#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.models.torch.botorch import BotorchModel, get_rounding_func
from ax.models.torch.botorch_defaults import recommend_best_out_of_sample_point
from ax.models.torch.utils import (
_get_X_pending_and_observed,
get_out_of_sample_best_point_acqf,
)
from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig
from ax.utils.common.docutils import copy_doc
from ax.utils.common.typeutils import not_none
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.cost_aware import InverseCostWeightedUtility
from botorch.acquisition.max_value_entropy_search import (
qMaxValueEntropy,
qMultiFidelityMaxValueEntropy,
)
from botorch.acquisition.utils import (
expand_trace_observations,
project_to_target_fidelity,
)
from botorch.exceptions.errors import UnsupportedError
from botorch.models.cost import AffineFidelityCostModel
from botorch.models.model import Model
from botorch.optim.optimize import optimize_acqf
from torch import Tensor
from .utils import subset_model
[docs]class MaxValueEntropySearch(BotorchModel):
r"""Max-value entropy search.
Args:
cost_intercept: The cost intercept for the affine cost of the form
`cost_intercept + n`, where `n` is the number of generated points.
Only used for multi-fidelity optimzation (i.e., if fidelity_features
are present).
linear_truncated: If `False`, use an alternate downsampling + exponential
decay Kernel instead of the default `LinearTruncatedFidelityKernel`
(only relevant for multi-fidelity optimization).
kwargs: Model-specific kwargs.
"""
def __init__(
self,
cost_intercept: float = 1.0,
linear_truncated: bool = True,
use_input_warping: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
best_point_recommender=recommend_best_out_of_sample_point,
linear_truncated=linear_truncated,
use_input_warping=use_input_warping,
**kwargs,
)
self.cost_intercept = cost_intercept
[docs] @copy_doc(TorchModel.gen)
def gen(
self,
n: int,
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
) -> TorchGenResults:
if (
torch_opt_config.linear_constraints is not None
or torch_opt_config.outcome_constraints is not None
):
raise UnsupportedError(
"Constraints are not yet supported by max-value entropy search!"
)
if len(torch_opt_config.objective_weights) > 1:
raise UnsupportedError(
"Models with multiple outcomes are not yet supported by MES!"
)
options = torch_opt_config.model_gen_options or {}
acf_options = options.get("acquisition_function_kwargs", {})
optimizer_options = options.get("optimizer_kwargs", {})
X_pending, X_observed = _get_X_pending_and_observed(
Xs=self.Xs,
objective_weights=torch_opt_config.objective_weights,
bounds=search_space_digest.bounds,
pending_observations=torch_opt_config.pending_observations,
outcome_constraints=torch_opt_config.outcome_constraints,
linear_constraints=torch_opt_config.linear_constraints,
fixed_features=torch_opt_config.fixed_features,
)
model = self.model
# subset model only to the outcomes we need for the optimization
if options.get("subset_model", True):
subset_model_results = subset_model(
model=model,
objective_weights=torch_opt_config.objective_weights,
outcome_constraints=torch_opt_config.outcome_constraints,
)
model = subset_model_results.model
objective_weights = subset_model_results.objective_weights
else:
objective_weights = torch_opt_config.objective_weights
# get the acquisition function
num_fantasies = acf_options.get("num_fantasies", 16)
num_mv_samples = acf_options.get("num_mv_samples", 10)
num_y_samples = acf_options.get("num_y_samples", 128)
candidate_size = acf_options.get("candidate_size", 1000)
num_restarts = optimizer_options.get("num_restarts", 40)
raw_samples = optimizer_options.get("raw_samples", 1024)
# generate the discrete points in the design space to sample max values
bounds_ = torch.tensor(
search_space_digest.bounds, dtype=self.dtype, device=self.device
)
bounds_ = bounds_.transpose(0, 1)
candidate_set = torch.rand(candidate_size, bounds_.size(1))
candidate_set = bounds_[0] + (bounds_[1] - bounds_[0]) * candidate_set
acq_function = _instantiate_MES(
model=model,
candidate_set=candidate_set,
num_fantasies=num_fantasies,
num_trace_observations=options.get("num_trace_observations", 0),
num_mv_samples=num_mv_samples,
num_y_samples=num_y_samples,
X_pending=X_pending,
maximize=True if objective_weights[0] == 1 else False,
target_fidelities=search_space_digest.target_fidelities,
fidelity_weights=options.get("fidelity_weights"),
cost_intercept=self.cost_intercept,
)
# optimize and get new points
botorch_rounding_func = get_rounding_func(torch_opt_config.rounding_func)
opt_options: Dict[str, Union[bool, float, int, str]] = {
"batch_limit": 8,
"maxiter": 200,
"method": "L-BFGS-B",
"nonnegative": False,
}
opt_options.update(optimizer_options.get("options", {}))
candidates, _ = optimize_acqf(
acq_function=acq_function,
bounds=bounds_,
q=n,
inequality_constraints=None,
fixed_features=torch_opt_config.fixed_features,
post_processing_func=botorch_rounding_func,
num_restarts=num_restarts,
raw_samples=raw_samples,
options=opt_options,
sequential=True,
)
return TorchGenResults(
points=candidates.detach().cpu(),
weights=torch.ones(n, dtype=self.dtype),
)
def _get_best_point_acqf(
self,
X_observed: Tensor,
objective_weights: Tensor,
mc_samples: int = 512,
fixed_features: Optional[Dict[int, float]] = None,
target_fidelities: Optional[Dict[int, float]] = None,
outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None,
seed_inner: Optional[int] = None,
qmc: bool = True,
**kwargs: Any,
) -> Tuple[AcquisitionFunction, Optional[List[int]]]:
# `outcome_constraints` is validated to be None in `gen`
if outcome_constraints is not None:
raise UnsupportedError("Outcome constraints not yet supported.")
return get_out_of_sample_best_point_acqf(
model=not_none(self.model),
Xs=self.Xs,
objective_weights=objective_weights,
# With None `outcome_constraints`, `get_objective` utility
# always returns a `ScalarizedObjective`, which results in
# `get_out_of_sample_best_point_acqf` always selecting
# `PosteriorMean`.
outcome_constraints=outcome_constraints,
X_observed=not_none(X_observed),
seed_inner=seed_inner,
fixed_features=fixed_features,
fidelity_features=self.fidelity_features,
target_fidelities=target_fidelities,
qmc=qmc,
)
def _instantiate_MES(
model: Model,
candidate_set: Tensor,
num_fantasies: int = 16,
num_mv_samples: int = 10,
num_y_samples: int = 128,
use_gumbel: bool = True,
X_pending: Optional[Tensor] = None,
maximize: bool = True,
num_trace_observations: int = 0,
target_fidelities: Optional[Dict[int, float]] = None,
fidelity_weights: Optional[Dict[int, float]] = None,
cost_intercept: float = 1.0,
) -> qMaxValueEntropy:
if target_fidelities:
if fidelity_weights is None:
fidelity_weights = {f: 1.0 for f in target_fidelities}
if not set(target_fidelities) == set(fidelity_weights):
raise RuntimeError(
"Must provide the same indices for target_fidelities "
f"({set(target_fidelities)}) and fidelity_weights "
f" ({set(fidelity_weights)})."
)
cost_model = AffineFidelityCostModel(
fidelity_weights=fidelity_weights, fixed_cost=cost_intercept
)
cost_aware_utility = InverseCostWeightedUtility(cost_model=cost_model)
def project(X: Tensor) -> Tensor:
return project_to_target_fidelity(X=X, target_fidelities=target_fidelities)
def expand(X: Tensor) -> Tensor:
return expand_trace_observations(
X=X,
fidelity_dims=sorted(target_fidelities), # pyre-ignore: [6]
num_trace_obs=num_trace_observations,
)
return qMultiFidelityMaxValueEntropy(
model=model,
candidate_set=candidate_set,
num_fantasies=num_fantasies,
num_mv_samples=num_mv_samples,
num_y_samples=num_y_samples,
X_pending=X_pending,
maximize=maximize,
cost_aware_utility=cost_aware_utility,
project=project,
expand=expand,
)
return qMaxValueEntropy(
model=model,
candidate_set=candidate_set,
num_fantasies=num_fantasies,
num_mv_samples=num_mv_samples,
num_y_samples=num_y_samples,
X_pending=X_pending,
maximize=maximize,
)