Source code for ax.models.torch.botorch_modular.surrogate

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import dataclasses
from logging import Logger
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata, TConfig
from ax.models.model_utils import best_in_sample_point
from ax.models.torch.utils import (
    _to_inequality_constraints,
    pick_best_out_of_sample_point_acqf_class,
    predict_from_model,
)
from ax.utils.common.base import Base
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast, checked_cast_optional, not_none
from botorch.fit import fit_gpytorch_model
from botorch.models.model import Model
from botorch.utils.containers import TrainingData
from gpytorch.kernels import Kernel
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from torch import Tensor


NOT_YET_FIT_MSG = (
    "Underlying BoTorch `Model` has not yet received its training_data. "
    "Please fit the model first."
)


logger: Logger = get_logger(__name__)


[docs]class Surrogate(Base): """ **All classes in 'botorch_modular' directory are under construction, incomplete, and should be treated as alpha versions only.** Ax wrapper for BoTorch ``Model``, subcomponent of ``BoTorchModel`` and is not meant to be used outside of it. Args: botorch_model_class: ``Model`` class to be used as the underlying BoTorch model. mll_class: ``MarginalLogLikelihood`` class to use for model-fitting. model_options: Dictionary of options / kwargs for the BoTorch ``Model`` constructed during ``Surrogate.fit``. kernel_class: ``Kernel`` class, not yet used. Will be used to construct custom BoTorch ``Model`` in the future. kernel_options: Kernel kwargs, not yet used. Will be used to construct custom BoTorch ``Model`` in the future. likelihood: ``Likelihood`` class, not yet used. Will be used to construct custom BoTorch ``Model`` in the future. """ botorch_model_class: Type[Model] mll_class: Type[MarginalLogLikelihood] model_options: Dict[str, Any] kernel_class: Optional[Type[Kernel]] = None _training_data: Optional[TrainingData] = None _model: Optional[Model] = None # Special setting for surrogates instantiated via `Surrogate.from_botorch`, # to avoid re-constructing the underlying BoTorch model on `Surrogate.fit` # when set to `False`. _constructed_manually: bool = False def __init__( self, # TODO: make optional when BoTorch model factory is checked in. # Construction will then be possible from likelihood, kernel, etc. botorch_model_class: Type[Model], model_options: Optional[Dict[str, Any]] = None, mll_class: Type[MarginalLogLikelihood] = ExactMarginalLogLikelihood, kernel_class: Optional[Type[Kernel]] = None, # TODO: use. likelihood: Optional[Type[Likelihood]] = None, # TODO: use. mll_options: Optional[Dict[str, Any]] = None, kernel_options: Optional[Dict[str, Any]] = None, # TODO: use. ) -> None: self.botorch_model_class = botorch_model_class self.mll_class = mll_class self.model_options = model_options or {} self.mll_options = mll_options or {} # Temporary validation while we develop these customizations. if likelihood is not None: raise NotImplementedError("Customizing likelihood not yet implemented.") if kernel_class is not None or kernel_options: raise NotImplementedError("Customizing kernel not yet implemented.") @property def model(self) -> Model: if self._model is None: raise ValueError( "BoTorch `Model` has not yet been constructed, please fit the " "surrogate first (done via `BoTorchModel.fit`)." ) return not_none(self._model) @property def training_data(self) -> TrainingData: if self._training_data is None: raise ValueError(NOT_YET_FIT_MSG) return not_none(self._training_data) @property def training_data_per_outcome(self) -> Dict[str, TrainingData]: raise NotImplementedError( # pragma: no cover "`training_data_per_outcome` is only used in `ListSurrogate`." ) @property def dtype(self) -> torch.dtype: return self.training_data.X.dtype @property def device(self) -> torch.device: return self.training_data.X.device
[docs] @classmethod def from_botorch( cls, model: Model, mll_class: Type[MarginalLogLikelihood] = ExactMarginalLogLikelihood, ) -> Surrogate: """Instantiate a `Surrogate` from a pre-instantiated Botorch `Model`.""" surrogate = cls(botorch_model_class=model.__class__, mll_class=mll_class) surrogate._model = model # Temporarily disallowing `update` for surrogates instantiated from # pre-made BoTorch `Model` instances to avoid reconstructing models # that were likely pre-constructed for a reason (e.g. if this setup # doesn't fully allow to constuct them). surrogate._constructed_manually = True return surrogate
[docs] def clone_reset(self) -> Surrogate: return self.__class__(**self._serialize_attributes_as_kwargs())
[docs] def construct(self, training_data: TrainingData, **kwargs: Any) -> None: """Constructs the underlying BoTorch ``Model`` using the training data. Args: training_data: Training data for the model (for one outcome for the default `Surrogate`, with the exception of batched multi-output case, where training data is formatted with just one X and concatenated Ys). **kwargs: Optional keyword arguments, expects any of: - "fidelity_features": Indices of columns in X that represent fidelity. """ if self._constructed_manually: logger.warning("Reconstructing a manually constructed `Model`.") if not isinstance(training_data, TrainingData): raise ValueError( # pragma: no cover "Base `Surrogate` expects training data for single outcome." ) input_constructor_kwargs = {**self.model_options, **(kwargs or {})} self._training_data = training_data formatted_model_inputs = self.botorch_model_class.construct_inputs( training_data=self.training_data, **input_constructor_kwargs ) # pyre-ignore[45]: Py raises informative msg if `model_cls` abstract. self._model = self.botorch_model_class(**formatted_model_inputs)
[docs] def fit( self, training_data: TrainingData, search_space_digest: SearchSpaceDigest, metric_names: List[str], candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None, state_dict: Optional[Dict[str, Tensor]] = None, refit: bool = True, ) -> None: """Fits the underlying BoTorch ``Model`` to ``m`` outcomes. NOTE: ``state_dict`` and ``refit`` keyword arguments control how the undelying BoTorch ``Model`` will be fit: whether its parameters will be reoptimized and whether it will be warm-started from a given state. There are three possibilities: * ``fit(state_dict=None)``: fit model from stratch (optimize model parameters and set its training data used for inference), * ``fit(state_dict=some_state_dict, refit=True)``: warm-start refit with a state dict of parameters (still re-optimize model parameters and set the training data), * ``fit(state_dict=some_state_dict, refit=False)``: load model parameters without refitting, but set new training data (used in cross-validation, for example). Args: training data: BoTorch ``TrainingData`` container with Xs, Ys, and possibly Yvars, to be passed to ``Model.construct_inputs`` in BoTorch. search_space_digest: A SearchSpaceDigest object containing metadata on the features in the trainig data. metric_names: Names of each outcome Y in Ys. candidate_metadata: Model-produced metadata for candidates, in the order corresponding to the Xs. state_dict: Optional state dict to load. refit: Whether to re-optimize model parameters. """ if self._constructed_manually: logger.debug( "For manually constructed surrogates (via `Surrogate.from_botorch`), " "`fit` skips setting the training data on model and only reoptimizes " "its parameters if `refit=True`." ) else: self.construct( training_data=training_data, metric_names=metric_names, **dataclasses.asdict(search_space_digest) ) if state_dict: # pyre-fixme[6]: Expected `OrderedDict[typing.Any, typing.Any]` for 1st # param but got `Dict[str, Tensor]`. self.model.load_state_dict(not_none(state_dict)) if state_dict is None or refit: mll = self.mll_class(self.model.likelihood, self.model, **self.mll_options) fit_gpytorch_model(mll)
[docs] def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]: """Predicts outcomes given a model and input tensor. Args: model: A botorch Model. X: A ``n x d`` tensor of input parameters. Returns: Tensor: The predicted posterior mean as an ``n x o``-dim tensor. Tensor: The predicted posterior covariance as a ``n x o x o``-dim tensor. """ return predict_from_model(model=self.model, X=X)
[docs] def best_in_sample_point( self, search_space_digest: SearchSpaceDigest, objective_weights: Optional[Tensor], outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None, linear_constraints: Optional[Tuple[Tensor, Tensor]] = None, fixed_features: Optional[Dict[int, float]] = None, options: Optional[TConfig] = None, ) -> Tuple[Tensor, float]: """Finds the best observed point and the corresponding observed outcome values. """ best_point_and_observed_value = best_in_sample_point( Xs=[self.training_data.X], # pyre-ignore[6]: `best_in_sample_point` currently expects a `TorchModel` # or a `NumpyModel` as `model` kwarg, but only uses them for `predict` # function, the signature for which is the same on this `Surrogate`. # TODO: When we move `botorch_modular` directory to OSS, we will extend # the annotation for `model` kwarg to accept `Surrogate` too. model=self, bounds=search_space_digest.bounds, objective_weights=objective_weights, outcome_constraints=outcome_constraints, linear_constraints=linear_constraints, fixed_features=fixed_features, options=options, ) if best_point_and_observed_value is None: raise ValueError("Could not obtain best in-sample point.") best_point, observed_value = best_point_and_observed_value return checked_cast(Tensor, best_point), observed_value
[docs] def best_out_of_sample_point( self, search_space_digest: SearchSpaceDigest, objective_weights: Tensor, outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None, linear_constraints: Optional[Tuple[Tensor, Tensor]] = None, fixed_features: Optional[Dict[int, float]] = None, options: Optional[TConfig] = None, ) -> Tuple[Tensor, Tensor]: """Finds the best predicted point and the corresponding value of the appropriate best point acquisition function. """ if fixed_features: # When have fixed features, need `FixedFeatureAcquisitionFunction` # which has peculiar instantiation (wraps another acquisition fn.), # so need to figure out how to handle. # TODO (ref: https://fburl.com/diff/uneqb3n9) raise NotImplementedError("Fixed features not yet supported.") options = options or {} acqf_class, acqf_options = pick_best_out_of_sample_point_acqf_class( outcome_constraints=outcome_constraints, seed_inner=checked_cast_optional(int, options.get(Keys.SEED_INNER, None)), qmc=checked_cast(bool, options.get(Keys.QMC, True)), ) # Avoiding circular import between `Surrogate` and `Acquisition`. from ax.models.torch.botorch_modular.acquisition import Acquisition acqf = Acquisition( # TODO: For multi-fidelity, might need diff. class. surrogate=self, botorch_acqf_class=acqf_class, search_space_digest=search_space_digest, objective_weights=objective_weights, outcome_constraints=outcome_constraints, linear_constraints=linear_constraints, fixed_features=fixed_features, options=acqf_options, ) candidates, acqf_values = acqf.optimize( n=1, search_space_digest=search_space_digest, inequality_constraints=_to_inequality_constraints( linear_constraints=linear_constraints ), fixed_features=fixed_features, ) return candidates[0], acqf_values[0]
[docs] def update( self, training_data: TrainingData, search_space_digest: SearchSpaceDigest, metric_names: List[str], candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None, state_dict: Optional[Dict[str, Tensor]] = None, refit: bool = True, ) -> None: """Updates the surrogate model with new data. In the base ``Surrogate``, just calls ``fit`` after checking that this surrogate was not created via ``Surrogate.from_botorch`` (in which case the ``Model`` comes premade, constructed manually and then supplied to ``Surrogate``). NOTE: Expects `training_data` to be all available data, not just the new data since the last time the model was updated. Args: training_data: Surrogate training_data containing all the data the model should use for inference. search_space_digest: A SearchSpaceDigest object containing metadata on the features in the training data. metric_names: Names of each outcome Y in Ys. candidate_metadata: Model-produced metadata for candidates, in the order corresponding to the Xs. state_dict: Optional state dict to load. refit: Whether to re-optimize model parameters or just set the training data used for interence to new training data. """ # NOTE: In the future, could have `incremental` kwarg, in which case # `training_data` could contain just the new data. if self._constructed_manually: raise NotImplementedError( "`update` not yet implemented for models that are " "constructed manually, but it is possible to create a new " "surrogate in the same way as the current manually constructed one, " "via `Surrogate.from_botorch`." ) self.fit( training_data=training_data, search_space_digest=search_space_digest, metric_names=metric_names, candidate_metadata=candidate_metadata, state_dict=state_dict, refit=refit, )
[docs] def pareto_frontier(self) -> Tuple[Tensor, Tensor]: """For multi-objective optimization, retrieve Pareto frontier instead of best point. Returns: A two-tuple of: - tensor of points in the feature space, - tensor of corresponding (multiple) outcomes. """ raise NotImplementedError( "Pareto frontier not yet implemented." ) # pragma: no cover
[docs] def compute_diagnostics(self) -> Dict[str, Any]: """Computes model diagnostics like cross-validation measure of fit, etc.""" return {} # pragma: no cover
def _serialize_attributes_as_kwargs(self) -> Dict[str, Any]: """Serialize attributes of this surrogate, to be passed back to it as kwargs on reinstantiation. """ return { "botorch_model_class": self.botorch_model_class, "mll_class": self.mll_class, "model_options": self.model_options, "mll_options": self.mll_options, }