Source code for ax.benchmark.benchmark_test_functions.surrogate
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from collections.abc import Callable, Mapping
from dataclasses import dataclass
import torch
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction
from ax.core.observation import ObservationFeatures
from ax.core.types import TParamValue
from ax.modelbridge.torch import TorchModelBridge
from ax.utils.common.base import Base
from ax.utils.common.equality import equality_typechecker
from botorch.utils.datasets import SupervisedDataset
from pyre_extensions import none_throws
from torch import Tensor
[docs]
@dataclass(kw_only=True)
class SurrogateTestFunction(BenchmarkTestFunction):
"""
Data-generating function for surrogate benchmark problems.
Args:
name: The name of the runner.
outcome_names: Names of outcomes to return in `evaluate_true`, if the
surrogate produces more outcomes than are needed.
_surrogate: Either `None`, or a `TorchModelBridge` surrogate to use
for generating observations. If `None`, `get_surrogate_and_datasets`
must not be None and will be used to generate the surrogate when it
is needed.
_datasets: Either `None`, or the `SupervisedDataset`s used to fit
the surrogate model. If `None`, `get_surrogate_and_datasets` must
not be None and will be used to generate the datasets when they are
needed.
get_surrogate_and_datasets: Function that returns the surrogate and
datasets, to allow for lazy construction. If
`get_surrogate_and_datasets` is not provided, `surrogate` and
`datasets` must be provided, and vice versa.
"""
name: str
outcome_names: list[str]
_surrogate: TorchModelBridge | None = None
_datasets: list[SupervisedDataset] | None = None
get_surrogate_and_datasets: (
None | Callable[[], tuple[TorchModelBridge, list[SupervisedDataset]]]
) = None
def __post_init__(self) -> None:
if self.get_surrogate_and_datasets is None and (
self._surrogate is None or self._datasets is None
):
raise ValueError(
"If `get_surrogate_and_datasets` is None, `_surrogate` "
"and `_datasets` must not be None, and vice versa."
)
[docs]
def set_surrogate_and_datasets(self) -> None:
self._surrogate, self._datasets = none_throws(self.get_surrogate_and_datasets)()
@property
def surrogate(self) -> TorchModelBridge:
if self._surrogate is None:
self.set_surrogate_and_datasets()
return none_throws(self._surrogate)
@property
def datasets(self) -> list[SupervisedDataset]:
if self._datasets is None:
self.set_surrogate_and_datasets()
return none_throws(self._datasets)
[docs]
def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor:
# We're ignoring the uncertainty predictions of the surrogate model here and
# use the mean predictions as the outcomes (before potentially adding noise)
means, _ = self.surrogate.predict(
# pyre-fixme[6]: params is a Mapping, but ObservationFeatures expects a Dict
observation_features=[ObservationFeatures(params)]
)
means = [means[name][0] for name in self.outcome_names]
return torch.tensor(
means,
device=self.surrogate.device,
dtype=self.surrogate.dtype,
)
@equality_typechecker
def __eq__(self, other: Base) -> bool:
if type(other) is not type(self):
return False
# Don't check surrogate, datasets, or callable
return self.name == other.name