Source code for ax.models.torch.botorch_modular.surrogate

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

import inspect
import warnings
from collections import OrderedDict
from collections.abc import Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from logging import Logger
from typing import Any

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata
from ax.exceptions.core import UserInputError
from ax.models.model_utils import best_in_sample_point
from ax.models.torch.botorch_modular.input_constructors.covar_modules import (
    covar_module_argparse,
)
from ax.models.torch.botorch_modular.input_constructors.input_transforms import (
    input_transform_argparse,
)
from ax.models.torch.botorch_modular.input_constructors.outcome_transform import (
    outcome_transform_argparse,
)
from ax.models.torch.botorch_modular.utils import (
    choose_model_class,
    convert_to_block_design,
    fit_botorch_model,
    ModelConfig,
    subset_state_dict,
    use_model_list,
)
from ax.models.torch.utils import (
    _to_inequality_constraints,
    pick_best_out_of_sample_point_acqf_class,
    predict_from_model,
)
from ax.models.torch_base import TorchOptConfig
from ax.models.types import TConfig
from ax.utils.common.base import Base
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import (
    _argparse_type_encoder,
    checked_cast,
    checked_cast_optional,
)
from botorch.models.model import Model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import MultiTaskGP
from botorch.models.transforms.input import (
    ChainedInputTransform,
    InputPerturbation,
    InputTransform,
)
from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform
from botorch.utils.containers import SliceContainer
from botorch.utils.datasets import RankingDataset, SupervisedDataset
from botorch.utils.dispatcher import Dispatcher
from gpytorch.kernels import Kernel
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from pyre_extensions import none_throws
from torch import Tensor
from torch.nn import Module

NOT_YET_FIT_MSG = (
    "Underlying BoTorch `Model` has not yet received its training_data. "
    "Please fit the model first."
)

logger: Logger = get_logger(__name__)


def _extract_model_kwargs(
    search_space_digest: SearchSpaceDigest,
) -> dict[str, list[int] | int]:
    """
    Extracts keyword arguments that are passed to the `construct_inputs`
    method of a BoTorch `Model` class.

    Args:
        search_space_digest: A `SearchSpaceDigest`.

    Returns:
        A dict of fidelity features, categorical features, and, if present, task
        features.
    """
    fidelity_features = search_space_digest.fidelity_features
    task_features = search_space_digest.task_features
    if len(fidelity_features) > 0 and len(task_features) > 0:
        raise NotImplementedError(
            "Multi-Fidelity GP models with task_features are "
            "currently not supported."
        )
    if len(task_features) > 1:
        raise NotImplementedError("Multiple task features are not supported.")

    kwargs: dict[str, list[int] | int] = {}
    if len(search_space_digest.categorical_features) > 0:
        kwargs["categorical_features"] = search_space_digest.categorical_features
    if len(fidelity_features) > 0:
        kwargs["fidelity_features"] = fidelity_features
    if len(task_features) == 1:
        kwargs["task_feature"] = task_features[0]
    return kwargs


def _make_botorch_input_transform(
    input_transform_classes: list[type[InputTransform]],
    input_transform_options: dict[str, dict[str, Any]],
    dataset: SupervisedDataset,
    search_space_digest: SearchSpaceDigest,
) -> InputTransform | None:
    """
    Makes a BoTorch input transform from the provided input classes and options.
    """
    if not (
        isinstance(input_transform_classes, list)
        and all(issubclass(c, InputTransform) for c in input_transform_classes)
    ):
        raise UserInputError("Expected a list of input transforms.")
    if search_space_digest.robust_digest is not None:
        input_transform_classes = [InputPerturbation] + input_transform_classes
    if len(input_transform_classes) == 0:
        return None

    input_transform_kwargs = [
        input_transform_argparse(
            transform_class,
            dataset=dataset,
            search_space_digest=search_space_digest,
            input_transform_options=deepcopy(  # In case of in-place modifications.
                input_transform_options.get(transform_class.__name__, {})
            ),
        )
        for transform_class in input_transform_classes
    ]

    input_transforms = [
        # pyre-fixme[45]: Cannot instantiate abstract class `InputTransform`.
        transform_class(**single_input_transform_kwargs)
        for transform_class, single_input_transform_kwargs in zip(
            input_transform_classes, input_transform_kwargs
        )
    ]

    input_transform_instance = (
        ChainedInputTransform(
            **{f"tf{i}": input_transforms[i] for i in range(len(input_transforms))}
        )
        if len(input_transforms) > 1
        else input_transforms[0]
    )

    return input_transform_instance


def _make_botorch_outcome_transform(
    outcome_transform_classes: list[type[OutcomeTransform]],
    outcome_transform_options: dict[str, dict[str, Any]],
    dataset: SupervisedDataset,
) -> OutcomeTransform | None:
    """
    Makes a BoTorch outcome transform from the provided classes and options.
    """
    if not (
        isinstance(outcome_transform_classes, list)
        and all(issubclass(c, OutcomeTransform) for c in outcome_transform_classes)
    ):
        raise UserInputError("Expected a list of outcome transforms.")
    if len(outcome_transform_classes) == 0:
        return None

    outcome_transform_kwargs = [
        outcome_transform_argparse(
            transform_class,
            outcome_transform_options=deepcopy(  # In case of in-place modifications.
                outcome_transform_options.get(transform_class.__name__, {})
            ),
            dataset=dataset,
        )
        for transform_class in outcome_transform_classes
    ]

    outcome_transforms = [
        # pyre-fixme[45]: Cannot instantiate abstract class `OutcomeTransform`.
        transform_class(**single_outcome_transform_kwargs)
        for transform_class, single_outcome_transform_kwargs in zip(
            outcome_transform_classes, outcome_transform_kwargs
        )
    ]

    outcome_transform_instance = (
        ChainedOutcomeTransform(
            **{f"otf{i}": otf for i, otf in enumerate(outcome_transforms)}
        )
        if len(outcome_transforms) > 1
        else outcome_transforms[0]
    )
    return outcome_transform_instance


def _construct_submodules(
    model_config: ModelConfig,
    dataset: SupervisedDataset,
    search_space_digest: SearchSpaceDigest,
    botorch_model_class: type[Model],
) -> dict[str, Module | None]:
    """Constructs the submodules for the BoTorch model from the inputs
    extracted from the ``ModelConfig``. If the corresponding inputs are
    specified, the `covar_module`, `likelihood`, `input_transform`, and
    `outcome_transform` submodules are constructed.
    """
    botorch_model_class_args: list[str] = inspect.getfullargspec(
        botorch_model_class
    ).args

    def _error_if_arg_not_supported(arg_name: str) -> None:
        if arg_name not in botorch_model_class_args:
            raise UserInputError(
                f"The BoTorch model class {botorch_model_class.__name__} does not "
                f"support the input {arg_name}."
            )

    submodules: dict[str, Module | None] = {}
    # NOTE: Using the walrus operator here and below helps pyre.
    if (covar_class := model_config.covar_module_class) is not None:
        _error_if_arg_not_supported("covar_module")
        covar_module_kwargs = covar_module_argparse(
            covar_class,
            dataset=dataset,
            botorch_model_class=botorch_model_class,
            **deepcopy(model_config.covar_module_options),
        )
        # pyre-ignore [45]: Cannot instantiate abstract class `Kernel`.
        submodules["covar_module"] = covar_class(**covar_module_kwargs)

    if (likelihood_class := model_config.likelihood_class) is not None:
        _error_if_arg_not_supported("likelihood")
        # pyre-ignore [45]: Cannot instantiate abstract class `Likelihood`.
        submodules["likelihood"] = likelihood_class(
            **deepcopy(model_config.likelihood_options)
        )

    if (
        input_transform_classes := model_config.input_transform_classes
    ) is not None or search_space_digest.robust_digest is not None:
        _error_if_arg_not_supported("input_transform")
        submodules["input_transform"] = _make_botorch_input_transform(
            input_transform_classes=input_transform_classes or [],
            input_transform_options=model_config.input_transform_options or {},
            dataset=dataset,
            search_space_digest=search_space_digest,
        )

    if (
        outcome_transform_classes := model_config.outcome_transform_classes
    ) is not None:
        _error_if_arg_not_supported("outcome_transform")
        submodules["outcome_transform"] = _make_botorch_outcome_transform(
            outcome_transform_classes=outcome_transform_classes,
            outcome_transform_options=model_config.outcome_transform_options or {},
            dataset=dataset,
        )
    elif "outcome_transform" in botorch_model_class_args:
        # This is a temporary solution until all BoTorch models use
        # `Standardize` by default, see TODO [T197435440].
        # After this, we should update `Surrogate` to use `DEFAULT`
        # (https://fburl.com/code/22f4397e) for both of these args. This will
        # allow users to explicitly disable the default transforms by passing
        # in `None`.
        submodules["outcome_transform"] = None

    return submodules


def _raise_deprecation_warning(
    is_surrogate: bool = False,
    **kwargs: Any,
) -> bool:
    """Raise deprecation warnings for deprecated arguments.

    Args:
        is_surrogate: A boolean indicating whether the warning is called from
            Surrogate.

    Returns:
        A boolean indicating whether any deprecation warnings were raised.
    """
    msg = "{k} is deprecated and will be removed in a future version. "
    if is_surrogate:
        msg += "Please specify {k} via `surrogate_spec.model_configs`."
    else:
        msg += "Please specify {k} via `model_configs`."
    warnings_raised = False
    default_is_dict = {"botorch_model_kwargs", "mll_kwargs"}
    for k, v in kwargs.items():
        should_raise = False
        if k in default_is_dict:
            if v != {}:
                should_raise = True
        elif (v is not None and k != "mll_class") or (
            k == "mll_class" and v is not ExactMarginalLogLikelihood
        ):
            should_raise = True
        if should_raise:
            warnings.warn(
                msg.format(k=k),
                DeprecationWarning,
                stacklevel=3,
            )
            warnings_raised = True
    return warnings_raised


[docs] def get_model_config_from_deprecated_args( botorch_model_class: type[Model] | None, model_options: dict[str, Any] | None, mll_class: type[MarginalLogLikelihood] | None, mll_options: dict[str, Any] | None, outcome_transform_classes: list[type[OutcomeTransform]] | None, outcome_transform_options: dict[str, dict[str, Any]] | None, input_transform_classes: list[type[InputTransform]] | None, input_transform_options: dict[str, dict[str, Any]] | None, covar_module_class: type[Kernel] | None, covar_module_options: dict[str, Any] | None, likelihood_class: type[Likelihood] | None, likelihood_options: dict[str, Any] | None, ) -> ModelConfig: """Construct a ModelConfig from deprecated arguments.""" model_config_kwargs = { "botorch_model_class": botorch_model_class, "model_options": (model_options or {}).copy(), "mll_class": mll_class, "mll_options": (mll_options or {}).copy(), "outcome_transform_classes": outcome_transform_classes, "outcome_transform_options": outcome_transform_options, "input_transform_classes": input_transform_classes, "input_transform_options": input_transform_options, "covar_module_class": covar_module_class, "covar_module_options": covar_module_options, "likelihood_class": likelihood_class, "likelihood_options": likelihood_options, } model_config_kwargs = { k: v for k, v in model_config_kwargs.items() if v is not None } # pyre-fixme [6]: Incompatible parameter type [6]: In call # `ModelConfig.__init__`, for 1st positional argument, expected # `Dict[str, typing.Any]` but got `Union[Dict[str, typing.Any], # Dict[str, Dict[str, typing.Any]], Sequence[Type[InputTransform]], # Sequence[Type[OutcomeTransform]], Type[Union[MarginalLogLikelihood, # Model]], Type[Likelihood], Type[Kernel]]`. return ModelConfig(**model_config_kwargs)
[docs] @dataclass(frozen=True) class SurrogateSpec: """ Fields in the SurrogateSpec dataclass correspond to arguments in ``Surrogate.__init__``, except for ``outcomes`` which is used to specify which outcomes the Surrogate is responsible for modeling. When ``BotorchModel.fit`` is called, these fields will be used to construct the requisite Surrogate objects. If ``outcomes`` is left empty then no outcomes will be fit to the Surrogate. """ botorch_model_class: type[Model] | None = None botorch_model_kwargs: dict[str, Any] = field(default_factory=dict) mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood mll_kwargs: dict[str, Any] = field(default_factory=dict) covar_module_class: type[Kernel] | None = None covar_module_kwargs: dict[str, Any] | None = None likelihood_class: type[Likelihood] | None = None likelihood_kwargs: dict[str, Any] | None = None input_transform_classes: list[type[InputTransform]] | None = None input_transform_options: dict[str, dict[str, Any]] | None = None outcome_transform_classes: list[type[OutcomeTransform]] | None = None outcome_transform_options: dict[str, dict[str, Any]] | None = None allow_batched_models: bool = True model_configs: list[ModelConfig] = field(default_factory=list) metric_to_model_configs: dict[str, list[ModelConfig]] = field(default_factory=dict) outcomes: list[str] = field(default_factory=list) def __post_init__(self) -> None: warnings_raised = _raise_deprecation_warning( is_surrogate=False, botorch_model_class=self.botorch_model_class, botorch_model_kwargs=self.botorch_model_kwargs, mll_class=self.mll_class, mll_kwargs=self.mll_kwargs, outcome_transform_classes=self.outcome_transform_classes, outcome_transform_options=self.outcome_transform_options, input_transform_classes=self.input_transform_classes, input_transform_options=self.input_transform_options, covar_module_class=self.covar_module_class, covar_module_options=self.covar_module_kwargs, likelihood_class=self.likelihood_class, likelihood_options=self.likelihood_kwargs, ) if len(self.model_configs) == 0: model_config = get_model_config_from_deprecated_args( botorch_model_class=self.botorch_model_class, model_options=self.botorch_model_kwargs, mll_class=self.mll_class, mll_options=self.mll_kwargs, outcome_transform_classes=self.outcome_transform_classes, outcome_transform_options=self.outcome_transform_options, input_transform_classes=self.input_transform_classes, input_transform_options=self.input_transform_options, covar_module_class=self.covar_module_class, covar_module_options=self.covar_module_kwargs, likelihood_class=self.likelihood_class, likelihood_options=self.likelihood_kwargs, ) # re-initialize with the non-deprecated arguments self.__init__( model_configs=[model_config], metric_to_model_configs=self.metric_to_model_configs, allow_batched_models=self.allow_batched_models, outcomes=self.outcomes, ) elif warnings_raised: raise UserInputError( "model_configs and deprecated arguments were both specified. " "Please use model_configs and remove deprecated arguments." ) if len(self.model_configs) > 1 or any( len(model_config) > 1 for model_config in self.metric_to_model_configs.values() ): raise NotImplementedError("Only one model config per metric is supported.")
[docs] class Surrogate(Base): """ **All classes in 'botorch_modular' directory are under construction, incomplete, and should be treated as alpha versions only.** Ax wrapper for BoTorch ``Model``, subcomponent of ``BoTorchModel`` and is not meant to be used outside of it. Args: botorch_model_class: ``Model`` class to be used as the underlying BoTorch model. If None is provided a model class will be selected (either one for all outcomes or a ModelList with separate models for each outcome) will be selected automatically based off the datasets at `construct` time. This argument is deprecated in favor of model_configs. model_options: Dictionary of options / kwargs for the BoTorch ``Model`` constructed during ``Surrogate.fit``. Note that the corresponding attribute will later be updated to include any additional kwargs passed into ``BoTorchModel.fit``. This argument is deprecated in favor of model_configs. mll_class: ``MarginalLogLikelihood`` class to use for model-fitting. This argument is deprecated in favor of model_configs. mll_options: Dictionary of options / kwargs for the MLL. This argument is deprecated in favor of model_configs. outcome_transform_classes: List of BoTorch outcome transforms classes. Passed down to the BoTorch ``Model``. Multiple outcome transforms can be chained together using ``ChainedOutcomeTransform``. This argument is deprecated in favor of model_configs. outcome_transform_options: Outcome transform classes kwargs. The keys are class string names and the values are dictionaries of outcome transform kwargs. For example, ` outcome_transform_classes = [Standardize] outcome_transform_options = { "Standardize": {"m": 1}, ` For more options see `botorch/models/transforms/outcome.py`. This argument is deprecated in favor of model_configs. input_transform_classes: List of BoTorch input transforms classes. Passed down to the BoTorch ``Model``. Multiple input transforms will be chained together using ``ChainedInputTransform``. This argument is deprecated in favor of model_configs. input_transform_options: Input transform classes kwargs. The keys are class string names and the values are dictionaries of input transform kwargs. For example, ` input_transform_classes = [Normalize, Round] input_transform_options = { "Normalize": {"d": 3}, "Round": {"integer_indices": [0], "categorical_features": {1: 2}}, } ` For more input options see `botorch/models/transforms/input.py`. This argument is deprecated in favor of model_configs. covar_module_class: Covariance module class. This gets initialized after parsing the ``covar_module_options`` in ``covar_module_argparse``, and gets passed to the model constructor as ``covar_module``. This argument is deprecated in favor of model_configs. covar_module_options: Covariance module kwargs. This argument is deprecated in favor of model_configs. likelihood: ``Likelihood`` class. This gets initialized with ``likelihood_options`` and gets passed to the model constructor. This argument is deprecated in favor of model_configs. likelihood_options: Likelihood options. This argument is deprecated in favor of model_configs. model_configs: List of model configs. Each model config is a specification of a model. These should be used in favor of the above deprecated arguments. metric_to_model_configs: Dictionary mapping metric names to a list of model configs for that metric. allow_batched_models: Set to true to fit the models in a batch if supported. Set to false to fit individual models to each metric in a loop. """ def __init__( self, surrogate_spec: SurrogateSpec | None = None, botorch_model_class: type[Model] | None = None, model_options: dict[str, Any] | None = None, mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood, mll_options: dict[str, Any] | None = None, outcome_transform_classes: list[type[OutcomeTransform]] | None = None, outcome_transform_options: dict[str, dict[str, Any]] | None = None, input_transform_classes: list[type[InputTransform]] | None = None, input_transform_options: dict[str, dict[str, Any]] | None = None, covar_module_class: type[Kernel] | None = None, covar_module_options: dict[str, Any] | None = None, likelihood_class: type[Likelihood] | None = None, likelihood_options: dict[str, Any] | None = None, allow_batched_models: bool = True, ) -> None: warnings_raised = _raise_deprecation_warning( is_surrogate=True, botorch_model_class=botorch_model_class, model_options=model_options, mll_class=mll_class, mll_options=mll_options, outcome_transform_classes=outcome_transform_classes, outcome_transform_options=outcome_transform_options, input_transform_classes=input_transform_classes, input_transform_options=input_transform_options, covar_module_class=covar_module_class, covar_module_options=covar_module_options, likelihood_class=likelihood_class, likelihood_options=likelihood_options, ) # check if surrogate_spec is provided if surrogate_spec is None: # create surrogate spec from deprecated arguments model_config = get_model_config_from_deprecated_args( botorch_model_class=botorch_model_class, model_options=model_options, mll_class=mll_class, mll_options=mll_options, outcome_transform_classes=outcome_transform_classes, outcome_transform_options=outcome_transform_options, input_transform_classes=input_transform_classes, input_transform_options=input_transform_options, covar_module_class=covar_module_class, covar_module_options=covar_module_options, likelihood_class=likelihood_class, likelihood_options=likelihood_options, ) surrogate_spec = SurrogateSpec( model_configs=[model_config], allow_batched_models=allow_batched_models ) elif warnings_raised: raise UserInputError( "model_configs and deprecated arguments were both specified. " "Please use model_configs and remove deprecated arguments." ) self.surrogate_spec: SurrogateSpec = surrogate_spec # Store the last dataset used to fit the model for a given metric(s). # If the new dataset is identical, we will skip model fitting for that metric. # The keys are `tuple(dataset.outcome_names)`. self._last_datasets: dict[tuple[str], SupervisedDataset] = {} # Store a reference from a tuple of metric names to the BoTorch Model # corresponding to those metrics. In most cases this will be a one-tuple, # though we need n-tuples for LCE-M models. This will be used to skip model # construction & fitting if the datasets are identical. self._submodels: dict[tuple[str], Model] = {} # Store a reference to search space digest used while fitting the cached models. # We will re-fit the models if the search space digest changes. self._last_search_space_digest: SearchSpaceDigest | None = None # These are later updated during model fitting. self._training_data: list[SupervisedDataset] | None = None self._outcomes: list[str] | None = None self._model: Model | None = None def __repr__(self) -> str: return f"<{self.__class__.__name__}" f" surrogate_spec={self.surrogate_spec}>" @property def model(self) -> Model: if self._model is None: raise ValueError( "BoTorch `Model` has not yet been constructed, please fit the " "surrogate first (done via `BoTorchModel.fit`)." ) return self._model @property def training_data(self) -> list[SupervisedDataset]: if self._training_data is None: raise ValueError(NOT_YET_FIT_MSG) return self._training_data @property def Xs(self) -> list[Tensor]: # Handles multi-output models. TODO: Improve this! training_data = self.training_data Xs = [] for dataset in training_data: if isinstance(dataset, RankingDataset): # directly accessing the d-dim X tensor values # instead of the augmented 2*d-dim dataset.X from RankingDataset Xi = checked_cast(SliceContainer, dataset._X).values else: Xi = dataset.X for _ in range(dataset.Y.shape[-1]): Xs.append(Xi) return Xs @property def dtype(self) -> torch.dtype: return self.training_data[0].X.dtype @property def device(self) -> torch.device: return self.training_data[0].X.device
[docs] def clone_reset(self) -> Surrogate: return self.__class__(**self._serialize_attributes_as_kwargs())
def _construct_model( self, dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, model_config: ModelConfig, default_botorch_model_class: type[Model], state_dict: OrderedDict[str, Tensor] | None, refit: bool, ) -> Model: """Constructs the underlying BoTorch ``Model`` using the training data. If the dataset and model class are identical to those used while training the cached sub-model, we skip model fitting and return the cached model. Args: dataset: Training data for the model (for one outcome for the default `Surrogate`, with the exception of batched multi-output case, where training data is formatted with just one X and concatenated Ys). search_space_digest: Search space digest used to set up model arguments. model_config: The model_config. default_botorch_model_class: The default ``Model`` class to be used as the underlying BoTorch model, if the model_config does not specify one. state_dict: Optional state dict to load. This should be subsetted for the current submodel being constructed. refit: Whether to re-optimize model parameters. """ outcome_names = tuple(dataset.outcome_names) botorch_model_class = ( model_config.botorch_model_class or default_botorch_model_class ) if self._should_reuse_last_model( dataset=dataset, botorch_model_class=botorch_model_class ): return self._submodels[outcome_names] formatted_model_inputs = submodel_input_constructor( botorch_model_class, # Do not pass as kwarg since this is used to dispatch. model_config=model_config, dataset=dataset, search_space_digest=search_space_digest, surrogate=self, ) # pyre-ignore [45] model = botorch_model_class(**formatted_model_inputs) if state_dict is not None: model.load_state_dict(state_dict) if state_dict is None or refit: fit_botorch_model( model=model, mll_class=model_config.mll_class, mll_options=model_config.mll_options, ) self._submodels[outcome_names] = model self._last_datasets[outcome_names] = dataset return model def _should_reuse_last_model( self, dataset: SupervisedDataset, botorch_model_class: type[Model] ) -> bool: """Checks whether the given dataset and model class match the last dataset and model class used to train the cached sub-model. """ outcome_names = tuple(dataset.outcome_names) if ( outcome_names in self._submodels and dataset == self._last_datasets[outcome_names] ): last_model = self._submodels[outcome_names] if type(last_model) is not botorch_model_class: logger.info( f"The model class for outcome(s) {dataset.outcome_names} " f"changed from {type(last_model)} to {botorch_model_class}. " "Will refit the model." ) else: return True return False
[docs] def fit( self, datasets: Sequence[SupervisedDataset], search_space_digest: SearchSpaceDigest, candidate_metadata: list[list[TCandidateMetadata]] | None = None, state_dict: OrderedDict[str, Tensor] | None = None, refit: bool = True, ) -> None: """Fits the underlying BoTorch ``Model`` to ``m`` outcomes. NOTE: ``state_dict`` and ``refit`` keyword arguments control how the undelying BoTorch ``Model`` will be fit: whether its parameters will be reoptimized and whether it will be warm-started from a given state. There are three possibilities: * ``fit(state_dict=None)``: fit model from scratch (optimize model parameters and set its training data used for inference), * ``fit(state_dict=some_state_dict, refit=True)``: warm-start refit with a state dict of parameters (still re-optimize model parameters and set the training data), * ``fit(state_dict=some_state_dict, refit=False)``: load model parameters without refitting, but set new training data (used in cross-validation, for example). Args: datasets: A list of ``SupervisedDataset`` containers, each corresponding to the data of one metric (outcome), to be passed to ``Model.construct_inputs`` in BoTorch. search_space_digest: A ``SearchSpaceDigest`` object containing metadata on the features in the datasets. candidate_metadata: Model-produced metadata for candidates, in the order corresponding to the Xs. state_dict: Optional state dict to load. refit: Whether to re-optimize model parameters. """ self._discard_cached_model_and_data_if_search_space_digest_changed( search_space_digest=search_space_digest ) # To determine whether to use ModelList under the hood, we need to check for # the batched multi-output case, so we first see which model would be chosen # given the Yvars and the properties of data. if ( len(self.surrogate_spec.model_configs) == 1 and self.surrogate_spec.model_configs[0].botorch_model_class is None ): default_botorch_model_class = choose_model_class( datasets=datasets, search_space_digest=search_space_digest ) else: default_botorch_model_class = self.surrogate_spec.model_configs[ 0 ].botorch_model_class should_use_model_list = use_model_list( datasets=datasets, botorch_model_class=none_throws(default_botorch_model_class), model_configs=self.surrogate_spec.model_configs, allow_batched_models=self.surrogate_spec.allow_batched_models, metric_to_model_configs=self.surrogate_spec.metric_to_model_configs, ) if not should_use_model_list and len(datasets) > 1: datasets = convert_to_block_design(datasets=datasets, force=True) self._training_data = list(datasets) # So that it can be modified if needed. models = [] outcome_names = [] for i, dataset in enumerate(datasets): submodel_state_dict = None if state_dict is not None: if should_use_model_list: submodel_state_dict = subset_state_dict( state_dict=state_dict, submodel_index=i ) else: submodel_state_dict = state_dict model_config = None if len(self.surrogate_spec.metric_to_model_configs) > 0: # if metric_to_model_configs is not empty, then # we are using a model list and each dataset # should have only one outcome. if len(dataset.outcome_names) > 1: raise ValueError( "Each dataset should have only one outcome when " "metric_to_model_configs is specified." ) model_config_list = self.surrogate_spec.metric_to_model_configs.get( dataset.outcome_names[0] ) # TODO: add support for automated model selection if model_config_list is not None: model_config = model_config_list[0] if model_config is None: model_config = self.surrogate_spec.model_configs[0] model = self._construct_model( dataset=dataset, search_space_digest=search_space_digest, model_config=model_config, default_botorch_model_class=none_throws(default_botorch_model_class), state_dict=submodel_state_dict, refit=refit, ) models.append(model) outcome_names.extend(dataset.outcome_names) if should_use_model_list: self._model = ModelListGP(*models) else: self._model = models[0] self._outcomes = outcome_names # In the order of input datasets
def _discard_cached_model_and_data_if_search_space_digest_changed( self, search_space_digest: SearchSpaceDigest ) -> None: """Checks whether the search space digest has changed since the last call to `fit`. If it has, discards cached model and datasets. Also updates `self._last_search_space_digest` for future checks. """ if ( self._last_search_space_digest is not None and search_space_digest != self._last_search_space_digest ): logger.info( "Discarding all previously trained models due to a change " "in the search space digest." ) self._submodels = {} self._last_datasets = {} self._last_search_space_digest = search_space_digest
[docs] def predict( self, X: Tensor, use_posterior_predictive: bool = False ) -> tuple[Tensor, Tensor]: """Predicts outcomes given an input tensor. Args: X: A ``n x d`` tensor of input parameters. use_posterior_predictive: A boolean indicating if the predictions should be from the posterior predictive (i.e. including observation noise). Returns: Tensor: The predicted posterior mean as an ``n x o``-dim tensor. Tensor: The predicted posterior covariance as a ``n x o x o``-dim tensor. """ return predict_from_model( model=self.model, X=X, use_posterior_predictive=use_posterior_predictive )
[docs] def best_in_sample_point( self, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, options: TConfig | None = None, ) -> tuple[Tensor, float]: """Finds the best observed point and the corresponding observed outcome values. """ if torch_opt_config.is_moo: raise NotImplementedError( "Best observed point is incompatible with MOO problems." ) best_point_and_observed_value = best_in_sample_point( Xs=self.Xs, model=self, bounds=search_space_digest.bounds, objective_weights=torch_opt_config.objective_weights, outcome_constraints=torch_opt_config.outcome_constraints, linear_constraints=torch_opt_config.linear_constraints, fixed_features=torch_opt_config.fixed_features, risk_measure=torch_opt_config.risk_measure, options=options, ) if best_point_and_observed_value is None: raise ValueError("Could not obtain best in-sample point.") best_point, observed_value = best_point_and_observed_value return ( # pyre-fixme[16]: Item `ndarray` of `Union[ndarray[typing.Any, # typing.Any], Tensor]` has no attribute `to`. best_point.to(dtype=self.dtype, device=torch.device("cpu")), observed_value, )
[docs] def best_out_of_sample_point( self, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, options: TConfig | None = None, ) -> tuple[Tensor, Tensor]: """Finds the best predicted point and the corresponding value of the appropriate best point acquisition function. Args: search_space_digest: A `SearchSpaceDigest`. torch_opt_config: A `TorchOptConfig`; none-None `fixed_features` is not supported. options: Optional. If present, `seed_inner` (default None) and `qmc` (default True) will be parsed from `options`; any other keys will be ignored. Returns: A two-tuple (`candidate`, `acqf_value`), where `candidate` is a 1d Tensor of the best predicted point and `acqf_value` is a scalar (0d) Tensor of the acquisition function value at the best point. """ if torch_opt_config.fixed_features: # When have fixed features, need `FixedFeatureAcquisitionFunction` # which has peculiar instantiation (wraps another acquisition fn.), # so need to figure out how to handle. # TODO (ref: https://fburl.com/diff/uneqb3n9) raise NotImplementedError("Fixed features not yet supported.") options = options or {} acqf_class, acqf_options = pick_best_out_of_sample_point_acqf_class( outcome_constraints=torch_opt_config.outcome_constraints, seed_inner=checked_cast_optional(int, options.get(Keys.SEED_INNER, None)), qmc=checked_cast(bool, options.get(Keys.QMC, True)), risk_measure=torch_opt_config.risk_measure, ) # Avoiding circular import between `Surrogate` and `Acquisition`. from ax.models.torch.botorch_modular.acquisition import Acquisition acqf = Acquisition( surrogate=self, botorch_acqf_class=acqf_class, search_space_digest=search_space_digest, torch_opt_config=torch_opt_config, options=acqf_options, ) candidates, acqf_value, _ = acqf.optimize( n=1, search_space_digest=search_space_digest, inequality_constraints=_to_inequality_constraints( linear_constraints=torch_opt_config.linear_constraints ), fixed_features=torch_opt_config.fixed_features, ) return candidates[0], acqf_value
[docs] def pareto_frontier(self) -> tuple[Tensor, Tensor]: """For multi-objective optimization, retrieve Pareto frontier instead of best point. Returns: A two-tuple of: - tensor of points in the feature space, - tensor of corresponding (multiple) outcomes. """ raise NotImplementedError("Pareto frontier not yet implemented.")
[docs] def compute_diagnostics(self) -> dict[str, Any]: """Computes model diagnostics like cross-validation measure of fit, etc.""" return {}
def _serialize_attributes_as_kwargs(self) -> dict[str, Any]: """Serialize attributes of this surrogate, to be passed back to it as kwargs on reinstantiation. """ return {"surrogate_spec": self.surrogate_spec} @property def outcomes(self) -> list[str]: if self._outcomes is None: raise RuntimeError("outcomes not initialized. Please call `fit` first.") return self._outcomes @outcomes.setter def outcomes(self, value: list[str]) -> None: raise RuntimeError("Setting outcomes manually is disallowed.")
submodel_input_constructor = Dispatcher( name="submodel_input_constructor", encoder=_argparse_type_encoder ) @submodel_input_constructor.register(Model) def _submodel_input_constructor_base( botorch_model_class: type[Model], model_config: ModelConfig, dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, surrogate: Surrogate, ) -> dict[str, Any]: """Construct the inputs required to initialize a BoTorch model. Args: botorch_model_class: The BoTorch model class to instantiate. model_config: The model config. dataset: The training data for the model. search_space_digest: Search space digest used to set up model arguments. surrogate: A reference to the surrogate that created the model. This can be used by the constructor to obtain any additional arguments that are not readily available. Returns: A dictionary of inputs for constructing the model. """ model_kwargs_from_ss = _extract_model_kwargs( search_space_digest=search_space_digest ) formatted_model_inputs: dict[str, Any] = botorch_model_class.construct_inputs( training_data=dataset, **model_config.model_options, **model_kwargs_from_ss, ) submodules = _construct_submodules( model_config=model_config, dataset=dataset, # This is used when constructing the input transforms. search_space_digest=search_space_digest, # Used to check for supported arguments and in covar module input constructors. botorch_model_class=botorch_model_class, ) formatted_model_inputs.update(submodules) return formatted_model_inputs @submodel_input_constructor.register(MultiTaskGP) def _submodel_input_constructor_mtgp( botorch_model_class: type[Model], model_config: ModelConfig, dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, surrogate: Surrogate, ) -> dict[str, Any]: if len(dataset.outcome_names) > 1: raise NotImplementedError("Multi-output Multi-task GPs are not yet supported.") formatted_model_inputs = _submodel_input_constructor_base( botorch_model_class=botorch_model_class, model_config=model_config, dataset=dataset, search_space_digest=search_space_digest, surrogate=surrogate, ) task_feature = formatted_model_inputs.get("task_feature") if task_feature is None: return formatted_model_inputs # specify output tasks so that model.num_outputs = 1 # since the model only models a single outcome if formatted_model_inputs.get("output_tasks") is None: if (search_space_digest.target_values is not None) and ( target_value := search_space_digest.target_values.get(task_feature) ) is not None: formatted_model_inputs["output_tasks"] = [int(target_value)] else: raise UserInputError( "output_tasks or target task value must be provided for MultiTaskGP." ) return formatted_model_inputs