#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
import warnings
from collections.abc import Callable
from copy import deepcopy
from logging import Logger
from typing import Any, Optional
import numpy.typing as npt
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata
from ax.exceptions.core import DataRequiredError
from ax.models.torch.botorch_defaults import (
get_and_fit_model,
get_qLogNEI,
recommend_best_observed_point,
scipy_optimizer,
TAcqfConstructor,
)
from ax.models.torch.utils import (
_datasets_to_legacy_inputs,
_get_X_pending_and_observed,
_to_inequality_constraints,
normalize_indices,
predict_from_model,
subset_model,
)
from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig
from ax.models.types import TConfig
from ax.utils.common.constants import Keys
from ax.utils.common.docutils import copy_doc
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.models import ModelList
from botorch.models.model import Model
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.transforms import is_ensemble
from torch import Tensor
from torch.nn import ModuleList # @manual
logger: Logger = get_logger(__name__)
# pyre-fixme[33]: Aliased annotation cannot contain `Any`.
TModelConstructor = Callable[
[
list[Tensor],
list[Tensor],
list[Tensor],
list[int],
list[int],
list[str],
Optional[dict[str, Tensor]],
Any,
],
Model,
]
TModelPredictor = Callable[[Model, Tensor, bool], tuple[Tensor, Tensor]]
# pyre-fixme[33]: Aliased annotation cannot contain `Any`.
TOptimizer = Callable[
[
AcquisitionFunction,
Tensor,
int,
Optional[list[tuple[Tensor, Tensor, float]]],
Optional[list[tuple[Tensor, Tensor, float]]],
Optional[dict[int, float]],
Optional[Callable[[Tensor], Tensor]],
Any,
],
tuple[Tensor, Tensor],
]
TBestPointRecommender = Callable[
[
TorchModel,
list[tuple[float, float]],
Tensor,
Optional[tuple[Tensor, Tensor]],
Optional[tuple[Tensor, Tensor]],
Optional[dict[int, float]],
Optional[TConfig],
Optional[dict[int, float]],
],
Optional[Tensor],
]
[docs]
class BotorchModel(TorchModel):
r"""
Customizable botorch model.
By default, this uses a noisy Log Expected Improvement (qLogNEI) acquisition
function on top of a model made up of separate GPs, one for each outcome. This
behavior can be modified by providing custom implementations of the following
components:
- a `model_constructor` that instantiates and fits a model on data
- a `model_predictor` that predicts outcomes using the fitted model
- a `acqf_constructor` that creates an acquisition function from a fitted model
- a `acqf_optimizer` that optimizes the acquisition function
- a `best_point_recommender` that recommends a current "best" point (i.e.,
what the model recommends if the learning process ended now)
Args:
model_constructor: A callable that instantiates and fits a model on data,
with signature as described below.
model_predictor: A callable that predicts using the fitted model, with
signature as described below.
acqf_constructor: A callable that creates an acquisition function from a
fitted model, with signature as described below.
acqf_optimizer: A callable that optimizes the acquisition function, with
signature as described below.
best_point_recommender: A callable that recommends the best point, with
signature as described below.
refit_on_cv: If True, refit the model for each fold when performing
cross-validation.
warm_start_refitting: If True, start model refitting from previous
model parameters in order to speed up the fitting process.
prior: An optional dictionary that contains the specification of GP model prior.
Currently, the keys include:
- covar_module_prior: prior on covariance matrix e.g.
{"lengthscale_prior": GammaPrior(3.0, 6.0)}.
- type: type of prior on task covariance matrix e.g.`LKJCovariancePrior`.
- sd_prior: A scalar prior over nonnegative numbers, which is used for the
default LKJCovariancePrior task_covar_prior.
- eta: The eta parameter on the default LKJ task_covar_prior.
Call signatures:
::
model_constructor(
Xs,
Ys,
Yvars,
task_features,
fidelity_features,
metric_names,
state_dict,
**kwargs,
) -> model
Here `Xs`, `Ys`, `Yvars` are lists of tensors (one element per outcome),
`task_features` identifies columns of Xs that should be modeled as a task,
`fidelity_features` is a list of ints that specify the positions of fidelity
parameters in 'Xs', `metric_names` provides the names of each `Y` in `Ys`,
`state_dict` is a pytorch module state dict, and `model` is a BoTorch `Model`.
Optional kwargs are being passed through from the `BotorchModel` constructor.
This callable is assumed to return a fitted BoTorch model that has the same
dtype and lives on the same device as the input tensors.
::
model_predictor(model, X) -> [mean, cov]
Here `model` is a fitted botorch model, `X` is a tensor of candidate points,
and `mean` and `cov` are the posterior mean and covariance, respectively.
::
acqf_constructor(
model,
objective_weights,
outcome_constraints,
X_observed,
X_pending,
**kwargs,
) -> acq_function
Here `model` is a botorch `Model`, `objective_weights` is a tensor of weights
for the model outputs, `outcome_constraints` is a tuple of tensors describing
the (linear) outcome constraints, `X_observed` are previously observed points,
and `X_pending` are points whose evaluation is pending. `acq_function` is a
BoTorch acquisition function crafted from these inputs. For additional
details on the arguments, see `get_qLogNEI`.
::
acqf_optimizer(
acq_function,
bounds,
n,
inequality_constraints,
equality_constraints,
fixed_features,
rounding_func,
**kwargs,
) -> candidates
Here `acq_function` is a BoTorch `AcquisitionFunction`, `bounds` is a tensor
containing bounds on the parameters, `n` is the number of candidates to be
generated, `inequality_constraints` are inequality constraints on parameter
values, `fixed_features` specifies features that should be fixed during
generation, and `rounding_func` is a callback that rounds an optimization
result appropriately. `candidates` is a tensor of generated candidates.
For additional details on the arguments, see `scipy_optimizer`.
::
best_point_recommender(
model,
bounds,
objective_weights,
outcome_constraints,
linear_constraints,
fixed_features,
model_gen_options,
target_fidelities,
) -> candidates
Here `model` is a TorchModel, `bounds` is a list of tuples containing bounds
on the parameters, `objective_weights` is a tensor of weights for the model outputs,
`outcome_constraints` is a tuple of tensors describing the (linear) outcome
constraints, `linear_constraints` is a tuple of tensors describing constraints
on the design, `fixed_features` specifies features that should be fixed during
generation, `model_gen_options` is a config dictionary that can contain
model-specific options, and `target_fidelities` is a map from fidelity feature
column indices to their respective target fidelities, used for multi-fidelity
optimization problems. % TODO: refer to an example.
"""
dtype: torch.dtype | None
device: torch.device | None
Xs: list[Tensor]
Ys: list[Tensor]
Yvars: list[Tensor]
_model: Model | None
_search_space_digest: SearchSpaceDigest | None = None
def __init__(
self,
model_constructor: TModelConstructor = get_and_fit_model,
model_predictor: TModelPredictor = predict_from_model,
acqf_constructor: TAcqfConstructor = get_qLogNEI,
# pyre-fixme[9]: acqf_optimizer declared/used type mismatch
acqf_optimizer: TOptimizer = scipy_optimizer,
best_point_recommender: TBestPointRecommender = recommend_best_observed_point,
refit_on_cv: bool = False,
warm_start_refitting: bool = True,
use_input_warping: bool = False,
use_loocv_pseudo_likelihood: bool = False,
prior: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
warnings.warn(
"The legacy `BotorchModel` and its subclasses, including the current"
f"class `{self.__class__.__name__}`, slated for deprecation. "
"These models will not be supported going forward and may be "
"fully removed in a future release. Please consider using the "
"Modular BoTorch Model (MBM) setup (ax/models/torch/botorch_modular) "
"instead. If you run into a use case that is not supported by MBM, "
"please raise this with an issue at https://github.com/facebook/Ax",
DeprecationWarning,
stacklevel=2,
)
self.model_constructor = model_constructor
self.model_predictor = model_predictor
self.acqf_constructor = acqf_constructor
self.acqf_optimizer = acqf_optimizer
self.best_point_recommender = best_point_recommender
# pyre-fixme[4]: Attribute must be annotated.
self._kwargs = kwargs
self.refit_on_cv = refit_on_cv
self.warm_start_refitting = warm_start_refitting
self.use_input_warping = use_input_warping
self.use_loocv_pseudo_likelihood = use_loocv_pseudo_likelihood
self.prior = prior
self._model: Model | None = None
self.Xs = []
self.Ys = []
self.Yvars = []
self.dtype = None
self.device = None
self.task_features: list[int] = []
self.fidelity_features: list[int] = []
self.metric_names: list[str] = []
[docs]
@copy_doc(TorchModel.fit)
def fit(
self,
datasets: list[SupervisedDataset],
search_space_digest: SearchSpaceDigest,
candidate_metadata: list[list[TCandidateMetadata]] | None = None,
) -> None:
if len(datasets) == 0:
raise DataRequiredError("BotorchModel.fit requires non-empty data sets.")
self.Xs, self.Ys, self.Yvars = _datasets_to_legacy_inputs(datasets=datasets)
self.metric_names = sum((ds.outcome_names for ds in datasets), [])
# Store search space info for later use (e.g. during generation)
self._search_space_digest = search_space_digest
self.dtype = self.Xs[0].dtype
self.device = self.Xs[0].device
self.task_features = normalize_indices(
search_space_digest.task_features, d=self.Xs[0].size(-1)
)
self.fidelity_features = normalize_indices(
search_space_digest.fidelity_features, d=self.Xs[0].size(-1)
)
extra_kwargs = {} if self.prior is None else {"prior": self.prior}
self._model = self.model_constructor( # pyre-ignore [28]
Xs=self.Xs,
Ys=self.Ys,
Yvars=self.Yvars,
task_features=self.task_features,
fidelity_features=self.fidelity_features,
metric_names=self.metric_names,
use_input_warping=self.use_input_warping,
use_loocv_pseudo_likelihood=self.use_loocv_pseudo_likelihood,
**extra_kwargs,
**self._kwargs,
)
[docs]
@copy_doc(TorchModel.predict)
def predict(self, X: Tensor) -> tuple[Tensor, Tensor]:
return self.model_predictor(model=self.model, X=X) # pyre-ignore [28]
[docs]
@copy_doc(TorchModel.gen)
def gen(
self,
n: int,
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
) -> TorchGenResults:
options = torch_opt_config.model_gen_options or {}
acf_options = options.get(Keys.ACQF_KWARGS, {})
optimizer_options = options.get(Keys.OPTIMIZER_KWARGS, {})
if search_space_digest.fidelity_features:
raise NotImplementedError(
"Base BotorchModel does not support fidelity_features."
)
X_pending, X_observed = _get_X_pending_and_observed(
Xs=self.Xs,
objective_weights=torch_opt_config.objective_weights,
bounds=search_space_digest.bounds,
pending_observations=torch_opt_config.pending_observations,
outcome_constraints=torch_opt_config.outcome_constraints,
linear_constraints=torch_opt_config.linear_constraints,
fixed_features=torch_opt_config.fixed_features,
fit_out_of_design=torch_opt_config.fit_out_of_design,
)
model = self.model
# subset model only to the outcomes we need for the optimization 357
if options.get(Keys.SUBSET_MODEL, True):
subset_model_results = subset_model(
model=model,
objective_weights=torch_opt_config.objective_weights,
outcome_constraints=torch_opt_config.outcome_constraints,
)
model = subset_model_results.model
objective_weights = subset_model_results.objective_weights
outcome_constraints = subset_model_results.outcome_constraints
else:
objective_weights = torch_opt_config.objective_weights
outcome_constraints = torch_opt_config.outcome_constraints
bounds_ = torch.tensor(
search_space_digest.bounds, dtype=self.dtype, device=self.device
)
bounds_ = bounds_.transpose(0, 1)
botorch_rounding_func = get_rounding_func(torch_opt_config.rounding_func)
from botorch.exceptions.errors import UnsupportedError
# pyre-fixme[53]: Captured variable `X_observed` is not annotated.
# pyre-fixme[53]: Captured variable `X_pending` is not annotated.
# pyre-fixme[53]: Captured variable `acf_options` is not annotated.
# pyre-fixme[53]: Captured variable `botorch_rounding_func` is not annotated.
# pyre-fixme[53]: Captured variable `bounds_` is not annotated.
# pyre-fixme[53]: Captured variable `model` is not annotated.
# pyre-fixme[53]: Captured variable `objective_weights` is not annotated.
# pyre-fixme[53]: Captured variable `optimizer_options` is not annotated.
# pyre-fixme[53]: Captured variable `outcome_constraints` is not annotated.
def make_and_optimize_acqf(override_qmc: bool = False) -> tuple[Tensor, Tensor]:
add_kwargs = {"qmc": False} if override_qmc else {}
acquisition_function = self.acqf_constructor(
model=model,
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
X_observed=X_observed,
X_pending=X_pending,
**acf_options,
**add_kwargs,
)
acquisition_function = checked_cast(
AcquisitionFunction, acquisition_function
)
# pyre-ignore: [28]
candidates, expected_acquisition_value = self.acqf_optimizer(
acq_function=checked_cast(AcquisitionFunction, acquisition_function),
bounds=bounds_,
n=n,
inequality_constraints=_to_inequality_constraints(
linear_constraints=torch_opt_config.linear_constraints
),
fixed_features=torch_opt_config.fixed_features,
rounding_func=botorch_rounding_func,
**optimizer_options,
)
return candidates, expected_acquisition_value
try:
candidates, expected_acquisition_value = make_and_optimize_acqf()
except UnsupportedError as e: # untested
if "SobolQMCSampler only supports dimensions" in str(e):
# dimension too large for Sobol, let's use IID
candidates, expected_acquisition_value = make_and_optimize_acqf(
override_qmc=True
)
else:
raise e
gen_metadata = {}
if expected_acquisition_value.numel() > 0:
gen_metadata["expected_acquisition_value"] = (
expected_acquisition_value.tolist()
)
return TorchGenResults(
points=candidates.detach().cpu(),
weights=torch.ones(n, dtype=self.dtype),
gen_metadata=gen_metadata,
)
[docs]
@copy_doc(TorchModel.best_point)
def best_point(
self,
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
) -> Tensor | None:
if torch_opt_config.is_moo:
raise NotImplementedError(
"Best observed point is incompatible with MOO problems."
)
target_fidelities = {
k: v
for k, v in search_space_digest.target_values.items()
if k in search_space_digest.fidelity_features
}
return self.best_point_recommender( # pyre-ignore [28]
model=self,
bounds=search_space_digest.bounds,
objective_weights=torch_opt_config.objective_weights,
outcome_constraints=torch_opt_config.outcome_constraints,
linear_constraints=torch_opt_config.linear_constraints,
fixed_features=torch_opt_config.fixed_features,
model_gen_options=torch_opt_config.model_gen_options,
target_fidelities=target_fidelities,
)
[docs]
@copy_doc(TorchModel.cross_validate)
def cross_validate( # pyre-ignore [14]: `search_space_digest` arg not needed here
self,
datasets: list[SupervisedDataset],
X_test: Tensor,
use_posterior_predictive: bool = False,
**kwargs: Any,
) -> tuple[Tensor, Tensor]:
if self._model is None:
raise RuntimeError("Cannot cross-validate model that has not been fitted.")
if self.refit_on_cv:
state_dict = None
else:
state_dict = deepcopy(self.model.state_dict())
Xs, Ys, Yvars = _datasets_to_legacy_inputs(datasets=datasets)
model = self.model_constructor( # pyre-ignore: [28]
Xs=Xs,
Ys=Ys,
Yvars=Yvars,
task_features=self.task_features,
state_dict=state_dict,
fidelity_features=self.fidelity_features,
metric_names=self.metric_names,
refit_model=self.refit_on_cv,
use_input_warping=self.use_input_warping,
use_loocv_pseudo_likelihood=self.use_loocv_pseudo_likelihood,
**self._kwargs,
)
# pyre-ignore: [28]
return self.model_predictor(
model=model, X=X_test, use_posterior_predictive=use_posterior_predictive
)
[docs]
def feature_importances(self) -> npt.NDArray:
return get_feature_importances_from_botorch_model(model=self._model)
@property
def search_space_digest(self) -> SearchSpaceDigest:
if self._search_space_digest is None:
raise RuntimeError(
"`search_space_digest` is not initialized. Please fit the model first."
)
return self._search_space_digest
@search_space_digest.setter
def search_space_digest(self, value: SearchSpaceDigest) -> None:
raise RuntimeError("Setting search_space_digest manually is disallowed.")
@property
def model(self) -> Model:
if self._model is None:
raise RuntimeError(
"`model` is not initialized. Please fit the model first."
)
return self._model
@model.setter
def model(self, model: Model) -> None:
self._model = model # there are a few places that set model directly
[docs]
def get_rounding_func(
rounding_func: Callable[[Tensor], Tensor] | None,
) -> Callable[[Tensor], Tensor] | None:
if rounding_func is None:
botorch_rounding_func = rounding_func
else:
# make sure rounding_func is properly applied to q- and t-batches
def botorch_rounding_func(X: Tensor) -> Tensor:
batch_shape, d = X.shape[:-1], X.shape[-1]
X_round = torch.stack(
[rounding_func(x) for x in X.view(-1, d)] # pyre-ignore: [16]
)
return X_round.view(*batch_shape, d)
return botorch_rounding_func
[docs]
def get_feature_importances_from_botorch_model(
model: Model | ModuleList | None,
) -> npt.NDArray:
"""Get feature importances from a list of BoTorch models.
Args:
models: BoTorch model to get feature importances from.
Returns:
The feature importances as a numpy array where each row sums to 1.
"""
if model is None:
raise RuntimeError(
"Cannot calculate feature_importances without a fitted model."
"Call `fit` first."
)
elif isinstance(model, ModelList):
models = model.models
else:
models = [model]
lengthscales = []
for m in models:
try:
# this can be a ModelList of a SAAS and STGP, so this is a necessary way
# to get the lengthscale
if hasattr(m.covar_module, "base_kernel"):
ls = m.covar_module.base_kernel.lengthscale
else:
ls = m.covar_module.lengthscale
except AttributeError:
ls = None
if ls is None or ls.shape[-1] != m.train_inputs[0].shape[-1]:
# TODO: We could potentially set the feature importances to NaN in this
# case, but this require knowing the batch dimension of this model.
# Consider supporting in the future.
raise NotImplementedError(
"Failed to extract lengthscales from `m.covar_module` "
"and `m.covar_module.base_kernel`"
)
if ls.ndim == 2:
ls = ls.unsqueeze(0)
if is_ensemble(m): # Take the median over the model batch dimension
ls = torch.quantile(ls, q=0.5, dim=0, keepdim=True)
lengthscales.append(ls)
lengthscales = torch.cat(lengthscales, dim=0)
feature_importances = (1 / lengthscales).detach().cpu() # pyre-ignore
# Make sure the sum of feature importances is 1.0 for each metric
feature_importances /= feature_importances.sum(dim=-1, keepdim=True)
return feature_importances.numpy()