Source code for ax.benchmark.runners.surrogate

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import warnings
from typing import Any, Dict, Iterable, List, Optional, Set, Union

import torch
from ax.benchmark.runners.base import BenchmarkRunner
from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.observation import ObservationFeatures
from ax.core.search_space import SearchSpace
from ax.modelbridge.torch import TorchModelBridge
from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry
from botorch.utils.datasets import SupervisedDataset
from torch import Tensor


[docs]class SurrogateRunner(BenchmarkRunner): def __init__( self, name: str, surrogate: TorchModelBridge, datasets: List[SupervisedDataset], search_space: SearchSpace, outcome_names: List[str], noise_stds: Union[float, Dict[str, float]] = 0.0, ) -> None: """Runner for surrogate benchmark problems. Args: name: The name of the runner. surrogate: The modular BoTorch model `Surrogate` to use for generating observations. search_space: The search space of the problem (used for parameter transforms). datasets: The data sets used to fit the surrogate model. outcome_names: The names of the outcomes of the Surrogate. noise_stds: Noise standard deviations to add to the surrogate output(s). If a single float is provided, noise with that standard deviation is added to all outputs. Alternatively, a dictionary mapping outcome names to noise standard deviations can be provided to specify different noise levels for different outputs. """ self.name = name self.surrogate = surrogate self._outcome_names = outcome_names self.datasets = datasets self.search_space = search_space self.noise_stds = noise_stds self.statuses: Dict[int, TrialStatus] = {} @property def outcome_names(self) -> List[str]: return self._outcome_names
[docs] def get_noise_stds(self) -> Union[None, float, Dict[str, float]]: return self.noise_stds
[docs] def get_Y_true(self, arm: Arm) -> Tensor: # We're ignoring the uncertainty predictions of the surrogate model here and # use the mean predictions as the outcomes (before potentially adding noise) means, _ = self.surrogate.predict( observation_features=[ObservationFeatures(arm.parameters)] ) means = [means[name][0] for name in self.outcome_names] return torch.tensor( means, device=self.surrogate.device, dtype=self.surrogate.dtype, )
[docs] def run(self, trial: BaseTrial) -> Dict[str, Any]: """Run the trial by evaluating its parameterization(s) on the surrogate model. Note: This also sets the status of the trial to COMPLETED. Args: trial: The trial to evaluate. Returns: A dictionary with the following keys: - outcome_names: The names of the metrics being evaluated. - Ys: A dict mapping arm names to lists of corresponding outcomes, where the order of the outcomes is the same as in `outcome_names`. - Ystds: A dict mapping arm names to lists of corresponding outcome noise standard deviations (possibly nan if the noise level is unobserved), where the order of the outcomes is the same as in `outcome_names`. - Ys_true: A dict mapping arm names to lists of corresponding ground truth outcomes, where the order of the outcomes is the same as in `outcome_names`. """ self.statuses[trial.index] = TrialStatus.COMPLETED run_metadata = super().run(trial=trial) run_metadata["outcome_names"] = self.outcome_names return run_metadata
[docs] def poll_trial_status( self, trials: Iterable[BaseTrial] ) -> Dict[TrialStatus, Set[int]]: return {TrialStatus.COMPLETED: {t.index for t in trials}}
[docs] @classmethod # pyre-fixme[2]: Parameter annotation cannot be `Any`. def serialize_init_args(cls, obj: Any) -> Dict[str, Any]: """Serialize the properties needed to initialize the runner. Used for storage. WARNING: Because of issues with consistently saving and loading BoTorch and GPyTorch modules the SurrogateRunner cannot be serialized at this time. At load time the runner will be replaced with a SyntheticRunner. """ warnings.warn( "Because of issues with consistently saving and loading BoTorch and " f"GPyTorch modules, {cls.__name__} cannot be serialized at this time. " "At load time the runner will be replaced with a SyntheticRunner.", stacklevel=3, ) return {}
[docs] @classmethod def deserialize_init_args( cls, args: Dict[str, Any], decoder_registry: Optional[TDecoderRegistry] = None, class_decoder_registry: Optional[TClassDecoderRegistry] = None, ) -> Dict[str, Any]: return {}