#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import AxError
from ax.models.torch.botorch import (
BotorchModel,
get_rounding_func,
TAcqfConstructor,
TBestPointRecommender,
TModelConstructor,
TModelPredictor,
TOptimizer,
)
from ax.models.torch.botorch_defaults import (
get_and_fit_model,
recommend_best_observed_point,
scipy_optimizer,
)
from ax.models.torch.botorch_moo_defaults import (
get_NEHVI,
infer_objective_thresholds,
pareto_frontier_evaluator,
scipy_optimizer_list,
)
from ax.models.torch.frontier_utils import TFrontierEvaluator
from ax.models.torch.utils import (
_get_X_pending_and_observed,
_to_inequality_constraints,
predict_from_model,
randomize_objective_weights,
subset_model,
)
from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig
from ax.utils.common.constants import Keys
from ax.utils.common.docutils import copy_doc
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast, not_none
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.models.model import Model
from torch import Tensor
# pyre-fixme[5]: Global expression must be annotated.
logger = get_logger(__name__)
# pyre-fixme[33]: Aliased annotation cannot contain `Any`.
TOptimizerList = Callable[
[
List[AcquisitionFunction],
Tensor,
Optional[List[Tuple[Tensor, Tensor, float]]],
Optional[Dict[int, float]],
Optional[Callable[[Tensor], Tensor]],
Any,
],
Tuple[Tensor, Tensor],
]
[docs]class MultiObjectiveBotorchModel(BotorchModel):
r"""
Customizable multi-objective model.
By default, this uses an Expected Hypervolume Improvment function to find the
pareto frontier of a function with multiple outcomes. This behavior
can be modified by providing custom implementations of the following
components:
- a `model_constructor` that instantiates and fits a model on data
- a `model_predictor` that predicts outcomes using the fitted model
- a `acqf_constructor` that creates an acquisition function from a fitted model
- a `acqf_optimizer` that optimizes the acquisition function
Args:
model_constructor: A callable that instantiates and fits a model on data,
with signature as described below.
model_predictor: A callable that predicts using the fitted model, with
signature as described below.
acqf_constructor: A callable that creates an acquisition function from a
fitted model, with signature as described below.
acqf_optimizer: A callable that optimizes an acquisition
function, with signature as described below.
Call signatures:
::
model_constructor(
Xs,
Ys,
Yvars,
task_features,
fidelity_features,
metric_names,
state_dict,
**kwargs,
) -> model
Here `Xs`, `Ys`, `Yvars` are lists of tensors (one element per outcome),
`task_features` identifies columns of Xs that should be modeled as a task,
`fidelity_features` is a list of ints that specify the positions of fidelity
parameters in 'Xs', `metric_names` provides the names of each `Y` in `Ys`,
`state_dict` is a pytorch module state dict, and `model` is a BoTorch `Model`.
Optional kwargs are being passed through from the `BotorchModel` constructor.
This callable is assumed to return a fitted BoTorch model that has the same
dtype and lives on the same device as the input tensors.
::
model_predictor(model, X) -> [mean, cov]
Here `model` is a fitted botorch model, `X` is a tensor of candidate points,
and `mean` and `cov` are the posterior mean and covariance, respectively.
::
acqf_constructor(
model,
objective_weights,
outcome_constraints,
X_observed,
X_pending,
**kwargs,
) -> acq_function
Here `model` is a botorch `Model`, `objective_weights` is a tensor of weights
for the model outputs, `outcome_constraints` is a tuple of tensors describing
the (linear) outcome constraints, `X_observed` are previously observed points,
and `X_pending` are points whose evaluation is pending. `acq_function` is a
BoTorch acquisition function crafted from these inputs. For additional
details on the arguments, see `get_NEI`.
::
acqf_optimizer(
acq_function,
bounds,
n,
inequality_constraints,
fixed_features,
rounding_func,
**kwargs,
) -> candidates
Here `acq_function` is a BoTorch `AcquisitionFunction`, `bounds` is a tensor
containing bounds on the parameters, `n` is the number of candidates to be
generated, `inequality_constraints` are inequality constraints on parameter
values, `fixed_features` specifies features that should be fixed during
generation, and `rounding_func` is a callback that rounds an optimization
result appropriately. `candidates` is a tensor of generated candidates.
For additional details on the arguments, see `scipy_optimizer`.
::
frontier_evaluator(
model,
objective_weights,
objective_thresholds,
X,
Y,
Yvar,
outcome_constraints,
)
Here `model` is a botorch `Model`, `objective_thresholds` is used in hypervolume
evaluations, `objective_weights` is a tensor of weights applied to the objectives
(sign represents direction), `X`, `Y`, `Yvar` are tensors, `outcome_constraints` is
a tuple of tensors describing the (linear) outcome constraints.
"""
dtype: Optional[torch.dtype]
device: Optional[torch.device]
Xs: List[Tensor]
Ys: List[Tensor]
Yvars: List[Tensor]
def __init__(
self,
model_constructor: TModelConstructor = get_and_fit_model,
model_predictor: TModelPredictor = predict_from_model,
# pyre-fixme[9]: acqf_constructor has type `Callable[[Model, Tensor,
# Optional[Tuple[Tensor, Tensor]], Optional[Tensor], Optional[Tensor], Any],
# AcquisitionFunction]`; used as `Callable[[Model, Tensor,
# Optional[Tuple[Tensor, Tensor]], Optional[Tensor], Optional[Tensor],
# **(Any)], AcquisitionFunction]`.
acqf_constructor: TAcqfConstructor = get_NEHVI,
# pyre-fixme[9]: acqf_optimizer has type `Callable[[AcquisitionFunction,
# Tensor, int, Optional[Dict[int, float]], Optional[Callable[[Tensor],
# Tensor]], Any], Tensor]`; used as `Callable[[AcquisitionFunction, Tensor,
# int, Optional[Dict[int, float]], Optional[Callable[[Tensor], Tensor]],
# **(Any)], Tensor]`.
acqf_optimizer: TOptimizer = scipy_optimizer,
# TODO: Remove best_point_recommender for botorch_moo. Used in modelbridge._gen.
best_point_recommender: TBestPointRecommender = recommend_best_observed_point,
frontier_evaluator: TFrontierEvaluator = pareto_frontier_evaluator,
refit_on_cv: bool = False,
refit_on_update: bool = True,
warm_start_refitting: bool = False,
use_input_warping: bool = False,
use_loocv_pseudo_likelihood: bool = False,
**kwargs: Any,
) -> None:
self.model_constructor = model_constructor
self.model_predictor = model_predictor
self.acqf_constructor = acqf_constructor
self.acqf_optimizer = acqf_optimizer
self.best_point_recommender = best_point_recommender
self.frontier_evaluator = frontier_evaluator
# pyre-fixme[4]: Attribute must be annotated.
self._kwargs = kwargs
self.refit_on_cv = refit_on_cv
self.refit_on_update = refit_on_update
self.warm_start_refitting = warm_start_refitting
self.use_input_warping = use_input_warping
self.use_loocv_pseudo_likelihood = use_loocv_pseudo_likelihood
self.model: Optional[Model] = None
self.Xs = []
self.Ys = []
self.Yvars = []
self.dtype = None
self.device = None
self.task_features: List[int] = []
self.fidelity_features: List[int] = []
self.metric_names: List[str] = []
[docs] @copy_doc(TorchModel.gen)
def gen(
self,
n: int,
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
) -> TorchGenResults:
options = torch_opt_config.model_gen_options or {}
acf_options = options.get("acquisition_function_kwargs", {})
optimizer_options = options.get("optimizer_kwargs", {})
if search_space_digest.target_fidelities: # untested
raise NotImplementedError(
"target_fidelities not implemented for base BotorchModel"
)
if (
torch_opt_config.objective_thresholds is not None
and torch_opt_config.objective_weights.shape[0]
!= not_none(torch_opt_config.objective_thresholds).shape[0]
):
raise AxError(
"Objective weights and thresholds most both contain an element for"
" each modeled metric."
)
X_pending, X_observed = _get_X_pending_and_observed(
Xs=self.Xs,
objective_weights=torch_opt_config.objective_weights,
bounds=search_space_digest.bounds,
pending_observations=torch_opt_config.pending_observations,
outcome_constraints=torch_opt_config.outcome_constraints,
linear_constraints=torch_opt_config.linear_constraints,
fixed_features=torch_opt_config.fixed_features,
)
model = not_none(self.model)
full_objective_thresholds = torch_opt_config.objective_thresholds
full_objective_weights = torch_opt_config.objective_weights
full_outcome_constraints = torch_opt_config.outcome_constraints
# subset model only to the outcomes we need for the optimization
if options.get(Keys.SUBSET_MODEL, True):
subset_model_results = subset_model(
model=model,
objective_weights=torch_opt_config.objective_weights,
outcome_constraints=torch_opt_config.outcome_constraints,
objective_thresholds=torch_opt_config.objective_thresholds,
)
model = subset_model_results.model
objective_weights = subset_model_results.objective_weights
outcome_constraints = subset_model_results.outcome_constraints
objective_thresholds = subset_model_results.objective_thresholds
idcs = subset_model_results.indices
else:
objective_weights = torch_opt_config.objective_weights
outcome_constraints = torch_opt_config.outcome_constraints
objective_thresholds = torch_opt_config.objective_thresholds
idcs = None
if objective_thresholds is None:
full_objective_thresholds = infer_objective_thresholds(
model=model,
X_observed=not_none(X_observed),
objective_weights=full_objective_weights,
outcome_constraints=full_outcome_constraints,
subset_idcs=idcs,
)
# subset the objective thresholds
objective_thresholds = (
full_objective_thresholds
if idcs is None
else full_objective_thresholds[idcs].clone()
)
bounds_ = torch.tensor(
search_space_digest.bounds, dtype=self.dtype, device=self.device
)
bounds_ = bounds_.transpose(0, 1)
botorch_rounding_func = get_rounding_func(torch_opt_config.rounding_func)
if acf_options.get("random_scalarization", False) or acf_options.get(
"chebyshev_scalarization", False
):
# If using a list of acquisition functions, the algorithm to generate
# that list is configured by acquisition_function_kwargs.
objective_weights_list = [
randomize_objective_weights(objective_weights, **acf_options)
for _ in range(n)
]
acquisition_function_list = [
self.acqf_constructor( # pyre-ignore: [28]
model=model,
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
X_observed=X_observed,
X_pending=X_pending,
**acf_options,
)
for objective_weights in objective_weights_list
]
acquisition_function_list = [
checked_cast(AcquisitionFunction, acq_function)
for acq_function in acquisition_function_list
]
# Multiple acquisition functions require a sequential optimizer
# always use scipy_optimizer_list.
# TODO(jej): Allow any optimizer.
candidates, expected_acquisition_value = scipy_optimizer_list(
acq_function_list=acquisition_function_list,
bounds=bounds_,
inequality_constraints=_to_inequality_constraints(
linear_constraints=torch_opt_config.linear_constraints
),
fixed_features=torch_opt_config.fixed_features,
rounding_func=botorch_rounding_func,
**optimizer_options,
)
else:
acquisition_function = self.acqf_constructor( # pyre-ignore: [28]
model=model,
objective_weights=objective_weights,
objective_thresholds=objective_thresholds,
outcome_constraints=outcome_constraints,
X_observed=X_observed,
X_pending=X_pending,
**acf_options,
)
acquisition_function = checked_cast(
AcquisitionFunction, acquisition_function
)
# pyre-ignore: [28]
candidates, expected_acquisition_value = self.acqf_optimizer(
acq_function=checked_cast(AcquisitionFunction, acquisition_function),
bounds=bounds_,
n=n,
inequality_constraints=_to_inequality_constraints(
linear_constraints=torch_opt_config.linear_constraints
),
fixed_features=torch_opt_config.fixed_features,
rounding_func=botorch_rounding_func,
**optimizer_options,
)
gen_metadata = {
"expected_acquisition_value": expected_acquisition_value.tolist(),
"objective_thresholds": not_none(full_objective_thresholds).cpu(),
"objective_weights": full_objective_weights.cpu(),
}
return TorchGenResults(
points=candidates.detach().cpu(),
weights=torch.ones(n, dtype=self.dtype),
gen_metadata=gen_metadata,
)