#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
import dataclasses
import functools
import operator
import warnings
from collections.abc import Mapping
from functools import partial, reduce
from itertools import product
from logging import Logger
from typing import Any, Callable, Optional
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import AxWarning, SearchSpaceExhausted
from ax.models.model_utils import enumerate_discrete_combinations, mk_discrete_choices
from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.torch.botorch_modular.utils import (
_tensor_difference,
get_post_processing_func,
)
from ax.models.torch.botorch_moo_defaults import infer_objective_thresholds
from ax.models.torch.utils import (
_get_X_pending_and_observed,
get_botorch_objective_and_transform,
subset_model,
)
from ax.models.torch_base import TorchOptConfig
from ax.utils.common.base import Base
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.input_constructors import get_acqf_input_constructor
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
from botorch.acquisition.risk_measures import RiskMeasureMCObjective
from botorch.models.model import Model, ModelDict
from botorch.optim.optimize import (
optimize_acqf,
optimize_acqf_discrete,
optimize_acqf_discrete_local_search,
optimize_acqf_mixed,
)
from botorch.utils.constraints import get_outcome_constraint_transforms
from torch import Tensor
DUPLICATE_TOL = 1e-6
MAX_CHOICES_ENUMERATE = 100_000
logger: Logger = get_logger(__name__)
[docs]class Acquisition(Base):
"""
**All classes in 'botorch_modular' directory are under
construction, incomplete, and should be treated as alpha
versions only.**
Ax wrapper for BoTorch `AcquisitionFunction`, subcomponent
of `BoTorchModel` and is not meant to be used outside of it.
Args:
surrogates: Dict of name => Surrogate model pairs, with which this acquisition
function will be used.
search_space_digest: A SearchSpaceDigest object containing metadata
about the search space (e.g. bounds, parameter types).
torch_opt_config: A TorchOptConfig object containing optimization
arguments (e.g., objective weights, constraints).
botorch_acqf_class: Type of BoTorch `AcquistitionFunction` that
should be used. Subclasses of `Acquisition` often specify
these via `default_botorch_acqf_class` attribute, in which
case specifying one here is not required.
options: Optional mapping of kwargs to the underlying `Acquisition
Function` in BoTorch.
"""
surrogates: dict[str, Surrogate]
acqf: AcquisitionFunction
options: dict[str, Any]
def __init__(
self,
# If using multiple Surrogates, must label primary Surrogate (typically the
# regression Surrogate) Keys.PRIMARY_SURROGATE
surrogates: dict[str, Surrogate],
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
botorch_acqf_class: type[AcquisitionFunction],
options: Optional[dict[str, Any]] = None,
) -> None:
self.surrogates = surrogates
self.options = options or {}
# Compute pending and observed points for each surrogate
Xs_pending_and_observed = {
name: _get_X_pending_and_observed(
Xs=surrogate.Xs,
objective_weights=torch_opt_config.objective_weights,
bounds=search_space_digest.bounds,
pending_observations=torch_opt_config.pending_observations,
outcome_constraints=torch_opt_config.outcome_constraints,
linear_constraints=torch_opt_config.linear_constraints,
fixed_features=torch_opt_config.fixed_features,
)
for name, surrogate in self.surrogates.items()
}
Xs_pending_list = [
Xs_pending
for Xs_pending, _ in Xs_pending_and_observed.values()
if Xs_pending is not None
]
unique_Xs_pending = (
torch.unique(
input=torch.cat(
tensors=Xs_pending_list,
dim=0,
),
dim=0,
)
if len(Xs_pending_list) > 0
else None
)
# This tensor may have some Xs that are also in pending (because they are
# observed for some models but not others)
Xs_observed_maybe_pending_list = [
Xs_observed
for _, Xs_observed in Xs_pending_and_observed.values()
if Xs_observed is not None
]
unique_Xs_observed_maybe_pending = (
torch.unique(
input=torch.cat(
tensors=Xs_observed_maybe_pending_list,
dim=0,
),
dim=0,
)
if len(Xs_observed_maybe_pending_list) > 0
else None
)
# If a point is pending on any model do not count it as observed.
# Do this by stacking pending on top of observed, filtering repeats, then
# removing pending points.
# TODO[sdaulton] Is this a sound approach? Should we be doing something more
# sophisticated here?
if unique_Xs_pending is None:
unique_Xs_observed = unique_Xs_observed_maybe_pending
elif unique_Xs_observed_maybe_pending is None:
unique_Xs_observed = None
else:
unique_Xs_observed = _tensor_difference(
A=unique_Xs_pending, B=unique_Xs_observed_maybe_pending
)
if torch.numel(unique_Xs_observed_maybe_pending) != torch.numel(
unique_Xs_observed
):
logger.warning(
"Encountered Xs pending for some Surrogates but observed for "
"others. Considering these points to be pending."
)
# Store objective thresholds for all outcomes (including non-objectives).
self._objective_thresholds: Optional[Tensor] = (
torch_opt_config.objective_thresholds
)
self._full_objective_weights: Tensor = torch_opt_config.objective_weights
full_outcome_constraints = torch_opt_config.outcome_constraints
# TODO[mpolson64] Handle more elegantly in the future. Since right now we
# only use one objective and posterior_transform this should be fine.
primary_surrogate = (
self.surrogates[Keys.PRIMARY_SURROGATE]
if len(self.surrogates) > 1
else next(iter(self.surrogates.values()))
)
primary_Xs_pending, primary_Xs_observed = Xs_pending_and_observed[
(
Keys.PRIMARY_SURROGATE
if len(self.surrogates) > 1
else next(iter(Xs_pending_and_observed.keys()))
)
]
# Subset model only to the outcomes we need for the optimization.
if self.options.pop(Keys.SUBSET_MODEL, True):
subset_model_results = subset_model(
model=primary_surrogate.model,
objective_weights=torch_opt_config.objective_weights,
outcome_constraints=torch_opt_config.outcome_constraints,
objective_thresholds=torch_opt_config.objective_thresholds,
)
model = subset_model_results.model
objective_weights = subset_model_results.objective_weights
outcome_constraints = subset_model_results.outcome_constraints
objective_thresholds = subset_model_results.objective_thresholds
subset_idcs = subset_model_results.indices
else:
model = primary_surrogate.model
objective_weights = torch_opt_config.objective_weights
outcome_constraints = torch_opt_config.outcome_constraints
objective_thresholds = torch_opt_config.objective_thresholds
subset_idcs = None
# If objective weights suggest multiple objectives but objective
# thresholds are not specified, infer them using the model that
# has already been subset to avoid re-subsetting it within
# `inter_objective_thresholds`.
if (
objective_weights.nonzero().numel() > 1
and (
self._objective_thresholds is None
or self._objective_thresholds[torch_opt_config.objective_weights != 0]
.isnan()
.any()
)
and primary_Xs_observed is not None
):
if torch_opt_config.risk_measure is not None:
# TODO[T131759263]: modify the heuristic to support risk measures.
raise NotImplementedError(
"Objective thresholds must be provided when using risk measures."
)
self._objective_thresholds = infer_objective_thresholds(
model=model,
objective_weights=self._full_objective_weights,
outcome_constraints=full_outcome_constraints,
X_observed=primary_Xs_observed,
subset_idcs=subset_idcs,
objective_thresholds=self._objective_thresholds,
)
objective_thresholds = (
not_none(self._objective_thresholds)[subset_idcs]
if subset_idcs is not None
else self._objective_thresholds
)
objective, posterior_transform = self.get_botorch_objective_and_transform(
botorch_acqf_class=botorch_acqf_class,
model=model,
objective_weights=objective_weights,
objective_thresholds=objective_thresholds,
outcome_constraints=outcome_constraints,
X_observed=primary_Xs_observed,
risk_measure=torch_opt_config.risk_measure,
)
model_deps = self.compute_model_dependencies(
surrogates=self.surrogates,
search_space_digest=search_space_digest,
torch_opt_config=dataclasses.replace(
torch_opt_config,
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
objective_thresholds=objective_thresholds,
),
options=self.options,
)
acqf_model_kwarg = (
{
"model_dict": ModelDict(
**{
name: surrogate.model
for name, surrogate in self.surrogates.items()
}
)
}
if len(self.surrogates) > 1
else {"model": model}
)
target_fidelities = {
k: v
for k, v in search_space_digest.target_values.items()
if k in search_space_digest.fidelity_features
}
input_constructor_kwargs = {
"X_baseline": unique_Xs_observed,
"X_pending": unique_Xs_pending,
"objective_thresholds": objective_thresholds,
"constraints": get_outcome_constraint_transforms(
outcome_constraints=outcome_constraints
),
"objective": objective,
"posterior_transform": posterior_transform,
**acqf_model_kwarg,
**model_deps,
**self.options,
}
if len(target_fidelities) > 0:
input_constructor_kwargs["target_fidelities"] = target_fidelities
input_constructor = get_acqf_input_constructor(botorch_acqf_class)
# Handle multi-dataset surrogates - TODO: Improve this
# If there is only one SupervisedDataset return it alone
if (
len(self.surrogates) == 1
and len(next(iter(self.surrogates.values())).training_data) == 1
):
training_data = next(iter(self.surrogates.values())).training_data[0]
else:
tdicts = (
dict(zip(not_none(surrogate._outcomes), surrogate.training_data))
for surrogate in self.surrogates.values()
)
# outcome_name => Dataset
training_data = functools.reduce(lambda x, y: {**x, **y}, tdicts)
acqf_inputs = input_constructor(
training_data=training_data,
bounds=search_space_digest.bounds,
**{k: v for k, v in input_constructor_kwargs.items() if v is not None},
)
self.acqf = botorch_acqf_class(**acqf_inputs) # pyre-ignore [45]
self.X_pending: Optional[Tensor] = unique_Xs_pending
self.X_observed: Optional[Tensor] = unique_Xs_observed
@property
def botorch_acqf_class(self) -> type[AcquisitionFunction]:
"""BoTorch ``AcquisitionFunction`` class underlying this ``Acquisition``."""
return self.acqf.__class__
@property
def dtype(self) -> Optional[torch.dtype]:
"""Torch data type of the tensors in the training data used in the model,
of which this ``Acquisition`` is a subcomponent.
"""
dtypes = {
label: surrogate.dtype for label, surrogate in self.surrogates.items()
}
dtypes_list = list(dtypes.values())
if dtypes_list.count(dtypes_list[0]) != len(dtypes_list):
raise ValueError(
f"Expected all Surrogates to have same dtype, found {dtypes}"
)
return dtypes_list[0]
@property
def device(self) -> Optional[torch.device]:
"""Torch device type of the tensors in the training data used in the model,
of which this ``Acquisition`` is a subcomponent.
"""
devices = {
label: surrogate.device for label, surrogate in self.surrogates.items()
}
devices_list = list(devices.values())
if devices_list.count(devices_list[0]) != len(devices_list):
raise ValueError(
f"Expected all Surrogates to have same device, found {devices}"
)
return devices_list[0]
@property
def objective_thresholds(self) -> Optional[Tensor]:
"""The objective thresholds for all outcomes.
For non-objective outcomes, the objective thresholds are nans.
"""
return self._objective_thresholds
@property
def objective_weights(self) -> Optional[Tensor]:
"""The objective weights for all outcomes."""
return self._full_objective_weights
[docs] def optimize(
self,
n: int,
search_space_digest: SearchSpaceDigest,
inequality_constraints: Optional[list[tuple[Tensor, Tensor, float]]] = None,
fixed_features: Optional[dict[int, float]] = None,
rounding_func: Optional[Callable[[Tensor], Tensor]] = None,
optimizer_options: Optional[dict[str, Any]] = None,
) -> tuple[Tensor, Tensor, Tensor]:
"""Generate a set of candidates via multi-start optimization. Obtains
candidates and their associated acquisition function values.
Args:
n: The number of candidates to generate.
search_space_digest: A ``SearchSpaceDigest`` object containing search space
properties, e.g. ``bounds`` for optimization.
inequality_constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
``sum_i (X[indices[i]] * coefficients[i]) >= rhs``.
fixed_features: A map `{feature_index: value}` for features that
should be fixed to a particular value during generation.
rounding_func: A function that post-processes an optimization
result appropriately. This is typically passed down from
`ModelBridge` to ensure compatibility of the candidates with
with Ax transforms. For additional post processing, use
`post_processing_func` option in `optimizer_options`.
optimizer_options: Options for the optimizer function, e.g. ``sequential``
or ``raw_samples``. This can also include a `post_processing_func`
which is applied to the candidates before the `rounding_func`.
`post_processing_func` can be used to support more customized options
that typically only exist in MBM, such as BoTorch transforms.
See the docstring of `TorchOptConfig` for more information on passing
down these options while constructing a generation strategy.
Returns:
A three-element tuple containing an `n x d`-dim tensor of generated
candidates, a tensor with the associated acquisition values, and a tensor
with the weight for each candidate.
"""
# NOTE: Could make use of `optimizer_class` when it's added to BoTorch
# instead of calling `optimizer_acqf` or `optimize_acqf_discrete` etc.
_tensorize = partial(torch.tensor, dtype=self.dtype, device=self.device)
ssd = search_space_digest
bounds = _tensorize(ssd.bounds).t()
discrete_features = sorted(ssd.ordinal_features + ssd.categorical_features)
discrete_choices = mk_discrete_choices(ssd=ssd, fixed_features=fixed_features)
if (
optimizer_options is not None
and "force_use_optimize_acqf" in optimizer_options
):
force_use_optimize_acqf = optimizer_options.pop("force_use_optimize_acqf")
else:
force_use_optimize_acqf = False
if (len(discrete_features) == 0) or force_use_optimize_acqf:
optimizer = "optimize_acqf"
else:
fully_discrete = len(discrete_choices) == len(ssd.feature_names)
if fully_discrete:
total_discrete_choices = reduce(
operator.mul, [float(len(c)) for c in discrete_choices.values()]
)
if total_discrete_choices > MAX_CHOICES_ENUMERATE:
optimizer = "optimize_acqf_discrete_local_search"
else:
optimizer = "optimize_acqf_discrete"
# `raw_samples` is not supported by `optimize_acqf_discrete`.
# TODO[santorella]: Rather than manually removing it, we should
# ensure that it is never passed.
if optimizer_options is not None:
optimizer_options.pop("raw_samples", None)
else:
optimizer = "optimize_acqf_mixed"
# Prepare arguments for optimizer
optimizer_options_with_defaults = optimizer_argparse(
self.acqf,
bounds=bounds,
q=n,
optimizer_options=optimizer_options,
optimizer=optimizer,
)
post_processing_func = get_post_processing_func(
rounding_func=rounding_func,
optimizer_options=optimizer_options_with_defaults,
)
if fixed_features is not None:
for i in fixed_features:
if not 0 <= i < len(ssd.feature_names):
raise ValueError(f"Invalid fixed_feature index: {i}")
# Return a weight of 1 for each arm by default. This can be
# customized in subclasses if necessary.
arm_weights = torch.ones(n, dtype=self.dtype)
# 1. Handle the fully continuous search space.
if optimizer == "optimize_acqf":
candidates, acqf_values = optimize_acqf(
acq_function=self.acqf,
bounds=bounds,
q=n,
inequality_constraints=inequality_constraints,
fixed_features=fixed_features,
post_processing_func=post_processing_func,
**optimizer_options_with_defaults,
)
return candidates, acqf_values, arm_weights
# 2. Handle search spaces with discrete features.
# 2a. Handle the fully discrete search space.
if optimizer in (
"optimize_acqf_discrete",
"optimize_acqf_discrete_local_search",
):
X_observed = self.X_observed
if self.X_pending is not None:
if X_observed is None:
X_observed = self.X_pending
else:
X_observed = torch.cat([X_observed, self.X_pending], dim=0)
# Special handling for search spaces with a large number of choices
if optimizer == "optimize_acqf_discrete_local_search":
discrete_choices = [
torch.tensor(c, device=self.device, dtype=self.dtype)
for c in discrete_choices.values()
]
candidates, acqf_values = optimize_acqf_discrete_local_search(
acq_function=self.acqf,
q=n,
discrete_choices=discrete_choices,
inequality_constraints=inequality_constraints,
X_avoid=X_observed,
**optimizer_options_with_defaults,
)
return candidates, acqf_values, arm_weights
# Else, optimizer is `optimize_acqf_discrete`
# Enumerate all possible choices
all_choices = (discrete_choices[i] for i in range(len(discrete_choices)))
all_choices = _tensorize(tuple(product(*all_choices)))
# This can be vectorized, but using a for-loop to avoid memory issues
if X_observed is not None:
for x in X_observed:
all_choices = all_choices[
(all_choices - x).abs().max(dim=-1).values > DUPLICATE_TOL
]
# Filter out candidates that violate the constraints
# TODO: It will be more memory-efficient to do this filtering before
# converting the generator into a tensor. However, if we run into memory
# issues we are likely better off being smarter in how we optimize the
# acquisition function.
inequality_constraints = inequality_constraints or []
is_feasible = torch.ones(all_choices.shape[0], dtype=torch.bool)
for inds, weights, bound in inequality_constraints:
is_feasible &= (all_choices[..., inds] * weights).sum(dim=-1) >= bound
all_choices = all_choices[is_feasible]
num_choices = all_choices.size(dim=0)
if num_choices == 0:
raise SearchSpaceExhausted(
"No more feasible choices in a fully discrete search space."
)
if num_choices < n:
warnings.warn(
(
f"Requested n={n} candidates from fully discrete search "
f"space, but only {num_choices} possible choices remain."
),
AxWarning,
stacklevel=2,
)
n = num_choices
candidates, acqf_values = optimize_acqf_discrete(
acq_function=self.acqf,
q=n,
choices=all_choices,
**optimizer_options_with_defaults,
)
return candidates, acqf_values, arm_weights
# 2b. Handle mixed search spaces that have discrete and continuous features.
# Only sequential optimization is supported for `optimize_acqf_mixed`.
candidates, acqf_values = optimize_acqf_mixed(
acq_function=self.acqf,
bounds=bounds,
q=n,
# For now we just enumerate all possible discrete combinations. This is not
# scalable and and only works for a reasonably small number of choices. A
# slowdown warning is logged in `enumerate_discrete_combinations` if needed.
fixed_features_list=enumerate_discrete_combinations(
discrete_choices=discrete_choices
),
inequality_constraints=inequality_constraints,
post_processing_func=post_processing_func,
**optimizer_options_with_defaults,
)
return candidates, acqf_values, arm_weights
[docs] def evaluate(self, X: Tensor) -> Tensor:
"""Evaluate the acquisition function on the candidate set `X`.
Args:
X: A `batch_shape x q x d`-dim Tensor of t-batches with `q` `d`-dim design
points each.
Returns:
A `batch_shape'`-dim Tensor of acquisition values at the given
design points `X`, where `batch_shape'` is the broadcasted batch shape of
model and input `X`.
"""
if isinstance(self.acqf, qKnowledgeGradient):
return self.acqf.evaluate(X=X)
else:
# NOTE: `AcquisitionFunction.__call__` calls `forward`,
# so below is equivalent to `self.acqf.forward(X=X)`.
return self.acqf(X=X)
[docs] def compute_model_dependencies(
self,
surrogates: Mapping[str, Surrogate],
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
options: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
"""Computes inputs to acquisition function class based on the given
surrogate model.
NOTE: When subclassing `Acquisition` from a superclass where this
method returns a non-empty dictionary of kwargs to `AcquisitionFunction`,
call `super().compute_model_dependencies` and then update that
dictionary of options with the options for the subclass you are creating
(unless the superclass' model dependencies should not be propagated to
the subclass). See `MultiFidelityKnowledgeGradient.compute_model_dependencies`
for an example.
Args:
surrogates: Mapping from names to Surrogate objects containing BoTorch
`Model`s, with which this `Acquisition` is to be used.
search_space_digest: A SearchSpaceDigest object containing metadata
about the search space (e.g. bounds, parameter types).
torch_opt_config: A TorchOptConfig object containing optimization
arguments (e.g., objective weights, constraints).
options: The `options` kwarg dict, passed on initialization of
the `Acquisition` object.
Returns: A dictionary of surrogate model-dependent options, to be passed
as kwargs to BoTorch`AcquisitionFunction` constructor.
"""
return {}