Source code for ax.models.torch.botorch_modular.model

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import dataclasses
import warnings
from collections import OrderedDict
from collections.abc import Mapping, Sequence
from typing import Any

import numpy.typing as npt
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata, TGenMetadata
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.models.torch.botorch import (
    get_feature_importances_from_botorch_model,
    get_rounding_func,
)
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.models.torch.botorch_modular.utils import (
    check_outcome_dataset_match,
    choose_botorch_acqf_class,
    construct_acquisition_and_optimizer_options,
)
from ax.models.torch.utils import _to_inequality_constraints
from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig
from ax.utils.common.base import Base
from ax.utils.common.constants import Keys
from ax.utils.common.docutils import copy_doc
from ax.utils.common.typeutils import checked_cast
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.models.deterministic import FixedSingleSampleModel
from botorch.utils.datasets import SupervisedDataset
from torch import Tensor


[docs] class BoTorchModel(TorchModel, Base): """**All classes in 'botorch_modular' directory are under construction, incomplete, and should be treated as alpha versions only.** Modular ``Model`` class for combining BoTorch subcomponents in Ax. Specified via ``Surrogate`` and ``Acquisition``, which wrap BoTorch ``Model`` and ``AcquisitionFunction``, respectively, for convenient use in Ax. Args: acquisition_class: Type of ``Acquisition`` to be used in this model, auto-selected based on experiment and data if not specified. acquisition_options: Optional dict of kwargs, passed to the constructor of BoTorch ``AcquisitionFunction``. botorch_acqf_class: Type of ``AcquisitionFunction`` to be used in this model, auto-selected based on experiment and data if not specified. surrogate_spec: An optional ``SurrogateSpec`` object specifying how to construct the ``Surrogate`` and the underlying BoTorch ``Model``. surrogate_specs: DEPRECATED. Please use ``surrogate_spec`` instead. surrogate: In lieu of ``SurrogateSpec``, an instance of ``Surrogate`` may be provided. In most cases, ``surrogate_spec`` should be used instead. refit_on_cv: Whether to reoptimize model parameters during call to ``BoTorchmodel.cross_validate``. warm_start_refit: Whether to load parameters from either the provided state dict or the state dict of the current BoTorch ``Model`` during refitting. If False, model parameters will be reoptimized from scratch on refit. NOTE: This setting is ignored during ``cross_validate`` if ``refit_on_cv`` is False. """ acquisition_class: type[Acquisition] acquisition_options: dict[str, Any] surrogate_spec: SurrogateSpec | None _surrogate: Surrogate | None _botorch_acqf_class: type[AcquisitionFunction] | None _search_space_digest: SearchSpaceDigest | None = None _supports_robust_optimization: bool = True def __init__( self, surrogate_spec: SurrogateSpec | None = None, surrogate_specs: Mapping[str, SurrogateSpec] | None = None, surrogate: Surrogate | None = None, acquisition_class: type[Acquisition] | None = None, acquisition_options: dict[str, Any] | None = None, botorch_acqf_class: type[AcquisitionFunction] | None = None, refit_on_cv: bool = False, warm_start_refit: bool = True, ) -> None: # Check that only one surrogate related option is provided. if bool(surrogate_spec) + bool(surrogate_specs) + bool(surrogate) > 1: raise UserInputError( "Only one of `surrogate_spec`, `surrogate_specs`, and `surrogate` " "can be specified. Please use `surrogate_spec`." ) if surrogate_specs is not None: if len(surrogate_specs) > 1: raise DeprecationWarning( "Support for multiple `Surrogate`s has been deprecated. " "Please use the `surrogate_spec` input in the future to " "specify a single `Surrogate`." ) warnings.warn( "The `surrogate_specs` argument is deprecated in favor of " "`surrogate_spec`, which accepts a single `SurrogateSpec` object. " "Please use `surrogate_spec` in the future.", DeprecationWarning, stacklevel=2, ) surrogate_spec = next(iter(surrogate_specs.values())) self.surrogate_spec = surrogate_spec self._surrogate = surrogate self.acquisition_class = acquisition_class or Acquisition self.acquisition_options = acquisition_options or {} self._botorch_acqf_class = botorch_acqf_class self.refit_on_cv = refit_on_cv self.warm_start_refit = warm_start_refit @property def surrogate(self) -> Surrogate: """Returns the ``Surrogate``, if it has been constructed.""" if self._surrogate is None: raise ValueError("Surrogate has not yet been constructed.") return self._surrogate @property def Xs(self) -> list[Tensor]: """A list of tensors, each of shape ``batch_shape x n_i x d``, where `n_i` is the number of training inputs for the i-th model. NOTE: This is an accessor for ``self.surrogate.Xs`` and returns it unchanged. """ return self.surrogate.Xs @property def botorch_acqf_class(self) -> type[AcquisitionFunction]: """BoTorch ``AcquisitionFunction`` class, associated with this model. Raises an error if one is not yet set. """ if not self._botorch_acqf_class: raise ValueError("BoTorch `AcquisitionFunction` has not yet been set.") return self._botorch_acqf_class
[docs] def fit( self, datasets: Sequence[SupervisedDataset], search_space_digest: SearchSpaceDigest, candidate_metadata: list[list[TCandidateMetadata]] | None = None, state_dict: OrderedDict[str, Tensor] | None = None, refit: bool = True, **additional_model_inputs: Any, ) -> None: """Fit model to m outcomes. Args: datasets: A list of ``SupervisedDataset`` containers, each corresponding to the data of one or more outcomes. search_space_digest: A ``SearchSpaceDigest`` object containing metadata on the features in the datasets. candidate_metadata: Model-produced metadata for candidates, in the order corresponding to the Xs. state_dict: An optional model statedict for the underlying ``Surrogate``. Primarily used in ``BoTorchModel.cross_validate``. refit: Whether to re-optimize model parameters. additional_model_inputs: Additional kwargs to pass to the model input constructor in ``Surrogate.fit``. """ outcome_names = sum((ds.outcome_names for ds in datasets), []) check_outcome_dataset_match( outcome_names=outcome_names, datasets=datasets, exact_match=True ) # Checks for duplicate outcome names # Store search space info for later use (e.g. during generation) self._search_space_digest = search_space_digest # If a surrogate has not been constructed, construct it. if self._surrogate is None: if self.surrogate_spec is not None: self._surrogate = Surrogate(surrogate_spec=self.surrogate_spec) else: self._surrogate = Surrogate() # Fit the surrogate. for config in self.surrogate.surrogate_spec.model_configs: config.model_options.update(additional_model_inputs) for ( config_list ) in self.surrogate.surrogate_spec.metric_to_model_configs.values(): for config in config_list: config.model_options.update(additional_model_inputs) self.surrogate.fit( datasets=datasets, search_space_digest=search_space_digest, candidate_metadata=candidate_metadata, state_dict=state_dict, refit=refit, )
[docs] def predict( self, X: Tensor, use_posterior_predictive: bool = False ) -> tuple[Tensor, Tensor]: """Predicts, potentially from multiple surrogates. Args: X: (n x d) Tensor of input locations. use_posterior_predictive: A boolean indicating if the predictions should be from the posterior predictive (i.e. including observation noise). Returns: Tuple of tensors: (n x m) mean, (n x m x m) covariance. """ return self.surrogate.predict( X=X, use_posterior_predictive=use_posterior_predictive )
[docs] @copy_doc(TorchModel.gen) def gen( self, n: int, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, ) -> TorchGenResults: acq_options, opt_options = construct_acquisition_and_optimizer_options( acqf_options=self.acquisition_options, model_gen_options=torch_opt_config.model_gen_options, ) # update bounds / target values search_space_digest = dataclasses.replace( self.search_space_digest, bounds=search_space_digest.bounds, target_values=search_space_digest.target_values or {}, ) acqf = self._instantiate_acquisition( search_space_digest=search_space_digest, torch_opt_config=torch_opt_config, acq_options=acq_options, ) botorch_rounding_func = get_rounding_func(torch_opt_config.rounding_func) candidates, expected_acquisition_value, weights = acqf.optimize( n=n, search_space_digest=search_space_digest, inequality_constraints=_to_inequality_constraints( linear_constraints=torch_opt_config.linear_constraints ), fixed_features=torch_opt_config.fixed_features, rounding_func=botorch_rounding_func, optimizer_options=checked_cast(dict, opt_options), ) gen_metadata = self._get_gen_metadata_from_acqf( acqf=acqf, torch_opt_config=torch_opt_config, expected_acquisition_value=expected_acquisition_value, ) return TorchGenResults( points=candidates.detach().cpu(), weights=weights, gen_metadata=gen_metadata, )
def _get_gen_metadata_from_acqf( self, acqf: Acquisition, torch_opt_config: TorchOptConfig, expected_acquisition_value: Tensor, ) -> TGenMetadata: gen_metadata: TGenMetadata = { Keys.EXPECTED_ACQF_VAL: expected_acquisition_value.tolist() } if torch_opt_config.objective_weights.nonzero().numel() > 1: gen_metadata["objective_thresholds"] = acqf.objective_thresholds gen_metadata["objective_weights"] = acqf.objective_weights if hasattr(acqf.acqf, "outcome_model"): outcome_model = acqf.acqf.outcome_model if isinstance( outcome_model, FixedSingleSampleModel, ): gen_metadata["outcome_model_fixed_draw_weights"] = outcome_model.w return gen_metadata
[docs] @copy_doc(TorchModel.best_point) def best_point( self, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, ) -> Tensor | None: try: return self.surrogate.best_in_sample_point( search_space_digest=search_space_digest, torch_opt_config=torch_opt_config, )[0] except ValueError: return None
[docs] @copy_doc(TorchModel.evaluate_acquisition_function) def evaluate_acquisition_function( self, X: Tensor, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, acq_options: dict[str, Any] | None = None, ) -> Tensor: acqf = self._instantiate_acquisition( search_space_digest=search_space_digest, torch_opt_config=torch_opt_config, acq_options=acq_options, ) return acqf.evaluate(X=X)
[docs] @copy_doc(TorchModel.cross_validate) def cross_validate( self, datasets: Sequence[SupervisedDataset], X_test: Tensor, search_space_digest: SearchSpaceDigest, use_posterior_predictive: bool = False, **additional_model_inputs: Any, ) -> tuple[Tensor, Tensor]: current_surrogate = self.surrogate # If we should be refitting but not warm-starting the refit, set # `state_dict` to None to avoid loading it. state_dict = ( None if self.refit_on_cv and not self.warm_start_refit else current_surrogate.model.state_dict() ) # Temporarily set `_surrogate` to cloned surrogate to set # the training data on cloned surrogate to train set and # use it to predict the test point. self._surrogate = current_surrogate.clone_reset() # Remove the `robust_digest` since we do not want to use perturbations here. search_space_digest = dataclasses.replace( search_space_digest, robust_digest=None, ) try: self.fit( datasets=datasets, search_space_digest=search_space_digest, # pyre-fixme [6]: state_dict() has a generic dict[str, Any] return type # but it is actually an OrderedDict[str, Tensor]. state_dict=state_dict, refit=self.refit_on_cv, **additional_model_inputs, ) X_test_prediction = self.predict( X=X_test, use_posterior_predictive=use_posterior_predictive, ) finally: # Reset the surrogates back to this model's surrogate, make # sure the cloned surrogate doesn't stay around if fit or # predict fail. self._surrogate = current_surrogate return X_test_prediction
@property def dtype(self) -> torch.dtype: """Torch data type of the tensors in the training data used in the model, of which this ``Acquisition`` is a subcomponent. """ return self.surrogate.dtype @property def device(self) -> torch.device: """Torch device type of the tensors in the training data used in the model, of which this ``Acquisition`` is a subcomponent. """ return self.surrogate.device def _instantiate_acquisition( self, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, acq_options: dict[str, Any] | None = None, ) -> Acquisition: """Set a BoTorch acquisition function class for this model if needed and instantiate it. Returns: A BoTorch ``AcquisitionFunction`` instance. """ if not self._botorch_acqf_class: if torch_opt_config.risk_measure is not None: raise UnsupportedError( "Automated selection of `botorch_acqf_class` is not supported " "for robust optimization with risk measures. Please specify " "`botorch_acqf_class` as part of `model_kwargs`." ) self._botorch_acqf_class = choose_botorch_acqf_class( torch_opt_config=torch_opt_config ) return self.acquisition_class( surrogate=self.surrogate, botorch_acqf_class=self.botorch_acqf_class, search_space_digest=search_space_digest, torch_opt_config=torch_opt_config, options=acq_options, )
[docs] def feature_importances(self) -> npt.NDArray: """Compute feature importances from the model. This assumes that we can get model lengthscales from either ``covar_module.base_kernel.lengthscale`` or ``covar_module.lengthscale``. Returns: The feature importances as a numpy array of size len(metrics) x 1 x dim where each row sums to 1. """ return get_feature_importances_from_botorch_model(model=self.surrogate.model)
@property def search_space_digest(self) -> SearchSpaceDigest: if self._search_space_digest is None: raise RuntimeError( "`search_space_digest` is not initialized. Must `fit` the model first." ) return self._search_space_digest @search_space_digest.setter def search_space_digest(self, value: SearchSpaceDigest) -> None: raise RuntimeError("Setting search_space_digest manually is disallowed.")