Source code for ax.benchmark.benchmark_test_functions.surrogate

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from collections.abc import Callable, Mapping
from dataclasses import dataclass

import torch
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction
from ax.core.observation import ObservationFeatures
from ax.core.types import TParamValue
from ax.modelbridge.torch import TorchModelBridge
from ax.utils.common.base import Base
from ax.utils.common.equality import equality_typechecker
from pyre_extensions import none_throws
from torch import Tensor


[docs] @dataclass(kw_only=True) class SurrogateTestFunction(BenchmarkTestFunction): """ Data-generating function for surrogate benchmark problems. Args: name: The name of the runner. outcome_names: Names of outcomes to return in `evaluate_true`, if the surrogate produces more outcomes than are needed. _surrogate: Either `None`, or a `TorchModelBridge` surrogate to use for generating observations. If `None`, `get_surrogate` must not be None and will be used to generate the surrogate when it is needed. get_surrogate: Function that returns the surrogate, to allow for lazy construction. If `get_surrogate` is not provided, `surrogate` must be provided and vice versa. """ name: str outcome_names: list[str] _surrogate: TorchModelBridge | None = None get_surrogate: None | Callable[[], TorchModelBridge] = None def __post_init__(self) -> None: if self.get_surrogate is None and self._surrogate is None: raise ValueError( "If `get_surrogate` is None, `_surrogate` must not be None, and" " vice versa." ) @property def surrogate(self) -> TorchModelBridge: if self._surrogate is None: self._surrogate = none_throws(self.get_surrogate)() return none_throws(self._surrogate)
[docs] def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor: # We're ignoring the uncertainty predictions of the surrogate model here and # use the mean predictions as the outcomes (before potentially adding noise) means, _ = self.surrogate.predict( # pyre-fixme[6]: params is a Mapping, but ObservationFeatures expects a Dict observation_features=[ObservationFeatures(params)] ) means = [means[name][0] for name in self.outcome_names] return torch.tensor( means, device=self.surrogate.device, dtype=self.surrogate.dtype, )
@equality_typechecker def __eq__(self, other: Base) -> bool: if type(other) is not type(self): return False # Don't check surrogate, datasets, or callable return self.name == other.name