#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import dataclasses
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata, TGenMetadata
from ax.exceptions.core import UnsupportedError
from ax.models.torch.botorch import get_rounding_func
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.list_surrogate import ListSurrogate
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.torch.botorch_modular.utils import (
choose_botorch_acqf_class,
choose_model_class,
construct_acquisition_and_optimizer_options,
convert_to_block_design,
use_model_list,
)
from ax.models.torch.utils import _to_inequality_constraints
from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig
from ax.utils.common.base import Base
from ax.utils.common.constants import Keys
from ax.utils.common.docutils import copy_doc
from ax.utils.common.typeutils import checked_cast, not_none
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.models.deterministic import FixedSingleSampleModel
from botorch.utils.datasets import SupervisedDataset
from torch import Tensor
[docs]class BoTorchModel(TorchModel, Base):
"""**All classes in 'botorch_modular' directory are under
construction, incomplete, and should be treated as alpha
versions only.**
Modular `Model` class for combining BoTorch subcomponents
in Ax. Specified via `Surrogate` and `Acquisition`, which wrap
BoTorch `Model` and `AcquisitionFunction`, respectively, for
convenient use in Ax.
Args:
acquisition_class: Type of `Acquisition` to be used in
this model, auto-selected based on experiment and data
if not specified.
acquisition_options: Optional dict of kwargs, passed to
the constructor of BoTorch `AcquisitionFunction`.
botorch_acqf_class: Type of `AcquisitionFunction` to be
used in this model, auto-selected based on experiment
and data if not specified.
surrogate: An instance of `Surrogate` to be used as part of
this model; if not specified, type of `Surrogate` and
underlying BoTorch `Model` will be auto-selected based
on experiment and data, with kwargs in `surrogate_options`
applied.
surrogate_options: Optional dict of kwargs for `Surrogate`
(used if no pre-instantiated Surrogate via is passed via `surrogate`).
Can include:
- model_options: Dict of options to surrogate's underlying
BoTorch `Model`,
- submodel_options or submodel_options_per_outcome:
Options for submodels in `ListSurrogate`, see documentation
for `ListSurrogate`.
refit_on_update: Whether to reoptimize model parameters during call
to `BoTorchModel.update`. If false, training data for the model
(used for inference) is still swapped for new training data, but
model parameters are not reoptimized.
refit_on_cv: Whether to reoptimize model parameters during call to
`BoTorchmodel.cross_validate`.
warm_start_refit: Whether to load parameters from either the provided
state dict or the state dict of the current BoTorch `Model` during
refitting. If False, model parameters will be reoptimized from
scratch on refit. NOTE: This setting is ignored during `update` or
`cross_validate` if the corresponding `refit_on_...` is False.
"""
acquisition_class: Type[Acquisition]
acquisition_options: Dict[str, Any]
surrogate_options: Dict[str, Any]
_surrogate: Optional[Surrogate]
_botorch_acqf_class: Optional[Type[AcquisitionFunction]]
_search_space_digest: Optional[SearchSpaceDigest] = None
def __init__(
self,
acquisition_class: Optional[Type[Acquisition]] = None,
acquisition_options: Optional[Dict[str, Any]] = None,
botorch_acqf_class: Optional[Type[AcquisitionFunction]] = None,
surrogate: Optional[Surrogate] = None,
surrogate_options: Optional[Dict[str, Any]] = None,
refit_on_update: bool = True,
refit_on_cv: bool = False,
warm_start_refit: bool = True,
) -> None:
self._surrogate = surrogate
if surrogate and surrogate_options:
raise ValueError( # pragma: no cover
"`surrogate_options` are only applied when using the default "
"surrogate, so only one of `surrogate` and `surrogate_options`"
" arguments is expected."
)
self.surrogate_options = surrogate_options or {}
self.acquisition_class = acquisition_class or Acquisition
# `_botorch_acqf_class` can be `None` here. If so, `Model.gen` or `Model.
# evaluate_acquisition_function` will set it with `choose_botorch_acqf_class`.
self._botorch_acqf_class = botorch_acqf_class
self.acquisition_options = acquisition_options or {}
self.refit_on_update = refit_on_update
self.refit_on_cv = refit_on_cv
self.warm_start_refit = warm_start_refit
@property
def Xs(self) -> List[Tensor]:
"""A list of tensors, each of shape ``batch_shape x n_i x d``,
where `n_i` is the number of training inputs for the i-th model.
NOTE: This is an accessor for ``self.surrogate.Xs``
and returns it unchanged.
"""
return self.surrogate.Xs
@property
def surrogate(self) -> Surrogate:
"""Ax ``Surrogate`` object (wrapper for BoTorch ``Model``), associated with
this model. Raises an error if one is not yet set.
"""
if not self._surrogate:
raise ValueError("Surrogate has not yet been set.")
return not_none(self._surrogate)
@property
def botorch_acqf_class(self) -> Type[AcquisitionFunction]:
"""BoTorch ``AcquisitionFunction`` class, associated with this model.
Raises an error if one is not yet set.
"""
if not self._botorch_acqf_class:
raise ValueError("BoTorch `AcquisitionFunction` has not yet been set.")
return not_none(self._botorch_acqf_class)
[docs] @copy_doc(TorchModel.fit)
def fit(
self,
datasets: List[SupervisedDataset],
metric_names: List[str],
search_space_digest: SearchSpaceDigest,
candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None,
state_dict: Optional[Dict[str, Tensor]] = None,
refit: bool = True,
) -> None:
if not len(datasets) == len(metric_names):
raise ValueError(
"Length of datasets and metric_names must match, but your inputs "
f"are of lengths {len(datasets)} and {len(metric_names)}, "
"respectively."
)
# store search space info for later use (e.g. during generation)
self._search_space_digest = search_space_digest
# Choose `Surrogate` and undelying `Model` based on properties of data.
if not self._surrogate:
self._autoset_surrogate(
datasets=datasets,
metric_names=metric_names,
search_space_digest=search_space_digest,
)
if len(datasets) > 1 and not isinstance(self.surrogate, ListSurrogate):
# Note: If the datasets do not confirm to a block design then this
# will filter the data and drop observations to make sure that it does.
# This can happen e.g. if only some metrics are observed at some points.
datasets, metric_names = convert_to_block_design(
datasets=datasets,
metric_names=metric_names,
force=True,
)
self.surrogate.fit(
datasets=datasets,
metric_names=metric_names,
search_space_digest=search_space_digest,
candidate_metadata=candidate_metadata,
state_dict=state_dict,
refit=refit,
)
[docs] @copy_doc(TorchModel.update)
def update(
self,
datasets: List[Optional[SupervisedDataset]],
metric_names: List[str],
search_space_digest: SearchSpaceDigest,
candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None,
) -> None:
if not self._surrogate:
raise UnsupportedError("Cannot update model that has not been fitted.")
# store search space info for later use (e.g. during generation)
self._search_space_digest = search_space_digest
# Sometimes the model fit should be restarted from scratch on update, for models
# that are prone to overfitting. In those cases, `self.warm_start_refit` should
# be false and `Surrogate.update` will not receive a state dict and will not
# pass it to the underlying `Surrogate.fit`.
state_dict = (
None
if self.refit_on_update and not self.warm_start_refit
else self.surrogate.model.state_dict()
)
if any(dataset is None for dataset in datasets):
raise UnsupportedError(
f"{self.__class__.__name__}.update requires data for all outcomes."
)
self.surrogate.update(
datasets=[not_none(dataset) for dataset in datasets],
metric_names=metric_names,
search_space_digest=search_space_digest,
candidate_metadata=candidate_metadata,
state_dict=state_dict,
refit=self.refit_on_update,
)
[docs] @copy_doc(TorchModel.predict)
def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
return self.surrogate.predict(X=X)
[docs] @copy_doc(TorchModel.gen)
def gen(
self,
n: int,
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
) -> TorchGenResults:
if self._search_space_digest is None:
raise RuntimeError("Must `fit` the model before calling `gen`.")
acq_options, opt_options = construct_acquisition_and_optimizer_options(
acqf_options=self.acquisition_options,
model_gen_options=torch_opt_config.model_gen_options,
)
# update bounds / target fidelities
search_space_digest = not_none(
dataclasses.replace(
self._search_space_digest,
bounds=search_space_digest.bounds,
target_fidelities=search_space_digest.target_fidelities or {},
)
)
acqf = self._instantiate_acquisition(
search_space_digest=search_space_digest,
torch_opt_config=torch_opt_config,
acq_options=acq_options,
)
botorch_rounding_func = get_rounding_func(torch_opt_config.rounding_func)
candidates, expected_acquisition_value = acqf.optimize(
n=n,
search_space_digest=search_space_digest,
inequality_constraints=_to_inequality_constraints(
linear_constraints=torch_opt_config.linear_constraints
),
fixed_features=torch_opt_config.fixed_features,
rounding_func=botorch_rounding_func,
optimizer_options=checked_cast(dict, opt_options),
)
gen_metadata = self._get_gen_metadata_from_acqf(
acqf=acqf,
torch_opt_config=torch_opt_config,
expected_acquisition_value=expected_acquisition_value,
)
return TorchGenResults(
points=candidates.detach().cpu(),
weights=torch.ones(n, dtype=self.surrogate.dtype),
gen_metadata=gen_metadata,
)
def _get_gen_metadata_from_acqf(
self,
acqf: Acquisition,
torch_opt_config: TorchOptConfig,
expected_acquisition_value: Tensor,
) -> TGenMetadata:
gen_metadata: TGenMetadata = {
Keys.EXPECTED_ACQF_VAL: expected_acquisition_value.tolist()
}
if torch_opt_config.objective_weights.nonzero().numel() > 1:
gen_metadata["objective_thresholds"] = acqf.objective_thresholds
gen_metadata["objective_weights"] = acqf.objective_weights
if hasattr(acqf.acqf, "outcome_model"):
outcome_model = acqf.acqf.outcome_model
if isinstance(
outcome_model,
FixedSingleSampleModel,
):
gen_metadata["outcome_model_fixed_draw_weights"] = outcome_model.w
return gen_metadata
[docs] @copy_doc(TorchModel.best_point)
def best_point(
self,
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
) -> Optional[Tensor]:
try:
return self.surrogate.best_in_sample_point(
search_space_digest=search_space_digest,
torch_opt_config=torch_opt_config,
)[0]
except ValueError:
return None
[docs] @copy_doc(TorchModel.evaluate_acquisition_function)
def evaluate_acquisition_function(
self,
X: Tensor,
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
acq_options: Optional[Dict[str, Any]] = None,
) -> Tensor:
acqf = self._instantiate_acquisition(
search_space_digest=search_space_digest,
torch_opt_config=torch_opt_config,
acq_options=acq_options,
)
return acqf.evaluate(X=X)
[docs] def cross_validate(
self,
datasets: List[SupervisedDataset],
metric_names: List[str],
X_test: Tensor,
search_space_digest: SearchSpaceDigest,
) -> Tuple[Tensor, Tensor]:
current_surrogate = self.surrogate
# If we should be refitting but not warm-starting the refit, set
# `state_dict` to None to avoid loading it.
state_dict = (
None
if self.refit_on_cv and not self.warm_start_refit
else deepcopy(current_surrogate.model.state_dict())
)
# Temporarily set `_surrogate` to cloned surrogate to set
# the training data on cloned surrogate to train set and
# use it to predict the test point.
surrogate_clone = self.surrogate.clone_reset()
self._surrogate = surrogate_clone
try:
self.fit(
datasets=datasets,
metric_names=metric_names,
search_space_digest=search_space_digest,
state_dict=state_dict,
refit=self.refit_on_cv,
)
X_test_prediction = self.predict(X=X_test)
finally:
# Reset the surrogate back to this model's surrogate, make
# sure the cloned surrogate doesn't stay around if fit or
# predict fail.
self._surrogate = current_surrogate
return X_test_prediction
def _autoset_surrogate(
self,
datasets: List[SupervisedDataset],
metric_names: List[str],
search_space_digest: SearchSpaceDigest,
) -> None:
"""Sets a default surrogate on this model if one was not explicitly
provided.
"""
# To determine whether to use `ListSurrogate`, we need to check for
# the batched multi-output case, so we first see which model would
# be chosen given the Yvars and the properties of data.
botorch_model_class = choose_model_class(
datasets=datasets,
search_space_digest=search_space_digest,
)
if use_model_list(datasets=datasets, botorch_model_class=botorch_model_class):
# If using `ListSurrogate` / `ModelListGP`, pick submodels for each
# outcome.
botorch_submodel_class_per_outcome = {
metric_name: choose_model_class(
datasets=[dataset],
search_space_digest=search_space_digest,
)
for dataset, metric_name in zip(datasets, metric_names)
}
self._surrogate = ListSurrogate(
botorch_submodel_class_per_outcome=botorch_submodel_class_per_outcome,
**self.surrogate_options,
)
else:
# Using regular `Surrogate`, so botorch model picked at the beginning
# of the function is the one we should use.
self._surrogate = Surrogate(
botorch_model_class=botorch_model_class, **self.surrogate_options
)
def _instantiate_acquisition(
self,
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
acq_options: Optional[Dict[str, Any]] = None,
) -> Acquisition:
"""Set a BoTorch acquisition function class for this model if needed and
instantiate it.
Returns:
A BoTorch ``AcquisitionFunction`` instance.
"""
if not self._botorch_acqf_class:
self._botorch_acqf_class = choose_botorch_acqf_class(
pending_observations=torch_opt_config.pending_observations,
outcome_constraints=torch_opt_config.outcome_constraints,
linear_constraints=torch_opt_config.linear_constraints,
fixed_features=torch_opt_config.fixed_features,
objective_thresholds=torch_opt_config.objective_thresholds,
objective_weights=torch_opt_config.objective_weights,
)
return self.acquisition_class(
surrogate=self.surrogate,
botorch_acqf_class=self.botorch_acqf_class,
search_space_digest=search_space_digest,
torch_opt_config=torch_opt_config,
options=acq_options,
)