Source code for ax.benchmark.problems.surrogate

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Iterable, List, Set

import numpy as np

import pandas as pd
import torch
from ax.benchmark.benchmark_problem import SingleObjectiveBenchmarkProblem
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.data import Data
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.core.objective import Objective
from ax.core.optimization_config import OptimizationConfig
from ax.core.runner import Runner
from ax.core.search_space import SearchSpace
from ax.modelbridge.modelbridge_utils import extract_search_space_digest
from ax.models.torch.botorch_modular.surrogate import Surrogate

from ax.utils.common.base import Base
from ax.utils.common.equality import equality_typechecker
from ax.utils.common.result import Err, Ok
from botorch.utils.datasets import SupervisedDataset


[docs]class SurrogateBenchmarkProblem(SingleObjectiveBenchmarkProblem): @equality_typechecker def __eq__(self, other: Base) -> bool: if not isinstance(other, SurrogateBenchmarkProblem): return False # Checking the whole datasets' equality here would be too expensive to be # worth it; just check names instead return self.name == other.name
[docs] @classmethod def from_surrogate( cls, name: str, search_space: SearchSpace, surrogate: Surrogate, datasets: List[SupervisedDataset], minimize: bool, optimal_value: float, num_trials: int, ) -> "SurrogateBenchmarkProblem": return SurrogateBenchmarkProblem( name=name, search_space=search_space, optimization_config=OptimizationConfig( objective=Objective( metric=SurrogateMetric(), minimize=minimize, ) ), runner=SurrogateRunner( name=name, surrogate=surrogate, datasets=datasets, search_space=search_space, ), optimal_value=optimal_value, num_trials=num_trials, )
[docs]class SurrogateMetric(Metric): def __init__(self) -> None: super().__init__(name="prediction") # pyre-fixme[2]: Parameter must be annotated.
[docs] def fetch_trial_data(self, trial: BaseTrial, **kwargs) -> MetricFetchResult: try: prediction = [ trial.run_metadata["prediction"][name] for name, arm in trial.arms_by_name.items() ] df = pd.DataFrame( { "arm_name": [name for name, _ in trial.arms_by_name.items()], "metric_name": self.name, "mean": prediction, "sem": np.nan, "trial_index": trial.index, } ) return Ok(value=Data(df=df)) except Exception as e: return Err( MetricFetchE( message=f"Failed to predict for trial {trial}", exception=e ) )
[docs]class SurrogateRunner(Runner): def __init__( self, name: str, surrogate: Surrogate, datasets: List[SupervisedDataset], search_space: SearchSpace, ) -> None: self.name = name self.surrogate = surrogate self.datasets = datasets self.search_space = search_space self.results: Dict[int, float] = {} self.statuses: Dict[int, TrialStatus] = {} surrogate.fit( datasets=datasets, metric_names=["objective"], search_space_digest=extract_search_space_digest( search_space=search_space, param_names=[*search_space.parameters.keys()] ), )
[docs] def run(self, trial: BaseTrial) -> Dict[str, Any]: self.statuses[trial.index] = TrialStatus.COMPLETED return { "prediction": { arm.name: self.surrogate.predict( X=torch.tensor([*arm.parameters.values()]).reshape( [1, len(arm.parameters)] ) )[0].item() for arm in trial.arms } }
[docs] def poll_trial_status( self, trials: Iterable[BaseTrial] ) -> Dict[TrialStatus, Set[int]]: return {TrialStatus.COMPLETED: {t.index for t in trials}}
[docs] @classmethod # pyre-fixme[2]: Parameter annotation cannot be `Any`. def serialize_init_args(cls, obj: Any) -> Dict[str, Any]: """Serialize the properties needed to initialize the runner. Used for storage. WARNING: Because of issues with consistently saving and loading BoTorch and GPyTorch modules the SurrogateRunner cannot be serialized at this time. At load time the runner will be replaced with a SyntheticRunner. """ return {}
[docs] @classmethod def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]: return {}