Source code for ax.models.torch.botorch_modular.sebo

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import functools
import warnings
from collections.abc import Callable
from copy import deepcopy
from functools import partial
from logging import Logger
from typing import Any

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import AxWarning
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.torch_base import TorchOptConfig
from ax.utils.common.logger import get_logger
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.multi_objective.logei import (
    qLogNoisyExpectedHypervolumeImprovement,
)
from botorch.acquisition.penalized import L0Approximation
from botorch.models.deterministic import GenericDeterministicModel
from botorch.models.model import ModelList
from botorch.optim import (
    Homotopy,
    HomotopyParameter,
    LogLinearHomotopySchedule,
    optimize_acqf_homotopy,
)
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.transforms import unnormalize
from pyre_extensions import none_throws
from torch import Tensor
from torch.quasirandom import SobolEngine

CLAMP_TOL = 1e-2
logger: Logger = get_logger(__name__)


[docs] class SEBOAcquisition(Acquisition): """ Implement the acquisition function of Sparsity Exploring Bayesian Optimization (SEBO). The SEBO is a hyperparameter-free method to simultaneously maximize a target objective and sparsity. When L0 norm is used, SEBO uses a novel differentiable relaxation based on homotopy continuation to efficiently optimize for sparsity. """ def __init__( self, surrogate: Surrogate, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, botorch_acqf_class: type[AcquisitionFunction], options: dict[str, Any] | None = None, ) -> None: tkwargs: dict[str, Any] = {"dtype": surrogate.dtype, "device": surrogate.device} options = {} if options is None else options self.penalty_name: str = options.pop("penalty", "L0_norm") self.target_point: Tensor = options.pop("target_point", None) if self.target_point is None: raise ValueError("please provide target point.") self.target_point.to(**tkwargs) self.sparsity_threshold: int = options.pop( "sparsity_threshold", surrogate.Xs[0].shape[-1] ) # construct determinsitic model for penalty term # pyre-fixme[4]: Attribute must be annotated. self.deterministic_model = self._construct_penalty() surrogate_f = deepcopy(surrogate) # we need to clamp the training data to the target point here as it may # be slightly off due to numerical issues. X_sparse = clamp_to_target( X=surrogate_f.Xs[0].clone(), target_point=self.target_point, clamp_tol=CLAMP_TOL, ) # update the training data in new surrogate none_throws(surrogate_f._training_data).append( SupervisedDataset( X=X_sparse, Y=self.deterministic_model(X_sparse), Yvar=torch.zeros(X_sparse.shape[0], 1, **tkwargs), # noiseless feature_names=surrogate_f.training_data[0].feature_names, outcome_names=[self.penalty_name], ) ) # update the model in new surrogate surrogate_f._model = ModelList(surrogate.model, self.deterministic_model) # update objective weights and thresholds in the torch config torch_opt_config_sebo = self._transform_torch_config( torch_opt_config=torch_opt_config, **tkwargs ) # Change some options (note: we do not want to do this in-place) if options.get("cache_root", False): warnings.warn( "SEBO doesn't support `cache_root=True`. Changing it to `False`.", AxWarning, stacklevel=3, ) options = {**options, "cache_root": False} # Instantiate the `botorch_acqf_class`. We need to modify `a` before doing this # (as it controls the L0 norm approximation) since the baseline will be pruned # when the acquisition function is created. With a=1e-6 the deterministic model # will be numerically close to the true L0 norm and we will select the # baseline according to the last homotopy step. if self.penalty_name == "L0_norm": self.deterministic_model._f.a.fill_(1e-6) super().__init__( surrogate=surrogate_f, search_space_digest=search_space_digest, torch_opt_config=torch_opt_config_sebo, botorch_acqf_class=qLogNoisyExpectedHypervolumeImprovement, options=options, ) # update objective threshold for deterministic model (penalty term) self.acqf.ref_point[-1] = self.sparsity_threshold * -1 self._objective_thresholds[-1] = self.sparsity_threshold # pyre-ignore def _construct_penalty(self) -> GenericDeterministicModel: """Construct a penalty term as deterministic model to be included in SEBO acqusition function. Currently only L0 and L1 penalty are supported. """ if self.penalty_name == "L0_norm": L0 = L0Approximation(target_point=self.target_point) return GenericDeterministicModel(f=L0) elif self.penalty_name == "L1_norm": L1 = functools.partial( L1_norm_func, init_point=self.target_point, ) return GenericDeterministicModel(f=L1) else: raise NotImplementedError( f"{self.penalty_name} is not currently implemented." ) def _transform_torch_config( self, torch_opt_config: TorchOptConfig, **tkwargs: Any, ) -> TorchOptConfig: """Transform torch config to include penalty term (deterministic model) as an additional outcomes in BoTorch model. """ # update objective weights by appending the weight -1 for sparsity objective. objective_weights_sebo = torch.cat( [torch_opt_config.objective_weights, -torch.ones(1, **tkwargs)] ) if torch_opt_config.outcome_constraints is not None: # update the shape of A matrix in outcome_constraints A, b = none_throws(torch_opt_config.outcome_constraints) outcome_constraints_sebo = ( torch.cat([A, torch.zeros(A.shape[0], 1, **tkwargs)], dim=1), b, ) else: outcome_constraints_sebo = None if torch_opt_config.objective_thresholds is not None: objective_thresholds_sebo = torch.cat( [ torch_opt_config.objective_thresholds, torch.tensor([self.sparsity_threshold], **tkwargs), ] ) else: # NOTE: The reference point will be inferred in the base class. objective_thresholds_sebo = None # update pending observations (if not none) by appending an obs for # the new penalty outcome pending_observations = torch_opt_config.pending_observations if torch_opt_config.pending_observations is not None: pending_observations = torch_opt_config.pending_observations + [ torch_opt_config.pending_observations[0] ] return TorchOptConfig( objective_weights=objective_weights_sebo, outcome_constraints=outcome_constraints_sebo, objective_thresholds=objective_thresholds_sebo, linear_constraints=torch_opt_config.linear_constraints, fixed_features=torch_opt_config.fixed_features, pending_observations=pending_observations, model_gen_options=torch_opt_config.model_gen_options, rounding_func=torch_opt_config.rounding_func, opt_config_metrics=torch_opt_config.opt_config_metrics, is_moo=True, # SEBO adds an objective, so it'll always be MOO. )
[docs] def optimize( self, n: int, search_space_digest: SearchSpaceDigest, inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, fixed_features: dict[int, float] | None = None, rounding_func: Callable[[Tensor], Tensor] | None = None, optimizer_options: dict[str, Any] | None = None, ) -> tuple[Tensor, Tensor, Tensor]: """Generate a set of candidates via multi-start optimization. Obtains candidates and their associated acquisition function values. Args: n: The number of candidates to generate. search_space_digest: A ``SearchSpaceDigest`` object containing search space properties, e.g. ``bounds`` for optimization. inequality_constraints: A list of tuples (indices, coefficients, rhs), with each tuple encoding an inequality constraint of the form ``sum_i (X[indices[i]] * coefficients[i]) >= rhs``. fixed_features: A map `{feature_index: value}` for features that should be fixed to a particular value during generation. rounding_func: A function that post-processes an optimization result appropriately (i.e., according to `round-trip` transformations). optimizer_options: Options for the optimizer function, e.g. ``sequential`` or ``raw_samples``. Returns: A three-element tuple containing an `n x d`-dim tensor of generated candidates, a tensor with the associated acquisition values, and a tensor with the weight for each candidate. """ if self.penalty_name == "L0_norm": if inequality_constraints is not None: raise NotImplementedError( "Homotopy does not support optimization with inequality " + "constraints. Use L1 penalty norm instead." ) candidates, expected_acquisition_value, weights = ( self._optimize_with_homotopy( n=n, search_space_digest=search_space_digest, fixed_features=fixed_features, rounding_func=rounding_func, optimizer_options=optimizer_options, ) ) else: # if L1 norm use standard moo-opt candidates, expected_acquisition_value, weights = super().optimize( n=n, search_space_digest=search_space_digest, inequality_constraints=inequality_constraints, fixed_features=fixed_features, rounding_func=rounding_func, optimizer_options=optimizer_options, ) # similar, make sure if applies to sparse dimensions only candidates = clamp_to_target( X=candidates, target_point=self.target_point, clamp_tol=CLAMP_TOL ) return candidates, expected_acquisition_value, weights
def _optimize_with_homotopy( self, n: int, search_space_digest: SearchSpaceDigest, fixed_features: dict[int, float] | None = None, rounding_func: Callable[[Tensor], Tensor] | None = None, optimizer_options: dict[str, Any] | None = None, ) -> tuple[Tensor, Tensor, Tensor]: """Optimize SEBO ACQF with L0 norm using homotopy.""" # extend to fixed a no homotopy_schedule schedule _tensorize = partial(torch.tensor, dtype=self.dtype, device=self.device) ssd = search_space_digest bounds = _tensorize(ssd.bounds).t() homotopy_schedule = LogLinearHomotopySchedule(start=0.2, end=1e-3, num_steps=30) # Prepare arguments for optimizer optimizer_options_with_defaults = optimizer_argparse( self.acqf, optimizer_options=optimizer_options, optimizer="optimize_acqf_homotopy", ) homotopy = Homotopy( homotopy_parameters=[ HomotopyParameter( parameter=self.deterministic_model._f.a, schedule=homotopy_schedule, ) ], ) batch_initial_conditions = get_batch_initial_conditions( acq_function=self.acqf, raw_samples=optimizer_options_with_defaults["raw_samples"], X_pareto=self.acqf.X_baseline, target_point=self.target_point, bounds=bounds, num_restarts=optimizer_options_with_defaults["num_restarts"], ) candidates, expected_acquisition_value = optimize_acqf_homotopy( q=n, acq_function=self.acqf, bounds=bounds, homotopy=homotopy, num_restarts=optimizer_options_with_defaults["num_restarts"], raw_samples=optimizer_options_with_defaults["raw_samples"], post_processing_func=rounding_func, fixed_features=fixed_features, batch_initial_conditions=batch_initial_conditions, ) return ( candidates, expected_acquisition_value, torch.ones(n, device=candidates.device, dtype=candidates.dtype), )
[docs] def L1_norm_func(X: Tensor, init_point: Tensor) -> Tensor: r"""L1_norm takes in a a `batch_shape x n x d`-dim input tensor `X` to a `batch_shape x n x 1`-dimensional L1 norm tensor. To be used for constructing a GenericDeterministicModel. """ return torch.linalg.norm((X - init_point), ord=1, dim=-1, keepdim=True)
[docs] def clamp_to_target(X: Tensor, target_point: Tensor, clamp_tol: float) -> Tensor: """Clamp generated candidates within the given ranges to the target point. Args: X: A `batch_shape x n x d`-dim input tensor `X`. target_point: A tensor of size `d` corresponding to the target point. clamp_tol: The clamping tolerance. Any value within `clamp_tol` of the `target_point` will be clamped to the `target_point`. """ clamp_mask = (X - target_point).abs() <= clamp_tol X[clamp_mask] = target_point.clone().repeat(*X.shape[:-1], 1)[clamp_mask] return X
[docs] def get_batch_initial_conditions( acq_function: AcquisitionFunction, raw_samples: int, X_pareto: Tensor, target_point: Tensor, bounds: Tensor, num_restarts: int = 20, ) -> Tensor: """Generate starting points for the SEBO acquisition function optimization.""" tkwargs: dict[str, Any] = {"device": X_pareto.device, "dtype": X_pareto.dtype} dim = X_pareto.shape[-1] # dimension num_sobol, num_local = num_restarts // 2, num_restarts - num_restarts // 2 # (1) Global sparse Sobol points X_cand_sobol = ( SobolEngine(dimension=dim, scramble=True) .draw(raw_samples, dtype=tkwargs["dtype"]) .to(**tkwargs) ) X_cand_sobol = unnormalize(X_cand_sobol, bounds=bounds) acq_vals = acq_function(X_cand_sobol.unsqueeze(1)) if len(X_pareto) == 0: return X_cand_sobol[acq_vals.topk(num_restarts).indices] X_cand_sobol = X_cand_sobol[acq_vals.topk(num_sobol).indices] # (2) Perturbations of points on the Pareto frontier (done by TuRBO/Spearmint) X_cand_local = X_pareto.clone()[ torch.randint(high=len(X_pareto), size=(raw_samples,)) ] mask = X_cand_local != target_point X_cand_local[mask] += ( 0.2 * ((bounds[1] - bounds[0]) * torch.randn_like(X_cand_local))[mask] ) X_cand_local = torch.clamp(X_cand_local, min=bounds[0], max=bounds[1]) X_cand_local = X_cand_local[ acq_function(X_cand_local.unsqueeze(1)).topk(num_local).indices ] return torch.cat((X_cand_sobol, X_cand_local), dim=0).unsqueeze(1)