#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from ax.core.types import TConfig
from ax.models.model_utils import best_observed_point
from ax.models.torch.botorch_defaults import (
get_and_fit_model,
get_NEI,
predict_from_model,
scipy_optimizer,
)
from ax.models.torch.utils import _get_X_pending_and_observed
from ax.models.torch_base import TorchModel
from ax.utils.common.docutils import copy_doc
from ax.utils.common.typeutils import checked_cast
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.models.model import Model
from torch import Tensor
TModelConstructor = Callable[
[
List[Tensor],
List[Tensor],
List[Tensor],
List[int],
Optional[Dict[str, Tensor]],
Any,
],
Model,
]
TModelPredictor = Callable[[Model, Tensor], Tuple[Tensor, Tensor]]
TAcqfConstructor = Callable[
[
Model,
Tensor,
Optional[Tuple[Tensor, Tensor]],
Optional[Tensor],
Optional[Tensor],
Any,
],
AcquisitionFunction,
]
TOptimizer = Callable[
[
AcquisitionFunction,
Tensor,
int,
Optional[List[Tuple[Tensor, Tensor, float]]],
Optional[Dict[int, float]],
Optional[Callable[[Tensor], Tensor]],
Any,
],
Tensor,
]
[docs]class BotorchModel(TorchModel):
r"""
Customizable botorch model.
By default, this uses a noisy Expected Improvement acquisition funciton on
top of a model made up of separate GPs, one for each outcome. This behavior
can be modified by providing custom implementations of the following
components:
- a `model_constructor` that instantiates and fits a model on data
- a `model_predictor` that predicts using the fitted model
- a `acqf_constructor` that creates an acquisition function from a fitted model
- a `acqf_optimizer` that optimizes the acquisition function
Args:
model_constructor: A callable that instantiates and fits a model on data,
with signature as described below.
model_predictor: A callable that predicts using the fitted model, with
signature as described below.
acqf_constructor: A callable that creates an acquisition function from a
fitted model, with signature as described below.
acqf_optimizer: A callable that optimizes the acquisition function, with
signature as described below.
refit_on_cv: If True, refit the model for each fold when performing
cross-validation.
Call signatures:
::
model_constructor(
Xs,
Ys,
Yvars,
task_features,
state_dict,
fidelity_features,
**kwargs
) -> model
Here `Xs`, `Ys`, `Yvars` are lists of tensors (one element per outcome),
`task_features` identifies columns of Xs that should be modeled
as a task, `state_dict` is a pytorch module state dict, 'fidelity_features' is
a list of ints that specify the positions of fidelity parameters in 'Xs',
and `model` is a botorch `Model`. Optional kwargs are being passed through
from the `BotorchModel` constructor. This callable is assumed to return a fitted
botorch model that has the same dtype and lives on the same device as the
input tensors.
::
model_predictor(model, X) -> [mean, cov]
Here `model` is a fitted botorch model, `X` is a tensor of candidate
points, and `mean` and `cov` are the posterior mean and covariance,
respectively.
::
acqf_constructor(
model,
objective_weights,
outcome_constraints,
X_observed,
X_pending,
**kwargs,
) -> acq_function
Here `model` is a botorch `Model`, `objective_weights` is a tensor of
weights for the model outputs, `outcome_constraints` is a tuple of
tensors describing the (linear) outcome constraints, `X_observed` are
previously observed points, and `X_pending` are points whose evaluation
is pending. `acq_function` is a botorch acquisition function crafted
from these inputs. For additional details on the arguments, see `get_NEI`.
::
acqf_optimizer(
acq_function,
bounds,
n,
inequality_constraints,
fixed_features,
rounding_func,
**kwargs,
) -> candidates
Here `acq_function` is a BoTorch `AcquisitionFunction`, `bounds` is a
tensor containing bounds on the parameters, `n` is the number of
candidates to be generated, `inequality_constraints` are inequality
constraints on parameter values, `fixed_features` specifies features that
should be fixed during generation, and `rounding_func` is a callback
that rounds an optimization result appropriately. `candidates` is
a tensor of generated candidates. For additional details on the
arguments, see `scipy_optimizer`.
"""
dtype: Optional[torch.dtype]
device: Optional[torch.device]
Xs: List[Tensor]
Ys: List[Tensor]
Yvars: List[Tensor]
def __init__(
self,
# pyre-fixme[9]: model_constructor has type `Callable[[List[Tensor],
# List[Tensor], List[Tensor], List[int], Optional[Dict[str, Tensor]], Any],
# Model]`; used as `Callable[[List[Tensor], List[Tensor], List[Tensor],
# List[int], Optional[Dict[str, Tensor]], **(Any)], MultiOutputGP]`.
model_constructor: TModelConstructor = get_and_fit_model,
model_predictor: TModelPredictor = predict_from_model,
# pyre-fixme[9]: acqf_constructor has type `Callable[[Model, Tensor,
# Optional[Tuple[Tensor, Tensor]], Optional[Tensor], Optional[Tensor], Any],
# AcquisitionFunction]`; used as `Callable[[Model, Tensor,
# Optional[Tuple[Tensor, Tensor]], Optional[Tensor], Optional[Tensor],
# **(Any)], AcquisitionFunction]`.
acqf_constructor: TAcqfConstructor = get_NEI,
# pyre-fixme[9]: acqf_optimizer has type `Callable[[AcquisitionFunction,
# Tensor, int, Optional[Dict[int, float]], Optional[Callable[[Tensor],
# Tensor]], Any], Tensor]`; used as `Callable[[AcquisitionFunction, Tensor,
# int, Optional[Dict[int, float]], Optional[Callable[[Tensor], Tensor]],
# **(Any)], Tensor]`.
acqf_optimizer: TOptimizer = scipy_optimizer,
refit_on_cv: bool = False,
refit_on_update: bool = True,
warm_start_refitting: bool = True,
**kwargs: Any,
) -> None:
self.model_constructor = model_constructor
self.model_predictor = model_predictor
self.acqf_constructor = acqf_constructor
self.acqf_optimizer = acqf_optimizer
self.refit_on_cv = refit_on_cv
self.refit_on_update = refit_on_update
self.model = None
self.Xs = []
self.Ys = []
self.Yvars = []
self.dtype = None
self.device = None
self.task_features: List[int] = []
self.fidelity_features: List[int] = []
self.fidelity_model_id = kwargs.get("fidelity_model_id", None)
self.warm_start_refitting = warm_start_refitting
[docs] @copy_doc(TorchModel.fit)
def fit(
self,
Xs: List[Tensor],
Ys: List[Tensor],
Yvars: List[Tensor],
bounds: List[Tuple[float, float]],
task_features: List[int],
feature_names: List[str],
fidelity_features: List[int],
) -> None:
self.dtype = Xs[0].dtype
self.device = Xs[0].device
self.Xs = Xs
self.Ys = Ys
self.Yvars = Yvars
self.task_features = task_features
self.fidelity_features = fidelity_features
self.model = self.model_constructor( # pyre-ignore [28]
Xs=Xs,
Ys=Ys,
Yvars=Yvars,
task_features=self.task_features,
fidelity_features=self.fidelity_features,
fidelity_model_id=self.fidelity_model_id,
)
[docs] @copy_doc(TorchModel.predict)
def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
return self.model_predictor(model=self.model, X=X) # pyre-ignore [28]
[docs] def gen(
self,
n: int,
bounds: List[Tuple[float, float]],
objective_weights: Tensor,
outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None,
linear_constraints: Optional[Tuple[Tensor, Tensor]] = None,
fixed_features: Optional[Dict[int, float]] = None,
pending_observations: Optional[List[Tensor]] = None,
model_gen_options: Optional[TConfig] = None,
rounding_func: Optional[Callable[[Tensor], Tensor]] = None,
) -> Tuple[Tensor, Tensor]:
"""Generate new candidates.
An initialized acquisition function can be passed in as
model_gen_options["acquisition_function"].
Args:
n: Number of candidates to generate.
bounds: A list of (lower, upper) tuples for each column of X.
objective_weights: The objective is to maximize a weighted sum of
the columns of f(x). These are the weights.
outcome_constraints: A tuple of (A, b). For k outcome constraints
and m outputs at f(x), A is (k x m) and b is (k x 1) such that
A f(x) <= b. (Not used by single task models)
linear_constraints: A tuple of (A, b). For k linear constraints on
d-dimensional x, A is (k x d) and b is (k x 1) such that
A x <= b.
fixed_features: A map {feature_index: value} for features that
should be fixed to a particular value during generation.
pending_observations: A list of m (k_i x d) feature tensors X
for m outcomes and k_i pending observations for outcome i.
model_gen_options: A config dictionary that can contain
model-specific options.
rounding_func: A function that rounds an optimization result
appropriately (i.e., according to `round-trip` transformations).
Returns:
Tensor: `n x d`-dim Tensor of generated points.
Tensor: `n`-dim Tensor of weights for each point.
"""
options = model_gen_options or {}
acf_options = options.get("acquisition_function_kwargs", {})
optimizer_options = options.get("optimizer_kwargs", {})
X_pending, X_observed = _get_X_pending_and_observed(
Xs=self.Xs,
pending_observations=pending_observations,
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
bounds=bounds,
linear_constraints=linear_constraints,
fixed_features=fixed_features,
)
acquisition_function = self.acqf_constructor( # pyre-ignore: [28]
model=self.model,
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
X_observed=X_observed,
X_pending=X_pending,
**acf_options,
)
bounds_ = torch.tensor(bounds, dtype=self.dtype, device=self.device)
bounds_ = bounds_.transpose(0, 1)
if linear_constraints is not None:
A, b = linear_constraints
inequality_constraints = []
k, d = A.shape
for i in range(k):
indicies = A[i, :].nonzero().view(-1)
coefficients = -A[i, indicies]
rhs = -b[i, 0]
inequality_constraints.append((indicies, coefficients, rhs))
else:
inequality_constraints = None
botorch_rounding_func = get_rounding_func(rounding_func)
candidates = self.acqf_optimizer( # pyre-ignore: [28]
acq_function=checked_cast(AcquisitionFunction, acquisition_function),
bounds=bounds_,
n=n,
inequality_constraints=inequality_constraints,
fixed_features=fixed_features,
rounding_func=botorch_rounding_func,
**optimizer_options,
)
return candidates.detach().cpu(), torch.ones(n, dtype=self.dtype)
[docs] @copy_doc(TorchModel.best_point)
def best_point(
self,
bounds: List[Tuple[float, float]],
objective_weights: Tensor,
outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None,
linear_constraints: Optional[Tuple[Tensor, Tensor]] = None,
fixed_features: Optional[Dict[int, float]] = None,
model_gen_options: Optional[TConfig] = None,
) -> Optional[Tensor]:
x_best = best_observed_point(
model=self,
bounds=bounds,
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
linear_constraints=linear_constraints,
fixed_features=fixed_features,
options=model_gen_options,
)
if x_best is None:
return None
# pyre-fixme[19]: Expected 0 positional arguments.
return x_best.to(dtype=self.dtype, device=torch.device("cpu"))
[docs] @copy_doc(TorchModel.cross_validate)
def cross_validate(
self,
Xs_train: List[Tensor],
Ys_train: List[Tensor],
Yvars_train: List[Tensor],
X_test: Tensor,
) -> Tuple[Tensor, Tensor]:
if self.model is None:
raise RuntimeError("Cannot cross-validate model that has not been fitted")
if self.refit_on_cv:
state_dict = None
else:
state_dict = deepcopy(self.model.state_dict()) # pyre-ignore: [16]
model = self.model_constructor( # pyre-ignore: [28]
Xs=Xs_train,
Ys=Ys_train,
Yvars=Yvars_train,
task_features=self.task_features,
state_dict=state_dict,
fidelity_features=self.fidelity_features,
fidelity_model_id=self.fidelity_model_id,
)
return self.model_predictor(model=model, X=X_test) # pyre-ignore: [28]
[docs] @copy_doc(TorchModel.update)
def update(self, Xs: List[Tensor], Ys: List[Tensor], Yvars: List[Tensor]) -> None:
if self.model is None:
raise RuntimeError("Cannot update model that has not been fitted")
self.Xs = Xs
self.Ys = Ys
self.Yvars = Yvars
if self.refit_on_update and not self.warm_start_refitting:
state_dict = None # pragma: no cover
else:
state_dict = deepcopy(self.model.state_dict()) # pyre-ignore: [16]
self.model = self.model_constructor( # pyre-ignore: [28]
Xs=self.Xs,
Ys=self.Ys,
Yvars=self.Yvars,
task_features=self.task_features,
state_dict=state_dict,
fidelity_features=self.fidelity_features,
fidelity_model_id=self.fidelity_model_id,
refit_model=self.refit_on_update,
)
[docs]def get_rounding_func(
rounding_func: Optional[Callable[[Tensor], Tensor]]
) -> Optional[Callable[[Tensor], Tensor]]:
if rounding_func is None:
botorch_rounding_func = rounding_func
else:
# make sure rounding_func is properly applied to q- and t-batches
def botorch_rounding_func(X: Tensor) -> Tensor:
batch_shape, d = X.shape[:-1], X.shape[-1]
X_round = torch.stack(
[rounding_func(x) for x in X.view(-1, d)] # pyre-ignore: [16]
)
return X_round.view(*batch_shape, d)
return botorch_rounding_func