#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
import numpy as np
import torch
from ax.core.types import TConfig, TGenMetadata
from ax.models.torch.botorch_defaults import (
get_and_fit_model,
get_NEI,
predict_from_model,
recommend_best_observed_point,
scipy_optimizer,
)
from ax.models.torch.utils import (
_get_X_pending_and_observed,
normalize_indices,
subset_model,
)
from ax.models.torch_base import TorchModel
from ax.utils.common.docutils import copy_doc
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.models.model import Model
from torch import Tensor
logger = get_logger(__name__)
TModelConstructor = Callable[
[
List[Tensor],
List[Tensor],
List[Tensor],
List[int],
List[int],
List[str],
Optional[Dict[str, Tensor]],
Any,
],
Model,
]
TModelPredictor = Callable[[Model, Tensor], Tuple[Tensor, Tensor]]
TAcqfConstructor = Callable[
[
Model,
Tensor,
Optional[Tuple[Tensor, Tensor]],
Optional[Tensor],
Optional[Tensor],
Any,
],
AcquisitionFunction,
]
TOptimizer = Callable[
[
AcquisitionFunction,
Tensor,
int,
Optional[List[Tuple[Tensor, Tensor, float]]],
Optional[Dict[int, float]],
Optional[Callable[[Tensor], Tensor]],
Any,
],
Tuple[Tensor, Tensor],
]
TBestPointRecommender = Callable[
[
TorchModel,
List[Tuple[float, float]],
Tensor,
Optional[Tuple[Tensor, Tensor]],
Optional[Tuple[Tensor, Tensor]],
Optional[Dict[int, float]],
Optional[TConfig],
Optional[Dict[int, float]],
],
Optional[Tensor],
]
[docs]class BotorchModel(TorchModel):
r"""
Customizable botorch model.
By default, this uses a noisy Expected Improvement acquisition function on
top of a model made up of separate GPs, one for each outcome. This behavior
can be modified by providing custom implementations of the following
components:
- a `model_constructor` that instantiates and fits a model on data
- a `model_predictor` that predicts outcomes using the fitted model
- a `acqf_constructor` that creates an acquisition function from a fitted model
- a `acqf_optimizer` that optimizes the acquisition function
- a `best_point_recommender` that recommends a current "best" point (i.e.,
what the model recommends if the learning process ended now)
Args:
model_constructor: A callable that instantiates and fits a model on data,
with signature as described below.
model_predictor: A callable that predicts using the fitted model, with
signature as described below.
acqf_constructor: A callable that creates an acquisition function from a
fitted model, with signature as described below.
acqf_optimizer: A callable that optimizes the acquisition function, with
signature as described below.
best_point_recommender: A callable that recommends the best point, with
signature as described below.
refit_on_cv: If True, refit the model for each fold when performing
cross-validation.
refit_on_update: If True, refit the model after updating the training
data using the `update` method.
warm_start_refitting: If True, start model refitting from previous
model parameters in order to speed up the fitting process.
Call signatures:
::
model_constructor(
Xs,
Ys,
Yvars,
task_features,
fidelity_features,
metric_names,
state_dict,
**kwargs,
) -> model
Here `Xs`, `Ys`, `Yvars` are lists of tensors (one element per outcome),
`task_features` identifies columns of Xs that should be modeled as a task,
`fidelity_features` is a list of ints that specify the positions of fidelity
parameters in 'Xs', `metric_names` provides the names of each `Y` in `Ys`,
`state_dict` is a pytorch module state dict, and `model` is a BoTorch `Model`.
Optional kwargs are being passed through from the `BotorchModel` constructor.
This callable is assumed to return a fitted BoTorch model that has the same
dtype and lives on the same device as the input tensors.
::
model_predictor(model, X) -> [mean, cov]
Here `model` is a fitted botorch model, `X` is a tensor of candidate points,
and `mean` and `cov` are the posterior mean and covariance, respectively.
::
acqf_constructor(
model,
objective_weights,
outcome_constraints,
X_observed,
X_pending,
**kwargs,
) -> acq_function
Here `model` is a botorch `Model`, `objective_weights` is a tensor of weights
for the model outputs, `outcome_constraints` is a tuple of tensors describing
the (linear) outcome constraints, `X_observed` are previously observed points,
and `X_pending` are points whose evaluation is pending. `acq_function` is a
BoTorch acquisition function crafted from these inputs. For additional
details on the arguments, see `get_NEI`.
::
acqf_optimizer(
acq_function,
bounds,
n,
inequality_constraints,
fixed_features,
rounding_func,
**kwargs,
) -> candidates
Here `acq_function` is a BoTorch `AcquisitionFunction`, `bounds` is a tensor
containing bounds on the parameters, `n` is the number of candidates to be
generated, `inequality_constraints` are inequality constraints on parameter
values, `fixed_features` specifies features that should be fixed during
generation, and `rounding_func` is a callback that rounds an optimization
result appropriately. `candidates` is a tensor of generated candidates.
For additional details on the arguments, see `scipy_optimizer`.
::
best_point_recommender(
model,
bounds,
objective_weights,
outcome_constraints,
linear_constraints,
fixed_features,
model_gen_options,
target_fidelities,
) -> candidates
Here `model` is a TorchModel, `bounds` is a list of tuples containing bounds
on the parameters, `objective_weights` is a tensor of weights for the model outputs,
`outcome_constraints` is a tuple of tensors describing the (linear) outcome
constraints, `linear_constraints` is a tuple of tensors describing constraints
on the design, `fixed_features` specifies features that should be fixed during
generation, `model_gen_options` is a config dictionary that can contain
model-specific options, and `target_fidelities` is a map from fidelity feature
column indices to their respective target fidelities, used for multi-fidelity
optimization problems. % TODO: refer to an example.
"""
dtype: Optional[torch.dtype]
device: Optional[torch.device]
Xs: List[Tensor]
Ys: List[Tensor]
Yvars: List[Tensor]
def __init__(
self,
model_constructor: TModelConstructor = get_and_fit_model,
model_predictor: TModelPredictor = predict_from_model,
# pyre-fixme[9]: acqf_constructor has type `Callable[[Model, Tensor,
# Optional[Tuple[Tensor, Tensor]], Optional[Tensor], Optional[Tensor], Any],
# AcquisitionFunction]`; used as `Callable[[Model, Tensor,
# Optional[Tuple[Tensor, Tensor]], Optional[Tensor], Optional[Tensor],
# **(Any)], AcquisitionFunction]`.
acqf_constructor: TAcqfConstructor = get_NEI,
# pyre-fixme[9]: acqf_optimizer has type `Callable[[AcquisitionFunction,
# Tensor, int, Optional[Dict[int, float]], Optional[Callable[[Tensor],
# Tensor]], Any], Tensor]`; used as `Callable[[AcquisitionFunction, Tensor,
# int, Optional[Dict[int, float]], Optional[Callable[[Tensor], Tensor]],
# **(Any)], Tensor]`.
acqf_optimizer: TOptimizer = scipy_optimizer,
best_point_recommender: TBestPointRecommender = recommend_best_observed_point,
refit_on_cv: bool = False,
refit_on_update: bool = True,
warm_start_refitting: bool = True,
**kwargs: Any,
) -> None:
self.model_constructor = model_constructor
self.model_predictor = model_predictor
self.acqf_constructor = acqf_constructor
self.acqf_optimizer = acqf_optimizer
self.best_point_recommender = best_point_recommender
self._kwargs = kwargs
self.refit_on_cv = refit_on_cv
self.refit_on_update = refit_on_update
self.warm_start_refitting = warm_start_refitting
self.model: Optional[Model] = None
self.Xs = []
self.Ys = []
self.Yvars = []
self.dtype = None
self.device = None
self.task_features: List[int] = []
self.fidelity_features: List[int] = []
self.metric_names: List[str] = []
[docs] @copy_doc(TorchModel.fit)
def fit(
self,
Xs: List[Tensor],
Ys: List[Tensor],
Yvars: List[Tensor],
bounds: List[Tuple[float, float]],
task_features: List[int],
feature_names: List[str],
metric_names: List[str],
fidelity_features: List[int],
) -> None:
self.dtype = Xs[0].dtype
self.device = Xs[0].device
self.Xs = Xs
self.Ys = Ys
self.Yvars = Yvars
# ensure indices are non-negative
self.task_features = normalize_indices(task_features, d=Xs[0].size(-1))
self.fidelity_features = normalize_indices(fidelity_features, d=Xs[0].size(-1))
self.metric_names = metric_names
self.model = self.model_constructor( # pyre-ignore [28]
Xs=Xs,
Ys=Ys,
Yvars=Yvars,
task_features=self.task_features,
fidelity_features=self.fidelity_features,
metric_names=self.metric_names,
**self._kwargs,
)
[docs] @copy_doc(TorchModel.predict)
def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
return self.model_predictor(model=self.model, X=X) # pyre-ignore [28]
[docs] @copy_doc(TorchModel.gen)
def gen(
self,
n: int,
bounds: List[Tuple[float, float]],
objective_weights: Tensor,
outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None,
linear_constraints: Optional[Tuple[Tensor, Tensor]] = None,
fixed_features: Optional[Dict[int, float]] = None,
pending_observations: Optional[List[Tensor]] = None,
model_gen_options: Optional[TConfig] = None,
rounding_func: Optional[Callable[[Tensor], Tensor]] = None,
target_fidelities: Optional[Dict[int, float]] = None,
) -> Tuple[Tensor, Tensor, TGenMetadata]:
options = model_gen_options or {}
acf_options = options.get("acquisition_function_kwargs", {})
optimizer_options = options.get("optimizer_kwargs", {})
if target_fidelities:
raise NotImplementedError(
"target_fidelities not implemented for base BotorchModel"
)
X_pending, X_observed = _get_X_pending_and_observed(
Xs=self.Xs,
pending_observations=pending_observations,
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
bounds=bounds,
linear_constraints=linear_constraints,
fixed_features=fixed_features,
)
model = self.model
# subset model only to the outcomes we need for the optimization
if options.get("subset_model", True):
model, objective_weights, outcome_constraints = subset_model(
model=model, # pyre-ignore [6]
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
)
bounds_ = torch.tensor(bounds, dtype=self.dtype, device=self.device)
bounds_ = bounds_.transpose(0, 1)
if linear_constraints is not None:
A, b = linear_constraints
inequality_constraints = []
k, d = A.shape
for i in range(k):
indicies = A[i, :].nonzero().view(-1)
coefficients = -A[i, indicies]
rhs = -b[i, 0]
inequality_constraints.append((indicies, coefficients, rhs))
else:
inequality_constraints = None
acquisition_function = self.acqf_constructor( # pyre-ignore: [28]
model=model,
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
X_observed=X_observed,
X_pending=X_pending,
**acf_options,
)
botorch_rounding_func = get_rounding_func(rounding_func)
# pyre-ignore: [28]
candidates, expected_acquisition_value = self.acqf_optimizer(
acq_function=checked_cast(AcquisitionFunction, acquisition_function),
bounds=bounds_,
n=n,
inequality_constraints=inequality_constraints,
fixed_features=fixed_features,
rounding_func=botorch_rounding_func,
**optimizer_options,
)
return (
candidates.detach().cpu(),
torch.ones(n, dtype=self.dtype),
{"expected_acquisition_value": expected_acquisition_value.tolist()},
)
[docs] @copy_doc(TorchModel.best_point)
def best_point(
self,
bounds: List[Tuple[float, float]],
objective_weights: Tensor,
outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None,
linear_constraints: Optional[Tuple[Tensor, Tensor]] = None,
fixed_features: Optional[Dict[int, float]] = None,
model_gen_options: Optional[TConfig] = None,
target_fidelities: Optional[Dict[int, float]] = None,
) -> Optional[Tensor]:
return self.best_point_recommender( # pyre-ignore [28]
model=self,
bounds=bounds,
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
linear_constraints=linear_constraints,
fixed_features=fixed_features,
model_gen_options=model_gen_options,
target_fidelities=target_fidelities,
)
[docs] @copy_doc(TorchModel.cross_validate)
def cross_validate(
self,
Xs_train: List[Tensor],
Ys_train: List[Tensor],
Yvars_train: List[Tensor],
X_test: Tensor,
) -> Tuple[Tensor, Tensor]:
if self.model is None:
raise RuntimeError("Cannot cross-validate model that has not been fitted")
if self.refit_on_cv:
state_dict = None
else:
state_dict = deepcopy(self.model.state_dict()) # pyre-ignore: [16]
model = self.model_constructor( # pyre-ignore: [28]
Xs=Xs_train,
Ys=Ys_train,
Yvars=Yvars_train,
task_features=self.task_features,
state_dict=state_dict,
fidelity_features=self.fidelity_features,
metric_names=self.metric_names,
**self._kwargs,
)
return self.model_predictor(model=model, X=X_test) # pyre-ignore: [28]
[docs] @copy_doc(TorchModel.update)
def update(self, Xs: List[Tensor], Ys: List[Tensor], Yvars: List[Tensor]) -> None:
if self.model is None:
raise RuntimeError("Cannot update model that has not been fitted")
self.Xs = Xs
self.Ys = Ys
self.Yvars = Yvars
if self.refit_on_update and not self.warm_start_refitting:
state_dict = None # pragma: no cover
else:
state_dict = deepcopy(self.model.state_dict()) # pyre-ignore: [16]
self.model = self.model_constructor( # pyre-ignore: [28]
Xs=self.Xs,
Ys=self.Ys,
Yvars=self.Yvars,
task_features=self.task_features,
state_dict=state_dict,
fidelity_features=self.fidelity_features,
metric_names=self.metric_names,
refit_model=self.refit_on_update,
**self._kwargs,
)
[docs] def feature_importances(self) -> np.ndarray:
if self.model is None:
raise RuntimeError(
"Cannot calculate feature_importances without a fitted model"
)
else:
ls = self.model.covar_module.base_kernel.lengthscale # pyre-ignore: [16]
return cast(Tensor, (1 / ls)).detach().cpu().numpy()
[docs]def get_rounding_func(
rounding_func: Optional[Callable[[Tensor], Tensor]]
) -> Optional[Callable[[Tensor], Tensor]]:
if rounding_func is None:
botorch_rounding_func = rounding_func
else:
# make sure rounding_func is properly applied to q- and t-batches
def botorch_rounding_func(X: Tensor) -> Tensor:
batch_shape, d = X.shape[:-1], X.shape[-1]
X_round = torch.stack(
[rounding_func(x) for x in X.view(-1, d)] # pyre-ignore: [16]
)
return X_round.view(*batch_shape, d)
return botorch_rounding_func