#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import inspect
from copy import deepcopy
from logging import Logger
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Type, Union
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata
from ax.exceptions.core import UserInputError
from ax.models.model_utils import best_in_sample_point
from ax.models.torch.botorch_modular.input_constructors.covar_modules import (
covar_module_argparse,
)
from ax.models.torch.botorch_modular.input_constructors.input_transforms import (
input_transform_argparse,
)
from ax.models.torch.botorch_modular.input_constructors.outcome_transform import (
outcome_transform_argparse,
)
from ax.models.torch.botorch_modular.utils import (
choose_model_class,
convert_to_block_design,
fit_botorch_model,
subset_state_dict,
use_model_list,
)
from ax.models.torch.utils import (
_to_inequality_constraints,
pick_best_out_of_sample_point_acqf_class,
predict_from_model,
)
from ax.models.torch_base import TorchOptConfig
from ax.models.types import TConfig
from ax.utils.common.base import Base
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import (
_argparse_type_encoder,
checked_cast,
checked_cast_optional,
)
from botorch.models.model import Model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.pairwise_gp import PairwiseGP
from botorch.models.transforms.input import (
ChainedInputTransform,
InputPerturbation,
InputTransform,
)
from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform
from botorch.utils.containers import SliceContainer
from botorch.utils.datasets import RankingDataset, SupervisedDataset
from botorch.utils.dispatcher import Dispatcher
from gpytorch.kernels import Kernel
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from torch import Tensor
NOT_YET_FIT_MSG = (
"Underlying BoTorch `Model` has not yet received its training_data. "
"Please fit the model first."
)
logger: Logger = get_logger(__name__)
def _extract_model_kwargs(
search_space_digest: SearchSpaceDigest,
) -> Dict[str, Union[List[int], int]]:
"""
Extracts keyword arguments that are passed to the `construct_inputs`
method of a BoTorch `Model` class.
Args:
search_space_digest: A `SearchSpaceDigest`.
Returns:
A dict of fidelity features, categorical features, and, if present, task
features.
"""
fidelity_features = search_space_digest.fidelity_features
task_features = search_space_digest.task_features
if len(fidelity_features) > 0 and len(task_features) > 0:
raise NotImplementedError(
"Multi-Fidelity GP models with task_features are "
"currently not supported."
)
# TODO: Allow each metric having different task_features or fidelity_features
# TODO: Need upstream change in the modelbrdige
if len(task_features) > 1:
raise NotImplementedError("Multiple task features are not supported.")
kwargs: Dict[str, Union[List[int], int]] = {}
if len(search_space_digest.categorical_features) > 0:
kwargs["categorical_features"] = search_space_digest.categorical_features
if len(fidelity_features) > 0:
kwargs["fidelity_features"] = fidelity_features
if len(task_features) == 1:
kwargs["task_feature"] = task_features[0]
return kwargs
[docs]class Surrogate(Base):
"""
**All classes in 'botorch_modular' directory are under
construction, incomplete, and should be treated as alpha
versions only.**
Ax wrapper for BoTorch ``Model``, subcomponent of ``BoTorchModel``
and is not meant to be used outside of it.
Args:
botorch_model_class: ``Model`` class to be used as the underlying
BoTorch model. If None is provided a model class will be selected (either
one for all outcomes or a ModelList with separate models for each outcome)
will be selected automatically based off the datasets at `construct` time.
model_options: Dictionary of options / kwargs for the BoTorch
``Model`` constructed during ``Surrogate.fit``.
Note that the corresponding attribute will later be updated to include any
additional kwargs passed into ``BoTorchModel.fit``.
mll_class: ``MarginalLogLikelihood`` class to use for model-fitting.
mll_options: Dictionary of options / kwargs for the MLL.
outcome_transform_classes: List of BoTorch outcome transforms classes. Passed
down to the BoTorch ``Model``. Multiple outcome transforms can be chained
together using ``ChainedOutcomeTransform``.
outcome_transform_options: Outcome transform classes kwargs. The keys are
class string names and the values are dictionaries of outcome transform
kwargs. For example,
`
outcome_transform_classes = [Standardize]
outcome_transform_options = {
"Standardize": {"m": 1},
`
For more options see `botorch/models/transforms/outcome.py`.
input_transform_classes: List of BoTorch input transforms classes.
Passed down to the BoTorch ``Model``. Multiple input transforms
will be chained together using ``ChainedInputTransform``.
input_transform_options: Input transform classes kwargs. The keys are
class string names and the values are dictionaries of input transform
kwargs. For example,
`
input_transform_classes = [Normalize, Round]
input_transform_options = {
"Normalize": {"d": 3},
"Round": {"integer_indices": [0], "categorical_features": {1: 2}},
}
`
For more input options see `botorch/models/transforms/input.py`.
covar_module_class: Covariance module class. This gets initialized after
parsing the ``covar_module_options`` in ``covar_module_argparse``,
and gets passed to the model constructor as ``covar_module``.
covar_module_options: Covariance module kwargs.
likelihood: ``Likelihood`` class. This gets initialized with
``likelihood_options`` and gets passed to the model constructor.
likelihood_options: Likelihood options.
allow_batched_models: Set to true to fit the models in a batch if supported.
Set to false to fit individual models to each metric in a loop.
"""
def __init__(
self,
botorch_model_class: Optional[Type[Model]] = None,
model_options: Optional[Dict[str, Any]] = None,
mll_class: Type[MarginalLogLikelihood] = ExactMarginalLogLikelihood,
mll_options: Optional[Dict[str, Any]] = None,
outcome_transform_classes: Optional[List[Type[OutcomeTransform]]] = None,
outcome_transform_options: Optional[Dict[str, Dict[str, Any]]] = None,
input_transform_classes: Optional[List[Type[InputTransform]]] = None,
input_transform_options: Optional[Dict[str, Dict[str, Any]]] = None,
covar_module_class: Optional[Type[Kernel]] = None,
covar_module_options: Optional[Dict[str, Any]] = None,
likelihood_class: Optional[Type[Likelihood]] = None,
likelihood_options: Optional[Dict[str, Any]] = None,
allow_batched_models: bool = True,
) -> None:
self.botorch_model_class = botorch_model_class
# Copying model options to avoid mutating the original dict.
# We later update it with any additional kwargs passed into `BoTorchModel.fit`.
self.model_options: Dict[str, Any] = (model_options or {}).copy()
self.mll_class = mll_class
self.mll_options: Dict[str, Any] = mll_options or {}
self.outcome_transform_classes = outcome_transform_classes
self.outcome_transform_options: Dict[str, Any] = outcome_transform_options or {}
self.input_transform_classes = input_transform_classes
self.input_transform_options: Dict[str, Any] = input_transform_options or {}
self.covar_module_class = covar_module_class
self.covar_module_options: Dict[str, Any] = covar_module_options or {}
self.likelihood_class = likelihood_class
self.likelihood_options: Dict[str, Any] = likelihood_options or {}
self.allow_batched_models = allow_batched_models
# Store the last dataset used to fit the model for a given metric(s).
# If the new dataset is identical, we will skip model fitting for that metric.
# The keys are `tuple(dataset.outcome_names)`.
self._last_datasets: Dict[Tuple[str], SupervisedDataset] = {}
# Store a reference from a tuple of metric names to the BoTorch Model
# corresponding to those metrics. In most cases this will be a one-tuple,
# though we need n-tuples for LCE-M models. This will be used to skip model
# construction & fitting if the datasets are identical.
self._submodels: Dict[Tuple[str], Model] = {}
# Store a reference to search space digest used while fitting the cached models.
# We will re-fit the models if the search space digest changes.
self._last_search_space_digest: Optional[SearchSpaceDigest] = None
# These are later updated during model fitting.
self._training_data: Optional[List[SupervisedDataset]] = None
self._outcomes: Optional[List[str]] = None
self._model: Optional[Model] = None
def __repr__(self) -> str:
return (
f"<{self.__class__.__name__}"
f" botorch_model_class={self.botorch_model_class} "
f"mll_class={self.mll_class} "
f"outcome_transform_classes={self.outcome_transform_classes} "
f"input_transform_classes={self.input_transform_classes} "
)
@property
def model(self) -> Model:
if self._model is None:
raise ValueError(
"BoTorch `Model` has not yet been constructed, please fit the "
"surrogate first (done via `BoTorchModel.fit`)."
)
return self._model
@property
def training_data(self) -> List[SupervisedDataset]:
if self._training_data is None:
raise ValueError(NOT_YET_FIT_MSG)
return self._training_data
@property
def Xs(self) -> List[Tensor]:
# Handles multi-output models. TODO: Improve this!
training_data = self.training_data
Xs = []
for dataset in training_data:
if self.botorch_model_class == PairwiseGP and isinstance(
dataset, RankingDataset
):
# directly accessing the d-dim X tensor values
# instead of the augmented 2*d-dim dataset.X from RankingDataset
Xi = checked_cast(SliceContainer, dataset._X).values
else:
Xi = dataset.X
for _ in range(dataset.Y.shape[-1]):
Xs.append(Xi)
return Xs
@property
def dtype(self) -> torch.dtype:
return self.training_data[0].X.dtype
@property
def device(self) -> torch.device:
return self.training_data[0].X.device
[docs] def clone_reset(self) -> Surrogate:
return self.__class__(**self._serialize_attributes_as_kwargs())
def _construct_model(
self,
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
botorch_model_class: Type[Model],
state_dict: Optional[OrderedDict[str, Tensor]],
refit: bool,
) -> Model:
"""Constructs the underlying BoTorch ``Model`` using the training data.
If the dataset and model class are identical to those used while training
the cached sub-model, we skip model fitting and return the cached model.
Args:
dataset: Training data for the model (for one outcome for
the default `Surrogate`, with the exception of batched
multi-output case, where training data is formatted with just
one X and concatenated Ys).
search_space_digest: Search space digest used to set up model arguments.
botorch_model_class: ``Model`` class to be used as the underlying
BoTorch model.
state_dict: Optional state dict to load. This should be subsetted for
the current submodel being constructed.
refit: Whether to re-optimize model parameters.
"""
outcome_names = tuple(dataset.outcome_names)
if self._should_reuse_last_model(
dataset=dataset, botorch_model_class=botorch_model_class
):
return self._submodels[outcome_names]
formatted_model_inputs = submodel_input_constructor(
botorch_model_class, # Do not pass as kwarg since this is used to dispatch.
dataset=dataset,
search_space_digest=search_space_digest,
surrogate=self,
)
# pyre-ignore [45]
model = botorch_model_class(**formatted_model_inputs)
if state_dict is not None:
model.load_state_dict(state_dict)
if state_dict is None or refit:
fit_botorch_model(
model=model, mll_class=self.mll_class, mll_options=self.mll_options
)
self._submodels[outcome_names] = model
self._last_datasets[outcome_names] = dataset
return model
def _should_reuse_last_model(
self, dataset: SupervisedDataset, botorch_model_class: Type[Model]
) -> bool:
"""Checks whether the given dataset and model class match the last
dataset and model class used to train the cached sub-model.
"""
outcome_names = tuple(dataset.outcome_names)
if (
outcome_names in self._submodels
and dataset == self._last_datasets[outcome_names]
):
last_model = self._submodels[outcome_names]
if type(last_model) is not botorch_model_class:
logger.info(
f"The model class for outcome(s) {dataset.outcome_names} "
f"changed from {type(last_model)} to {botorch_model_class}. "
"Will refit the model."
)
else:
return True
return False
def _set_formatted_inputs(
self,
formatted_model_inputs: Dict[str, Any],
# pyre-ignore [2] The proper hint for the second arg is Union[None,
# Type[Kernel], Type[Likelihood], List[Type[OutcomeTransform]],
# List[Type[InputTransform]]]. Keeping it as Any saves us from a
# bunch of checked_cast calls within the for loop.
inputs: List[Tuple[str, Any, Dict[str, Any]]],
dataset: SupervisedDataset,
botorch_model_class_args: List[str],
search_space_digest: SearchSpaceDigest,
) -> None:
for input_name, input_class, input_options in inputs:
if input_class is None:
continue
if input_name not in botorch_model_class_args:
# TODO: We currently only pass in `covar_module` and `likelihood`
# if they are inputs to the BoTorch model. This interface will need
# to be expanded to a ModelFactory, see D22457664, to accommodate
# different models in the future.
raise UserInputError(
f"The BoTorch model class {self.botorch_model_class} does not "
f"support the input {input_name}."
)
input_options = deepcopy(input_options) or {}
if input_name == "covar_module":
covar_module_with_defaults = covar_module_argparse(
input_class,
dataset=dataset,
botorch_model_class=self.botorch_model_class,
**input_options,
)
formatted_model_inputs[input_name] = input_class(
**covar_module_with_defaults
)
elif input_name == "input_transform":
formatted_model_inputs[input_name] = self._make_botorch_input_transform(
input_classes=input_class,
input_options=input_options,
dataset=dataset,
search_space_digest=search_space_digest,
)
elif input_name == "outcome_transform":
formatted_model_inputs[
input_name
] = self._make_botorch_outcome_transform(
input_classes=input_class,
input_options=input_options,
dataset=dataset,
)
else:
formatted_model_inputs[input_name] = input_class(**input_options)
def _make_botorch_input_transform(
self,
input_classes: List[Type[InputTransform]],
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
input_options: Dict[str, Dict[str, Any]],
) -> Optional[InputTransform]:
"""
Makes a BoTorch input transform from the provided input classes and options.
"""
if not (
isinstance(input_classes, list)
and all(issubclass(c, InputTransform) for c in input_classes)
):
raise UserInputError("Expected a list of input transforms.")
if len(input_classes) == 0:
return None
input_transform_kwargs = [
input_transform_argparse(
single_input_class,
dataset=dataset,
search_space_digest=search_space_digest,
input_transform_options=input_options.get(
single_input_class.__name__, {}
),
)
for single_input_class in input_classes
]
input_transforms = [
# pyre-fixme[45]: Cannot instantiate abstract class `InputTransform`.
single_input_class(**single_input_transform_kwargs)
for single_input_class, single_input_transform_kwargs in zip(
input_classes, input_transform_kwargs
)
]
input_instance = (
ChainedInputTransform(
**{f"tf{i}": input_transforms[i] for i in range(len(input_transforms))}
)
if len(input_transforms) > 1
else input_transforms[0]
)
return input_instance
def _make_botorch_outcome_transform(
self,
input_classes: List[Type[OutcomeTransform]],
input_options: Dict[str, Dict[str, Any]],
dataset: SupervisedDataset,
) -> Optional[OutcomeTransform]:
"""
Makes a BoTorch outcome transform from the provided classes and options.
"""
if not (
isinstance(input_classes, list)
and all(issubclass(c, OutcomeTransform) for c in input_classes)
):
raise UserInputError("Expected a list of outcome transforms.")
if len(input_classes) == 0:
return None
outcome_transform_kwargs = [
outcome_transform_argparse(
input_class,
outcome_transform_options=input_options.get(input_class.__name__, {}),
dataset=dataset,
)
for input_class in input_classes
]
outcome_transforms = [
# pyre-fixme[45]: Cannot instantiate abstract class `OutcomeTransform`.
input_class(**single_outcome_transform_kwargs)
for input_class, single_outcome_transform_kwargs in zip(
input_classes, outcome_transform_kwargs
)
]
outcome_transform_instance = (
ChainedOutcomeTransform(
**{f"otf{i}": otf for i, otf in enumerate(outcome_transforms)}
)
if len(outcome_transforms) > 1
else outcome_transforms[0]
)
return outcome_transform_instance
[docs] def fit(
self,
datasets: List[SupervisedDataset],
search_space_digest: SearchSpaceDigest,
candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None,
state_dict: Optional[OrderedDict[str, Tensor]] = None,
refit: bool = True,
) -> None:
"""Fits the underlying BoTorch ``Model`` to ``m`` outcomes.
NOTE: ``state_dict`` and ``refit`` keyword arguments control how the
undelying BoTorch ``Model`` will be fit: whether its parameters will
be reoptimized and whether it will be warm-started from a given state.
There are three possibilities:
* ``fit(state_dict=None)``: fit model from scratch (optimize model
parameters and set its training data used for inference),
* ``fit(state_dict=some_state_dict, refit=True)``: warm-start refit
with a state dict of parameters (still re-optimize model parameters
and set the training data),
* ``fit(state_dict=some_state_dict, refit=False)``: load model parameters
without refitting, but set new training data (used in cross-validation,
for example).
Args:
datasets: A list of ``SupervisedDataset`` containers, each
corresponding to the data of one metric (outcome), to be passed
to ``Model.construct_inputs`` in BoTorch.
search_space_digest: A ``SearchSpaceDigest`` object containing
metadata on the features in the datasets.
candidate_metadata: Model-produced metadata for candidates, in
the order corresponding to the Xs.
state_dict: Optional state dict to load.
refit: Whether to re-optimize model parameters.
"""
self._discard_cached_model_and_data_if_search_space_digest_changed(
search_space_digest=search_space_digest
)
# To determine whether to use ModelList under the hood, we need to check for
# the batched multi-output case, so we first see which model would be chosen
# given the Yvars and the properties of data.
botorch_model_class = self.botorch_model_class or choose_model_class(
datasets=datasets, search_space_digest=search_space_digest
)
should_use_model_list = use_model_list(
datasets=datasets,
botorch_model_class=botorch_model_class,
allow_batched_models=self.allow_batched_models,
)
if not should_use_model_list and len(datasets) > 1:
datasets = convert_to_block_design(datasets=datasets, force=True)
self._training_data = datasets
models = []
outcome_names = []
for i, dataset in enumerate(datasets):
submodel_state_dict = None
if state_dict is not None:
if should_use_model_list:
submodel_state_dict = subset_state_dict(
state_dict=state_dict, submodel_index=i
)
else:
submodel_state_dict = state_dict
model = self._construct_model(
dataset=dataset,
search_space_digest=search_space_digest,
botorch_model_class=botorch_model_class,
state_dict=submodel_state_dict,
refit=refit,
)
models.append(model)
outcome_names.extend(dataset.outcome_names)
if should_use_model_list:
self._model = ModelListGP(*models)
else:
self._model = models[0]
self._outcomes = outcome_names # In the order of input datasets
def _discard_cached_model_and_data_if_search_space_digest_changed(
self, search_space_digest: SearchSpaceDigest
) -> None:
"""Checks whether the search space digest has changed since the last call
to `fit`. If it has, discards cached model and datasets. Also updates
`self._last_search_space_digest` for future checks.
"""
if (
self._last_search_space_digest is not None
and search_space_digest != self._last_search_space_digest
):
logger.info(
"Discarding all previously trained models due to a change "
"in the search space digest."
)
self._submodels = {}
self._last_datasets = {}
self._last_search_space_digest = search_space_digest
[docs] def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
"""Predicts outcomes given an input tensor.
Args:
X: A ``n x d`` tensor of input parameters.
Returns:
Tensor: The predicted posterior mean as an ``n x o``-dim tensor.
Tensor: The predicted posterior covariance as a ``n x o x o``-dim tensor.
"""
return predict_from_model(model=self.model, X=X)
[docs] def best_in_sample_point(
self,
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
options: Optional[TConfig] = None,
) -> Tuple[Tensor, float]:
"""Finds the best observed point and the corresponding observed outcome
values.
"""
if torch_opt_config.is_moo:
raise NotImplementedError(
"Best observed point is incompatible with MOO problems."
)
best_point_and_observed_value = best_in_sample_point(
Xs=self.Xs,
model=self,
bounds=search_space_digest.bounds,
objective_weights=torch_opt_config.objective_weights,
outcome_constraints=torch_opt_config.outcome_constraints,
linear_constraints=torch_opt_config.linear_constraints,
fixed_features=torch_opt_config.fixed_features,
risk_measure=torch_opt_config.risk_measure,
options=options,
)
if best_point_and_observed_value is None:
raise ValueError("Could not obtain best in-sample point.")
best_point, observed_value = best_point_and_observed_value
return (
best_point.to(dtype=self.dtype, device=torch.device("cpu")),
observed_value,
)
[docs] def best_out_of_sample_point(
self,
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
options: Optional[TConfig] = None,
) -> Tuple[Tensor, Tensor]:
"""Finds the best predicted point and the corresponding value of the
appropriate best point acquisition function.
"""
if torch_opt_config.fixed_features:
# When have fixed features, need `FixedFeatureAcquisitionFunction`
# which has peculiar instantiation (wraps another acquisition fn.),
# so need to figure out how to handle.
# TODO (ref: https://fburl.com/diff/uneqb3n9)
raise NotImplementedError("Fixed features not yet supported.")
options = options or {}
acqf_class, acqf_options = pick_best_out_of_sample_point_acqf_class(
outcome_constraints=torch_opt_config.outcome_constraints,
seed_inner=checked_cast_optional(int, options.get(Keys.SEED_INNER, None)),
qmc=checked_cast(bool, options.get(Keys.QMC, True)),
risk_measure=torch_opt_config.risk_measure,
)
# Avoiding circular import between `Surrogate` and `Acquisition`.
from ax.models.torch.botorch_modular.acquisition import Acquisition
acqf = Acquisition( # TODO: For multi-fidelity, might need diff. class.
surrogates={"self": self},
botorch_acqf_class=acqf_class,
search_space_digest=search_space_digest,
torch_opt_config=torch_opt_config,
options=acqf_options,
)
candidates, acqf_values = acqf.optimize(
n=1,
search_space_digest=search_space_digest,
inequality_constraints=_to_inequality_constraints(
linear_constraints=torch_opt_config.linear_constraints
),
fixed_features=torch_opt_config.fixed_features,
)
return candidates[0], acqf_values[0]
[docs] def pareto_frontier(self) -> Tuple[Tensor, Tensor]:
"""For multi-objective optimization, retrieve Pareto frontier instead
of best point.
Returns: A two-tuple of:
- tensor of points in the feature space,
- tensor of corresponding (multiple) outcomes.
"""
raise NotImplementedError("Pareto frontier not yet implemented.")
[docs] def compute_diagnostics(self) -> Dict[str, Any]:
"""Computes model diagnostics like cross-validation measure of fit, etc."""
return {}
def _serialize_attributes_as_kwargs(self) -> Dict[str, Any]:
"""Serialize attributes of this surrogate, to be passed back to it
as kwargs on reinstantiation.
"""
return {
"botorch_model_class": self.botorch_model_class,
"model_options": self.model_options,
"mll_class": self.mll_class,
"mll_options": self.mll_options,
"outcome_transform_classes": self.outcome_transform_classes,
"outcome_transform_options": self.outcome_transform_options,
"input_transform_classes": self.input_transform_classes,
"input_transform_options": self.input_transform_options,
"covar_module_class": self.covar_module_class,
"covar_module_options": self.covar_module_options,
"likelihood_class": self.likelihood_class,
"likelihood_options": self.likelihood_options,
"allow_batched_models": self.allow_batched_models,
}
def _extract_construct_input_transform_args(
self, search_space_digest: SearchSpaceDigest
) -> Tuple[Optional[List[Type[InputTransform]]], Dict[str, Dict[str, Any]]]:
"""
Extracts input transform classes and input transform options that will
be used in `self._set_formatted_inputs` and ultimately passed to
BoTorch.
Args:
search_space_digest: A `SearchSpaceDigest`.
Returns:
A tuple containing
- Either `None` or a list of input transform classes,
- A dictionary of input transform options.
"""
# Construct input perturbation if doing robust optimization.
# NOTE: Doing this here rather than in `_set_formatted_inputs` to make sure
# we use the same perturbations for each sub-model.
if (robust_digest := search_space_digest.robust_digest) is not None:
submodel_input_transform_options = {
"InputPerturbation": input_transform_argparse(
InputTransform,
search_space_digest=SearchSpaceDigest(
feature_names=[], bounds=[], robust_digest=robust_digest
),
)
}
submodel_input_transform_classes: List[Type[InputTransform]] = [
InputPerturbation
]
if self.input_transform_classes is not None:
# TODO: Support mixing with user supplied transforms.
raise NotImplementedError(
"User supplied input transforms are not supported "
"in robust optimization."
)
else:
submodel_input_transform_classes = self.input_transform_classes
submodel_input_transform_options = self.input_transform_options
return (
submodel_input_transform_classes,
submodel_input_transform_options,
)
@property
def outcomes(self) -> List[str]:
if self._outcomes is None:
raise RuntimeError("outcomes not initialized. Please call `fit` first.")
return self._outcomes
@outcomes.setter
def outcomes(self, value: List[str]) -> None:
raise RuntimeError("Setting outcomes manually is disallowed.")
submodel_input_constructor = Dispatcher(
name="submodel_input_constructor", encoder=_argparse_type_encoder
)
@submodel_input_constructor.register(Model)
def _submodel_input_constructor_base(
botorch_model_class: Type[Model],
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
surrogate: Surrogate,
) -> Dict[str, Any]:
"""Construct the inputs required to initialize a BoTorch model.
Args:
botorch_model_class: The BoTorch model class to instantiate.
dataset: The training data for the model.
search_space_digest: Search space digest used to set up model arguments.
surrogate: A reference to the surrogate that created the model.
This can be used by the constructor to obtain any additional
arguments that are not readily available.
Returns:
A dictionary of inputs for constructing the model.
"""
model_kwargs_from_ss = _extract_model_kwargs(
search_space_digest=search_space_digest
)
(
input_transform_classes,
input_transform_options,
) = surrogate._extract_construct_input_transform_args(
search_space_digest=search_space_digest
)
formatted_model_inputs = botorch_model_class.construct_inputs(
training_data=dataset,
**surrogate.model_options,
**model_kwargs_from_ss,
)
botorch_model_class_args = inspect.getfullargspec(botorch_model_class).args
surrogate._set_formatted_inputs(
formatted_model_inputs=formatted_model_inputs,
inputs=[
(
"covar_module",
surrogate.covar_module_class,
surrogate.covar_module_options,
),
("likelihood", surrogate.likelihood_class, surrogate.likelihood_options),
(
"outcome_transform",
surrogate.outcome_transform_classes,
surrogate.outcome_transform_options,
),
(
"input_transform",
input_transform_classes,
deepcopy(input_transform_options),
),
],
dataset=dataset,
# This is used when constructing the input transforms.
search_space_digest=search_space_digest,
# This is used to check if the arguments are supported.
botorch_model_class_args=botorch_model_class_args,
)
return formatted_model_inputs