#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
import inspect
import warnings
from collections import OrderedDict
from collections.abc import Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from logging import Logger
from typing import Any
import numpy as np
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata
from ax.exceptions.core import AxError, UnsupportedError, UserInputError
from ax.models.model_utils import best_in_sample_point
from ax.models.torch.botorch_modular.input_constructors.covar_modules import (
covar_module_argparse,
)
from ax.models.torch.botorch_modular.input_constructors.input_transforms import (
input_transform_argparse,
)
from ax.models.torch.botorch_modular.input_constructors.outcome_transform import (
outcome_transform_argparse,
)
from ax.models.torch.botorch_modular.utils import (
choose_model_class,
convert_to_block_design,
fit_botorch_model,
ModelConfig,
subset_state_dict,
use_model_list,
)
from ax.models.torch.utils import (
_to_inequality_constraints,
pick_best_out_of_sample_point_acqf_class,
predict_from_model,
)
from ax.models.torch_base import TorchOptConfig
from ax.models.types import TConfig
from ax.utils.common.base import Base
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import (
_argparse_type_encoder,
checked_cast,
checked_cast_optional,
)
from ax.utils.stats.model_fit_stats import (
DIAGNOSTIC_FN_DIRECTIONS,
DIAGNOSTIC_FNS,
ModelFitMetricDirection,
RANK_CORRELATION,
)
from botorch.exceptions.errors import ModelFittingError
from botorch.models.model import Model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import MultiTaskGP
from botorch.models.transforms.input import (
ChainedInputTransform,
InputPerturbation,
InputTransform,
)
from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.utils.containers import SliceContainer
from botorch.utils.datasets import MultiTaskDataset, RankingDataset, SupervisedDataset
from botorch.utils.dispatcher import Dispatcher
from gpytorch.kernels import Kernel
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from pyre_extensions import assert_is_instance, none_throws
from torch import Tensor
from torch.nn import Module
NOT_YET_FIT_MSG = (
"Underlying BoTorch `Model` has not yet received its training_data. "
"Please fit the model first."
)
logger: Logger = get_logger(__name__)
def _extract_model_kwargs(
search_space_digest: SearchSpaceDigest,
) -> dict[str, list[int] | int]:
"""
Extracts keyword arguments that are passed to the `construct_inputs`
method of a BoTorch `Model` class.
Args:
search_space_digest: A `SearchSpaceDigest`.
Returns:
A dict of fidelity features, categorical features, and, if present, task
features.
"""
fidelity_features = search_space_digest.fidelity_features
task_features = search_space_digest.task_features
if len(fidelity_features) > 0 and len(task_features) > 0:
raise NotImplementedError(
"Multi-Fidelity GP models with task_features are "
"currently not supported."
)
if len(task_features) > 1:
raise NotImplementedError("Multiple task features are not supported.")
kwargs: dict[str, list[int] | int] = {}
if len(search_space_digest.categorical_features) > 0:
kwargs["categorical_features"] = search_space_digest.categorical_features
if len(fidelity_features) > 0:
kwargs["fidelity_features"] = fidelity_features
if len(task_features) == 1:
kwargs["task_feature"] = task_features[0]
return kwargs
def _make_botorch_input_transform(
input_transform_classes: list[type[InputTransform]],
input_transform_options: dict[str, dict[str, Any]],
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
) -> InputTransform | None:
"""
Makes a BoTorch input transform from the provided input classes and options.
"""
if not (
isinstance(input_transform_classes, list)
and all(issubclass(c, InputTransform) for c in input_transform_classes)
):
raise UserInputError("Expected a list of input transforms.")
if search_space_digest.robust_digest is not None:
input_transform_classes = [InputPerturbation] + input_transform_classes
if len(input_transform_classes) == 0:
return None
input_transform_kwargs = [
input_transform_argparse(
transform_class,
dataset=dataset,
search_space_digest=search_space_digest,
input_transform_options=deepcopy( # In case of in-place modifications.
input_transform_options.get(transform_class.__name__, {})
),
)
for transform_class in input_transform_classes
]
input_transforms = [
# pyre-fixme[45]: Cannot instantiate abstract class `InputTransform`.
transform_class(**single_input_transform_kwargs)
for transform_class, single_input_transform_kwargs in zip(
input_transform_classes, input_transform_kwargs
)
]
input_transform_instance = (
ChainedInputTransform(
**{f"tf{i}": input_transforms[i] for i in range(len(input_transforms))}
)
if len(input_transforms) > 1
else input_transforms[0]
)
return input_transform_instance
def _make_botorch_outcome_transform(
outcome_transform_classes: list[type[OutcomeTransform]],
outcome_transform_options: dict[str, dict[str, Any]],
dataset: SupervisedDataset,
) -> OutcomeTransform | None:
"""
Makes a BoTorch outcome transform from the provided classes and options.
"""
if not (
isinstance(outcome_transform_classes, list)
and all(issubclass(c, OutcomeTransform) for c in outcome_transform_classes)
):
raise UserInputError("Expected a list of outcome transforms.")
if len(outcome_transform_classes) == 0:
return None
outcome_transform_kwargs = [
outcome_transform_argparse(
transform_class,
outcome_transform_options=deepcopy( # In case of in-place modifications.
outcome_transform_options.get(transform_class.__name__, {})
),
dataset=dataset,
)
for transform_class in outcome_transform_classes
]
outcome_transforms = [
# pyre-fixme[45]: Cannot instantiate abstract class `OutcomeTransform`.
transform_class(**single_outcome_transform_kwargs)
for transform_class, single_outcome_transform_kwargs in zip(
outcome_transform_classes, outcome_transform_kwargs
)
]
outcome_transform_instance = (
ChainedOutcomeTransform(
**{f"otf{i}": otf for i, otf in enumerate(outcome_transforms)}
)
if len(outcome_transforms) > 1
else outcome_transforms[0]
)
return outcome_transform_instance
def _construct_submodules(
model_config: ModelConfig,
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
botorch_model_class: type[Model],
) -> dict[str, Module | None]:
"""Constructs the submodules for the BoTorch model from the inputs
extracted from the ``ModelConfig``. If the corresponding inputs are
specified, the `covar_module`, `likelihood`, `input_transform`, and
`outcome_transform` submodules are constructed.
"""
botorch_model_class_args: list[str] = inspect.getfullargspec(
botorch_model_class
).args
def _error_if_arg_not_supported(arg_name: str) -> None:
if arg_name not in botorch_model_class_args:
raise UserInputError(
f"The BoTorch model class {botorch_model_class.__name__} does not "
f"support the input {arg_name}."
)
submodules: dict[str, Module | None] = {}
# NOTE: Using the walrus operator here and below helps pyre.
if (covar_class := model_config.covar_module_class) is not None:
_error_if_arg_not_supported("covar_module")
covar_module_kwargs = covar_module_argparse(
covar_class,
dataset=dataset,
botorch_model_class=botorch_model_class,
**deepcopy(model_config.covar_module_options),
)
# pyre-ignore [45]: Cannot instantiate abstract class `Kernel`.
submodules["covar_module"] = covar_class(**covar_module_kwargs)
if (likelihood_class := model_config.likelihood_class) is not None:
_error_if_arg_not_supported("likelihood")
# pyre-ignore [45]: Cannot instantiate abstract class `Likelihood`.
submodules["likelihood"] = likelihood_class(
**deepcopy(model_config.likelihood_options)
)
if (
input_transform_classes := model_config.input_transform_classes
) is not None or search_space_digest.robust_digest is not None:
_error_if_arg_not_supported("input_transform")
submodules["input_transform"] = _make_botorch_input_transform(
input_transform_classes=input_transform_classes or [],
input_transform_options=model_config.input_transform_options or {},
dataset=dataset,
search_space_digest=search_space_digest,
)
if (
outcome_transform_classes := model_config.outcome_transform_classes
) is not None:
_error_if_arg_not_supported("outcome_transform")
submodules["outcome_transform"] = _make_botorch_outcome_transform(
outcome_transform_classes=outcome_transform_classes,
outcome_transform_options=model_config.outcome_transform_options or {},
dataset=dataset,
)
elif "outcome_transform" in botorch_model_class_args:
# This is a temporary solution until all BoTorch models use
# `Standardize` by default, see TODO [T197435440].
# After this, we should update `Surrogate` to use `DEFAULT`
# (https://fburl.com/code/22f4397e) for both of these args. This will
# allow users to explicitly disable the default transforms by passing
# in `None`.
submodules["outcome_transform"] = None
return submodules
def _raise_deprecation_warning(
is_surrogate: bool = False,
**kwargs: Any,
) -> bool:
"""Raise deprecation warnings for deprecated arguments.
Args:
is_surrogate: A boolean indicating whether the warning is called from
Surrogate.
Returns:
A boolean indicating whether any deprecation warnings were raised.
"""
msg = "{k} is deprecated and will be removed in a future version. "
if is_surrogate:
msg += "Please specify {k} via `surrogate_spec.model_configs`."
else:
msg += "Please specify {k} via `model_configs`."
warnings_raised = False
default_is_dict = {"botorch_model_kwargs", "mll_kwargs"}
for k, v in kwargs.items():
should_raise = False
if k in default_is_dict:
if v != {}:
should_raise = True
elif (v is not None and k != "mll_class") or (
k == "mll_class" and v is not ExactMarginalLogLikelihood
):
should_raise = True
if should_raise:
warnings.warn(
msg.format(k=k),
DeprecationWarning,
stacklevel=3,
)
warnings_raised = True
return warnings_raised
[docs]
def get_model_config_from_deprecated_args(
botorch_model_class: type[Model] | None,
model_options: dict[str, Any] | None,
mll_class: type[MarginalLogLikelihood] | None,
mll_options: dict[str, Any] | None,
outcome_transform_classes: list[type[OutcomeTransform]] | None,
outcome_transform_options: dict[str, dict[str, Any]] | None,
input_transform_classes: list[type[InputTransform]] | None,
input_transform_options: dict[str, dict[str, Any]] | None,
covar_module_class: type[Kernel] | None,
covar_module_options: dict[str, Any] | None,
likelihood_class: type[Likelihood] | None,
likelihood_options: dict[str, Any] | None,
) -> ModelConfig:
"""Construct a ModelConfig from deprecated arguments."""
model_config_kwargs = {
"botorch_model_class": botorch_model_class,
"model_options": (model_options or {}).copy(),
"mll_class": mll_class,
"mll_options": (mll_options or {}).copy(),
"outcome_transform_classes": outcome_transform_classes,
"outcome_transform_options": outcome_transform_options,
"input_transform_classes": input_transform_classes,
"input_transform_options": input_transform_options,
"covar_module_class": covar_module_class,
"covar_module_options": covar_module_options,
"likelihood_class": likelihood_class,
"likelihood_options": likelihood_options,
}
model_config_kwargs = {
k: v for k, v in model_config_kwargs.items() if v is not None
}
# pyre-fixme [6]: Incompatible parameter type [6]: In call
# `ModelConfig.__init__`, for 1st positional argument, expected
# `Dict[str, typing.Any]` but got `Union[Dict[str, typing.Any],
# Dict[str, Dict[str, typing.Any]], Sequence[Type[InputTransform]],
# Sequence[Type[OutcomeTransform]], Type[Union[MarginalLogLikelihood,
# Model]], Type[Likelihood], Type[Kernel]]`.
return ModelConfig(**model_config_kwargs, name="from deprecated args")
[docs]
@dataclass(frozen=True)
class SurrogateSpec:
"""
Fields in the SurrogateSpec dataclass correspond to arguments in
``Surrogate.__init__``, except for ``outcomes`` which is used to specify which
outcomes the Surrogate is responsible for modeling.
When ``BotorchModel.fit`` is called, these fields will be used to construct the
requisite Surrogate objects.
If ``outcomes`` is left empty then no outcomes will be fit to the Surrogate.
Args:
botorch_model_class: ``Model`` class to be used as the underlying
BoTorch model. If None is provided a model class will be selected (either
one for all outcomes or a ModelList with separate models for each outcome)
will be selected automatically based off the datasets at `construct` time.
This argument is deprecated in favor of model_configs.
model_options: Dictionary of options / kwargs for the BoTorch
``Model`` constructed during ``Surrogate.fit``.
Note that the corresponding attribute will later be updated to include any
additional kwargs passed into ``BoTorchModel.fit``.
This argument is deprecated in favor of model_configs.
mll_class: ``MarginalLogLikelihood`` class to use for model-fitting.
This argument is deprecated in favor of model_configs.
mll_options: Dictionary of options / kwargs for the MLL. This argument is
deprecated in favor of model_configs.
outcome_transform_classes: List of BoTorch outcome transforms classes. Passed
down to the BoTorch ``Model``. Multiple outcome transforms can be chained
together using ``ChainedOutcomeTransform``. This argument is deprecated in
favor of model_configs.
outcome_transform_options: Outcome transform classes kwargs. The keys are
class string names and the values are dictionaries of outcome transform
kwargs. For example,
`
outcome_transform_classes = [Standardize]
outcome_transform_options = {
"Standardize": {"m": 1},
`
For more options see `botorch/models/transforms/outcome.py`. This argument
is deprecated in favor of model_configs.
input_transform_classes: List of BoTorch input transforms classes.
Passed down to the BoTorch ``Model``. Multiple input transforms
will be chained together using ``ChainedInputTransform``.
This argument is deprecated in favor of model_configs.
input_transform_options: Input transform classes kwargs. The keys are
class string names and the values are dictionaries of input transform
kwargs. For example,
`
input_transform_classes = [Normalize, Round]
input_transform_options = {
"Normalize": {"d": 3},
"Round": {"integer_indices": [0], "categorical_features": {1: 2}},
}
`
For more input options see `botorch/models/transforms/input.py`.
This argument is deprecated in favor of model_configs.
covar_module_class: Covariance module class. This gets initialized after
parsing the ``covar_module_options`` in ``covar_module_argparse``,
and gets passed to the model constructor as ``covar_module``.
This argument is deprecated in favor of model_configs.
covar_module_options: Covariance module kwargs. This argument is deprecated
in favor of model_configs.
likelihood: ``Likelihood`` class. This gets initialized with
``likelihood_options`` and gets passed to the model constructor.
This argument is deprecated in favor of model_configs.
likelihood_options: Likelihood options. This argument is deprecated in favor
of model_configs.
model_configs: List of model configs. Each model config is a specification of
a model. These should be used in favor of the above deprecated arguments.
metric_to_model_configs: Dictionary mapping metric names to a list of model
configs for that metric.
eval_criterion: The name of the evaluation criteria to use. These are defined in
``ax.utils.stats.model_fit_stats``. Defaults to rank correlation.
outcomes: List of outcomes names.
use_posterior_predictive: Whether to use posterior predictive in
cross-validation.
"""
botorch_model_class: type[Model] | None = None
botorch_model_kwargs: dict[str, Any] = field(default_factory=dict)
mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood
mll_kwargs: dict[str, Any] = field(default_factory=dict)
covar_module_class: type[Kernel] | None = None
covar_module_kwargs: dict[str, Any] | None = None
likelihood_class: type[Likelihood] | None = None
likelihood_kwargs: dict[str, Any] | None = None
input_transform_classes: list[type[InputTransform]] | None = None
input_transform_options: dict[str, dict[str, Any]] | None = None
outcome_transform_classes: list[type[OutcomeTransform]] | None = None
outcome_transform_options: dict[str, dict[str, Any]] | None = None
allow_batched_models: bool = True
model_configs: list[ModelConfig] = field(default_factory=list)
metric_to_model_configs: dict[str, list[ModelConfig]] = field(default_factory=dict)
eval_criterion: str = RANK_CORRELATION
outcomes: list[str] = field(default_factory=list)
use_posterior_predictive: bool = False
def __post_init__(self) -> None:
warnings_raised = _raise_deprecation_warning(
is_surrogate=False,
botorch_model_class=self.botorch_model_class,
botorch_model_kwargs=self.botorch_model_kwargs,
mll_class=self.mll_class,
mll_kwargs=self.mll_kwargs,
outcome_transform_classes=self.outcome_transform_classes,
outcome_transform_options=self.outcome_transform_options,
input_transform_classes=self.input_transform_classes,
input_transform_options=self.input_transform_options,
covar_module_class=self.covar_module_class,
covar_module_options=self.covar_module_kwargs,
likelihood_class=self.likelihood_class,
likelihood_options=self.likelihood_kwargs,
)
if len(self.model_configs) == 0:
model_config = get_model_config_from_deprecated_args(
botorch_model_class=self.botorch_model_class,
model_options=self.botorch_model_kwargs,
mll_class=self.mll_class,
mll_options=self.mll_kwargs,
outcome_transform_classes=self.outcome_transform_classes,
outcome_transform_options=self.outcome_transform_options,
input_transform_classes=self.input_transform_classes,
input_transform_options=self.input_transform_options,
covar_module_class=self.covar_module_class,
covar_module_options=self.covar_module_kwargs,
likelihood_class=self.likelihood_class,
likelihood_options=self.likelihood_kwargs,
)
# re-initialize with the non-deprecated arguments
self.__init__(
allow_batched_models=self.allow_batched_models,
model_configs=[model_config],
metric_to_model_configs=self.metric_to_model_configs,
eval_criterion=self.eval_criterion,
outcomes=self.outcomes,
use_posterior_predictive=self.use_posterior_predictive,
)
elif warnings_raised:
raise UserInputError(
"model_configs and deprecated arguments were both specified. "
"Please use model_configs and remove deprecated arguments."
)
[docs]
class Surrogate(Base):
"""
**All classes in 'botorch_modular' directory are under
construction, incomplete, and should be treated as alpha
versions only.**
Ax wrapper for BoTorch ``Model``, subcomponent of ``BoTorchModel``
and is not meant to be used outside of it.
Args:
botorch_model_class: ``Model`` class to be used as the underlying
BoTorch model. If None is provided a model class will be selected (either
one for all outcomes or a ModelList with separate models for each outcome)
will be selected automatically based off the datasets at `construct` time.
This argument is deprecated in favor of model_configs.
model_options: Dictionary of options / kwargs for the BoTorch
``Model`` constructed during ``Surrogate.fit``.
Note that the corresponding attribute will later be updated to include any
additional kwargs passed into ``BoTorchModel.fit``.
This argument is deprecated in favor of model_configs.
mll_class: ``MarginalLogLikelihood`` class to use for model-fitting.
This argument is deprecated in favor of model_configs.
mll_options: Dictionary of options / kwargs for the MLL. This argument is
deprecated in favor of model_configs.
outcome_transform_classes: List of BoTorch outcome transforms classes. Passed
down to the BoTorch ``Model``. Multiple outcome transforms can be chained
together using ``ChainedOutcomeTransform``. This argument is deprecated in
favor of model_configs.
outcome_transform_options: Outcome transform classes kwargs. The keys are
class string names and the values are dictionaries of outcome transform
kwargs. For example,
`
outcome_transform_classes = [Standardize]
outcome_transform_options = {
"Standardize": {"m": 1},
`
For more options see `botorch/models/transforms/outcome.py`. This argument
is deprecated in favor of model_configs.
input_transform_classes: List of BoTorch input transforms classes.
Passed down to the BoTorch ``Model``. Multiple input transforms
will be chained together using ``ChainedInputTransform``.
This argument is deprecated in favor of model_configs.
input_transform_options: Input transform classes kwargs. The keys are
class string names and the values are dictionaries of input transform
kwargs. For example,
`
input_transform_classes = [Normalize, Round]
input_transform_options = {
"Normalize": {"d": 3},
"Round": {"integer_indices": [0], "categorical_features": {1: 2}},
}
`
For more input options see `botorch/models/transforms/input.py`.
This argument is deprecated in favor of model_configs.
covar_module_class: Covariance module class. This gets initialized after
parsing the ``covar_module_options`` in ``covar_module_argparse``,
and gets passed to the model constructor as ``covar_module``.
This argument is deprecated in favor of model_configs.
covar_module_options: Covariance module kwargs. This argument is deprecated
in favor of model_configs.
likelihood: ``Likelihood`` class. This gets initialized with
``likelihood_options`` and gets passed to the model constructor.
This argument is deprecated in favor of model_configs.
likelihood_options: Likelihood options. This argument is deprecated in favor
of model_configs.
allow_batched_models: Set to true to fit the models in a batch if supported.
Set to false to fit individual models to each metric in a loop.
refit_on_cv: Whether to refit the model on the cross-validation folds.
metric_to_best_model_config: Dictionary mapping a tuple of metric names
to the best model config. This is only used by BotorchModel.cross_validate
and for logging what model was used.
"""
def __init__(
self,
surrogate_spec: SurrogateSpec | None = None,
botorch_model_class: type[Model] | None = None,
model_options: dict[str, Any] | None = None,
mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood,
mll_options: dict[str, Any] | None = None,
outcome_transform_classes: list[type[OutcomeTransform]] | None = None,
outcome_transform_options: dict[str, dict[str, Any]] | None = None,
input_transform_classes: list[type[InputTransform]] | None = None,
input_transform_options: dict[str, dict[str, Any]] | None = None,
covar_module_class: type[Kernel] | None = None,
covar_module_options: dict[str, Any] | None = None,
likelihood_class: type[Likelihood] | None = None,
likelihood_options: dict[str, Any] | None = None,
allow_batched_models: bool = True,
refit_on_cv: bool = False,
metric_to_best_model_config: dict[tuple[str], ModelConfig] | None = None,
) -> None:
warnings_raised = _raise_deprecation_warning(
is_surrogate=True,
botorch_model_class=botorch_model_class,
model_options=model_options,
mll_class=mll_class,
mll_options=mll_options,
outcome_transform_classes=outcome_transform_classes,
outcome_transform_options=outcome_transform_options,
input_transform_classes=input_transform_classes,
input_transform_options=input_transform_options,
covar_module_class=covar_module_class,
covar_module_options=covar_module_options,
likelihood_class=likelihood_class,
likelihood_options=likelihood_options,
)
# check if surrogate_spec is provided
if surrogate_spec is None:
# create surrogate spec from deprecated arguments
model_config = get_model_config_from_deprecated_args(
botorch_model_class=botorch_model_class,
model_options=model_options,
mll_class=mll_class,
mll_options=mll_options,
outcome_transform_classes=outcome_transform_classes,
outcome_transform_options=outcome_transform_options,
input_transform_classes=input_transform_classes,
input_transform_options=input_transform_options,
covar_module_class=covar_module_class,
covar_module_options=covar_module_options,
likelihood_class=likelihood_class,
likelihood_options=likelihood_options,
)
surrogate_spec = SurrogateSpec(
model_configs=[model_config], allow_batched_models=allow_batched_models
)
elif warnings_raised:
raise UserInputError(
"model_configs and deprecated arguments were both specified. "
"Please use model_configs and remove deprecated arguments."
)
self.surrogate_spec: SurrogateSpec = surrogate_spec
# Store the last dataset used to fit the model for a given metric(s).
# If the new dataset is identical, we will skip model fitting for that metric.
# The keys are `tuple(dataset.outcome_names)`.
self._last_datasets: dict[tuple[str], SupervisedDataset] = {}
# Store a reference from a tuple of metric names to the BoTorch Model
# corresponding to those metrics. In most cases this will be a one-tuple,
# though we need n-tuples for LCE-M models. This will be used to skip model
# construction & fitting if the datasets are identical.
self._submodels: dict[tuple[str], Model] = {}
self.metric_to_best_model_config: dict[tuple[str], ModelConfig] = (
metric_to_best_model_config or {}
)
# Store a reference to search space digest used while fitting the cached models.
# We will re-fit the models if the search space digest changes.
self._last_search_space_digest: SearchSpaceDigest | None = None
# These are later updated during model fitting.
self._training_data: list[SupervisedDataset] | None = None
self._outcomes: list[str] | None = None
self._model: Model | None = None
self.refit_on_cv = refit_on_cv
def __repr__(self) -> str:
return f"<{self.__class__.__name__}" f" surrogate_spec={self.surrogate_spec}>"
@property
def model(self) -> Model:
if self._model is None:
raise ValueError(
"BoTorch `Model` has not yet been constructed, please fit the "
"surrogate first (done via `BoTorchModel.fit`)."
)
return self._model
@property
def training_data(self) -> list[SupervisedDataset]:
if self._training_data is None:
raise ValueError(NOT_YET_FIT_MSG)
return self._training_data
@property
def Xs(self) -> list[Tensor]:
# Handles multi-output models. TODO: Improve this!
training_data = self.training_data
Xs = []
for dataset in training_data:
if isinstance(dataset, RankingDataset):
# directly accessing the d-dim X tensor values
# instead of the augmented 2*d-dim dataset.X from RankingDataset
Xi = checked_cast(SliceContainer, dataset._X).values
else:
Xi = dataset.X
for _ in range(dataset.Y.shape[-1]):
Xs.append(Xi)
return Xs
@property
def dtype(self) -> torch.dtype:
return self.training_data[0].X.dtype
@property
def device(self) -> torch.device:
return self.training_data[0].X.device
[docs]
def clone_reset(self) -> Surrogate:
return self.__class__(**self._serialize_attributes_as_kwargs())
def _construct_model(
self,
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
model_config: ModelConfig,
default_botorch_model_class: type[Model],
state_dict: OrderedDict[str, Tensor] | None,
refit: bool,
) -> Model:
"""Constructs the underlying BoTorch ``Model`` using the training data.
If the dataset and model class are identical to those used while training
the cached sub-model, we skip model fitting and return the cached model.
Args:
dataset: Training data for the model (for one outcome for
the default `Surrogate`, with the exception of batched
multi-output case, where training data is formatted with just
one X and concatenated Ys).
search_space_digest: Search space digest used to set up model arguments.
model_config: The model_config.
default_botorch_model_class: The default ``Model`` class to be used as the
underlying BoTorch model, if the model_config does not specify one.
state_dict: Optional state dict to load. This should be subsetted for
the current submodel being constructed.
refit: Whether to re-optimize model parameters.
"""
outcome_names = tuple(dataset.outcome_names)
botorch_model_class = (
model_config.botorch_model_class or default_botorch_model_class
)
if self._should_reuse_last_model(dataset=dataset):
return self._submodels[outcome_names]
formatted_model_inputs = submodel_input_constructor(
botorch_model_class, # Do not pass as kwarg since this is used to dispatch.
model_config=model_config,
dataset=dataset,
search_space_digest=search_space_digest,
surrogate=self,
)
# pyre-ignore [45]
model = botorch_model_class(**formatted_model_inputs)
if state_dict is not None:
model.load_state_dict(state_dict)
if state_dict is None or refit:
fit_botorch_model(
model=model,
mll_class=model_config.mll_class,
mll_options=model_config.mll_options,
)
return model
def _should_reuse_last_model(
self,
dataset: SupervisedDataset,
) -> bool:
"""Checks whether the given dataset and model class match the last
dataset.
"""
outcome_names = tuple(dataset.outcome_names)
return (
outcome_names in self._submodels
and dataset == self._last_datasets[outcome_names]
)
[docs]
def fit(
self,
datasets: Sequence[SupervisedDataset],
search_space_digest: SearchSpaceDigest,
candidate_metadata: list[list[TCandidateMetadata]] | None = None,
state_dict: OrderedDict[str, Tensor] | None = None,
refit: bool = True,
) -> None:
"""Fits the underlying BoTorch ``Model`` to ``m`` outcomes.
NOTE: ``state_dict`` and ``refit`` keyword arguments control how the
undelying BoTorch ``Model`` will be fit: whether its parameters will
be reoptimized and whether it will be warm-started from a given state.
There are three possibilities:
* ``fit(state_dict=None)``: fit model from scratch (optimize model
parameters and set its training data used for inference),
* ``fit(state_dict=some_state_dict, refit=True)``: warm-start refit
with a state dict of parameters (still re-optimize model parameters
and set the training data),
* ``fit(state_dict=some_state_dict, refit=False)``: load model parameters
without refitting, but set new training data (used in cross-validation,
for example).
Args:
datasets: A list of ``SupervisedDataset`` containers, each
corresponding to the data of one metric (outcome), to be passed
to ``Model.construct_inputs`` in BoTorch.
search_space_digest: A ``SearchSpaceDigest`` object containing
metadata on the features in the datasets.
candidate_metadata: Model-produced metadata for candidates, in
the order corresponding to the Xs.
state_dict: Optional state dict to load.
refit: Whether to re-optimize model parameters.
"""
self._discard_cached_model_and_data_if_search_space_digest_changed(
search_space_digest=search_space_digest
)
# To determine whether to use ModelList under the hood, we need to check for
# the batched multi-output case, so we first see which model would be chosen
# given the Yvars and the properties of data.
default_botorch_model_class = choose_model_class(
datasets=datasets, search_space_digest=search_space_digest
)
should_use_model_list = use_model_list(
datasets=datasets,
botorch_model_class=default_botorch_model_class,
model_configs=self.surrogate_spec.model_configs,
allow_batched_models=self.surrogate_spec.allow_batched_models,
metric_to_model_configs=self.surrogate_spec.metric_to_model_configs,
)
if not should_use_model_list and len(datasets) > 1:
datasets = convert_to_block_design(datasets=datasets, force=True)
self._training_data = list(datasets) # So that it can be modified if needed.
models = []
outcome_names = []
for i, dataset in enumerate(datasets):
submodel_state_dict = None
if state_dict is not None:
if should_use_model_list:
submodel_state_dict = subset_state_dict(
state_dict=state_dict, submodel_index=i
)
else:
submodel_state_dict = state_dict
outcome_name = dataset.outcome_names[0]
model_configs = (
self.surrogate_spec.metric_to_model_configs[outcome_name]
if outcome_name in self.surrogate_spec.metric_to_model_configs
else self.surrogate_spec.model_configs
)
# Case 1: There is either 1 model config, or we don't want to refit
# and we know what the previous best model was
outcome_name_tuple = tuple(dataset.outcome_names)
model_config = self.metric_to_best_model_config.get(outcome_name_tuple)
if len(model_configs) == 1 or (not refit and model_config is not None):
best_model_config = model_config or model_configs[0]
model = self._construct_model(
dataset=dataset,
search_space_digest=search_space_digest,
model_config=best_model_config,
default_botorch_model_class=default_botorch_model_class,
state_dict=submodel_state_dict,
refit=refit,
)
# Case 2: There is more than 1 model config and we want to refit
# or don't know what the previous best model was
else:
if len(dataset.outcome_names) > 1:
raise UnsupportedError(
"Multiple model configs are not supported with datasets that"
" contain multiple outcomes. Each dataset must contain only "
"one outcome."
)
model, best_model_config = self.model_selection(
dataset=dataset,
model_configs=model_configs,
default_botorch_model_class=default_botorch_model_class,
search_space_digest=search_space_digest,
candidate_metadata=candidate_metadata,
)
models.append(model)
outcome_names.extend(dataset.outcome_names)
# store best model config, model, and dataset
self.metric_to_best_model_config[outcome_name_tuple] = none_throws(
best_model_config
)
self._submodels[outcome_name_tuple] = model
self._last_datasets[outcome_name_tuple] = dataset
if should_use_model_list:
self._model = ModelListGP(*models)
else:
self._model = models[0]
self._outcomes = outcome_names # In the order of input datasets
[docs]
def model_selection(
self,
dataset: SupervisedDataset,
model_configs: list[ModelConfig],
default_botorch_model_class: type[Model],
search_space_digest: SearchSpaceDigest,
candidate_metadata: list[list[TCandidateMetadata]] | None = None,
) -> tuple[Model, ModelConfig]:
"""Perform model selection over a list of model configs.
This selects the best botorch Model across the provided model configs
based on the SurrogateSpec's eval_criteria. The eval_criteria is
computed using LOOCV on the provided dataset. The best model config is saved
in self.metric_to_best_model_config for future use (e.g. for using cross-
validation at the Modelbridge level).
Args:
dataset: Training data for the model
model_configs: The model_configs.
default_botorch_model_class: The default ``Model`` class to be used as
the default, if no botorch_model_class is specified in the
model_config.
search_space_digest: Search space digest.
candidate_metadata: Model-produced metadata for candidates.
Returns:
A two element tuple containing:
- The best model according to the eval_criterion.
- The ModelConfig for the best model.
"""
if (
isinstance(dataset, MultiTaskDataset)
and assert_is_instance(dataset, MultiTaskDataset).has_heterogeneous_features
):
raise UnsupportedError(
"Model selection is not supported for datasets with heterogeneous "
"features."
)
# loop over model configs, fit model for each config, perform LOOCV, select
# best model according to specified criterion
maximize = (
DIAGNOSTIC_FN_DIRECTIONS[self.surrogate_spec.eval_criterion]
== ModelFitMetricDirection.MAXIMIZE
)
prefix = "-" if maximize else ""
best_eval_metric = float(f"{prefix}inf")
best_model = None
best_model_config = None
for model_config in model_configs:
# fit model to all data
try:
model = self._construct_model(
dataset=dataset,
search_space_digest=search_space_digest,
model_config=model_config,
default_botorch_model_class=none_throws(
default_botorch_model_class
),
state_dict=None,
refit=True,
)
state_dict = model.state_dict()
# perform LOOCV
eval_metric = self.cross_validate(
dataset=dataset,
search_space_digest=search_space_digest,
model_config=model_config,
default_botorch_model_class=none_throws(
default_botorch_model_class
),
# pyre-fixme [6]: In call `Surrogate.cross_validate`, for argument
# `state_dict`, expected `Optional[OrderedDict[str, Tensor]]` but
# got `Dict[str, typing.Any]`.
state_dict=state_dict,
)
except ModelFittingError as e:
logger.info(
f"Model {model_config} failed to fit with error {e}. Skipping."
)
continue
if maximize ^ (eval_metric < best_eval_metric):
best_eval_metric = eval_metric
best_model = model
best_model_config = model_config
if best_model is None:
raise AxError(
"No model configs were able to fit the data. Please check your "
"model configs and/or data."
)
return none_throws(best_model), none_throws(best_model_config)
[docs]
def cross_validate(
self,
dataset: SupervisedDataset,
model_config: ModelConfig,
default_botorch_model_class: type[Model],
search_space_digest: SearchSpaceDigest,
state_dict: OrderedDict[str, Tensor] | None = None,
) -> float:
"""Cross-validation for a single outcome.
Args:
dataset: Training data for the model (for one outcome for
the default `Surrogate`, with the exception of batched
multi-output case, where training data is formatted with just
one X and concatenated Ys).
model_config: The model_config.
default_botorch_model_class: The default ``Model`` class to be used as the
underlying BoTorch model, if the model_config does not specify one.
search_space_digest: Search space digest used to set up model arguments.
state_dict: Optional state dict to load.
Returns:
The eval criterion value for the given model config.
"""
if isinstance(dataset, MultiTaskDataset):
# only evaluate model on target task
target_dataset = dataset.datasets[dataset.target_outcome_name]
X, Y = target_dataset.X, target_dataset.Y
else:
X, Y = dataset.X, dataset.Y
train_mask = torch.ones(X.shape[0], dtype=torch.bool, device=X.device)
pred_Y = np.zeros(X.shape[-2])
pred_Yvar = np.zeros(X.shape[-2])
# TODO: add hyperparameter to set the number of folds
for i in range(X.shape[-2]):
train_mask[i] = 0
# split data into train and test
train_dataset = dataset.clone(mask=train_mask)
# Note: for MT models, `output_tasks` is set on the model, so we
# don't need to add the task feature here if missing
test_X = X[i : i + 1]
# fit model to all but one data point
# TODO: consider batchifying
loo_model = self._construct_model(
dataset=train_dataset,
search_space_digest=search_space_digest,
model_config=model_config,
default_botorch_model_class=none_throws(default_botorch_model_class),
# pyre-fixme [6]: state_dict() has a generic dict[str, Any] return type
# but it is actually an OrderedDict[str, Tensor].
state_dict=state_dict,
refit=self.refit_on_cv,
)
# evaluate model
with torch.no_grad():
posterior = loo_model.posterior(
test_X,
observation_noise=self.surrogate_spec.use_posterior_predictive,
)
# TODO: support non-GPyTorch posteriors
posterior = assert_is_instance(posterior, GPyTorchPosterior)
pred_mean = posterior.mean
pred_var = posterior.variance
pred_Y[i] = pred_mean.view(-1).numpy()
pred_Yvar[i] = pred_var.view(-1).numpy()
train_mask[i] = 1
# evaluate model fit metric
diag_fn = DIAGNOSTIC_FNS[none_throws(self.surrogate_spec.eval_criterion)]
# pyre-ignore [28]: Unexpected keyword argument `y_obs` to anonymous call.
return diag_fn(
y_obs=Y.view(-1).numpy(),
y_pred=pred_Y,
se_pred=pred_Yvar,
)
def _discard_cached_model_and_data_if_search_space_digest_changed(
self, search_space_digest: SearchSpaceDigest
) -> None:
"""Checks whether the search space digest has changed since the last call
to `fit`. If it has, discards cached model and datasets. Also updates
`self._last_search_space_digest` for future checks.
"""
if (
self._last_search_space_digest is not None
and search_space_digest != self._last_search_space_digest
):
logger.info(
"Discarding all previously trained models due to a change "
"in the search space digest."
)
self._submodels = {}
self._last_datasets = {}
self.metric_to_best_model_config = {}
self._last_search_space_digest = search_space_digest
[docs]
def predict(
self, X: Tensor, use_posterior_predictive: bool = False
) -> tuple[Tensor, Tensor]:
"""Predicts outcomes given an input tensor.
Args:
X: A ``n x d`` tensor of input parameters.
use_posterior_predictive: A boolean indicating if the predictions
should be from the posterior predictive (i.e. including
observation noise).
Returns:
Tensor: The predicted posterior mean as an ``n x o``-dim tensor.
Tensor: The predicted posterior covariance as a ``n x o x o``-dim tensor.
"""
return predict_from_model(
model=self.model, X=X, use_posterior_predictive=use_posterior_predictive
)
[docs]
def best_in_sample_point(
self,
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
options: TConfig | None = None,
) -> tuple[Tensor, float]:
"""Finds the best observed point and the corresponding observed outcome
values.
"""
if torch_opt_config.is_moo:
raise NotImplementedError(
"Best observed point is incompatible with MOO problems."
)
best_point_and_observed_value = best_in_sample_point(
Xs=self.Xs,
model=self,
bounds=search_space_digest.bounds,
objective_weights=torch_opt_config.objective_weights,
outcome_constraints=torch_opt_config.outcome_constraints,
linear_constraints=torch_opt_config.linear_constraints,
fixed_features=torch_opt_config.fixed_features,
risk_measure=torch_opt_config.risk_measure,
options=options,
)
if best_point_and_observed_value is None:
raise ValueError("Could not obtain best in-sample point.")
best_point, observed_value = best_point_and_observed_value
return (
# pyre-fixme[16]: Item `ndarray` of `Union[ndarray[typing.Any,
# typing.Any], Tensor]` has no attribute `to`.
best_point.to(dtype=self.dtype, device=torch.device("cpu")),
observed_value,
)
[docs]
def best_out_of_sample_point(
self,
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
options: TConfig | None = None,
) -> tuple[Tensor, Tensor]:
"""Finds the best predicted point and the corresponding value of the
appropriate best point acquisition function.
Args:
search_space_digest: A `SearchSpaceDigest`.
torch_opt_config: A `TorchOptConfig`; none-None `fixed_features` is
not supported.
options: Optional. If present, `seed_inner` (default None) and `qmc`
(default True) will be parsed from `options`; any other keys
will be ignored.
Returns:
A two-tuple (`candidate`, `acqf_value`), where `candidate` is a 1d
Tensor of the best predicted point and `acqf_value` is a scalar (0d)
Tensor of the acquisition function value at the best point.
"""
if torch_opt_config.fixed_features:
# When have fixed features, need `FixedFeatureAcquisitionFunction`
# which has peculiar instantiation (wraps another acquisition fn.),
# so need to figure out how to handle.
# TODO (ref: https://fburl.com/diff/uneqb3n9)
raise NotImplementedError("Fixed features not yet supported.")
options = options or {}
acqf_class, acqf_options = pick_best_out_of_sample_point_acqf_class(
outcome_constraints=torch_opt_config.outcome_constraints,
seed_inner=checked_cast_optional(int, options.get(Keys.SEED_INNER, None)),
qmc=checked_cast(bool, options.get(Keys.QMC, True)),
risk_measure=torch_opt_config.risk_measure,
)
# Avoiding circular import between `Surrogate` and `Acquisition`.
from ax.models.torch.botorch_modular.acquisition import Acquisition
acqf = Acquisition(
surrogate=self,
botorch_acqf_class=acqf_class,
search_space_digest=search_space_digest,
torch_opt_config=torch_opt_config,
options=acqf_options,
)
candidates, acqf_value, _ = acqf.optimize(
n=1,
search_space_digest=search_space_digest,
inequality_constraints=_to_inequality_constraints(
linear_constraints=torch_opt_config.linear_constraints
),
fixed_features=torch_opt_config.fixed_features,
)
return candidates[0], acqf_value
[docs]
def pareto_frontier(self) -> tuple[Tensor, Tensor]:
"""For multi-objective optimization, retrieve Pareto frontier instead
of best point.
Returns: A two-tuple of:
- tensor of points in the feature space,
- tensor of corresponding (multiple) outcomes.
"""
raise NotImplementedError("Pareto frontier not yet implemented.")
[docs]
def compute_diagnostics(self) -> dict[str, Any]:
"""Computes model diagnostics like cross-validation measure of fit, etc."""
return {}
def _serialize_attributes_as_kwargs(self) -> dict[str, Any]:
"""Serialize attributes of this surrogate, to be passed back to it
as kwargs on reinstantiation.
"""
return {
"surrogate_spec": self.surrogate_spec,
"refit_on_cv": self.refit_on_cv,
"metric_to_best_model_config": self.metric_to_best_model_config,
}
@property
def outcomes(self) -> list[str]:
if self._outcomes is None:
raise RuntimeError("outcomes not initialized. Please call `fit` first.")
return self._outcomes
@outcomes.setter
def outcomes(self, value: list[str]) -> None:
raise RuntimeError("Setting outcomes manually is disallowed.")
submodel_input_constructor = Dispatcher(
name="submodel_input_constructor", encoder=_argparse_type_encoder
)
@submodel_input_constructor.register(Model)
def _submodel_input_constructor_base(
botorch_model_class: type[Model],
model_config: ModelConfig,
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
surrogate: Surrogate,
) -> dict[str, Any]:
"""Construct the inputs required to initialize a BoTorch model.
Args:
botorch_model_class: The BoTorch model class to instantiate.
model_config: The model config.
dataset: The training data for the model.
search_space_digest: Search space digest used to set up model arguments.
surrogate: A reference to the surrogate that created the model.
This can be used by the constructor to obtain any additional
arguments that are not readily available.
Returns:
A dictionary of inputs for constructing the model.
"""
model_kwargs_from_ss = _extract_model_kwargs(
search_space_digest=search_space_digest
)
formatted_model_inputs: dict[str, Any] = botorch_model_class.construct_inputs(
training_data=dataset,
**model_config.model_options,
**model_kwargs_from_ss,
)
submodules = _construct_submodules(
model_config=model_config,
dataset=dataset,
# This is used when constructing the input transforms.
search_space_digest=search_space_digest,
# Used to check for supported arguments and in covar module input constructors.
botorch_model_class=botorch_model_class,
)
formatted_model_inputs.update(submodules)
return formatted_model_inputs
@submodel_input_constructor.register(MultiTaskGP)
def _submodel_input_constructor_mtgp(
botorch_model_class: type[Model],
model_config: ModelConfig,
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
surrogate: Surrogate,
) -> dict[str, Any]:
if len(dataset.outcome_names) > 1:
raise NotImplementedError("Multi-output Multi-task GPs are not yet supported.")
formatted_model_inputs = _submodel_input_constructor_base(
botorch_model_class=botorch_model_class,
model_config=model_config,
dataset=dataset,
search_space_digest=search_space_digest,
surrogate=surrogate,
)
task_feature = formatted_model_inputs.get("task_feature")
if task_feature is None:
return formatted_model_inputs
# specify output tasks so that model.num_outputs = 1
# since the model only models a single outcome
if formatted_model_inputs.get("output_tasks") is None:
if (search_space_digest.target_values is not None) and (
target_value := search_space_digest.target_values.get(task_feature)
) is not None:
formatted_model_inputs["output_tasks"] = [int(target_value)]
else:
raise UserInputError(
"output_tasks or target task value must be provided for MultiTaskGP."
)
return formatted_model_inputs