import sys
import plotly.io as pio
if 'google.colab' in sys.modules:
pio.renderers.default = "colab"
%pip install ax-platform
from typing import Any, Dict, Optional, Tuple, Type
from ax.modelbridge.registry import Models
# Ax data tranformation layer
from ax.models.torch.botorch_modular.acquisition import Acquisition
# Ax wrappers for BoTorch components
from ax.models.torch.botorch_modular.model import BoTorchModel
from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.models.torch.botorch_modular.utils import ModelConfig
# Experiment examination utilities
from ax.service.utils.report_utils import exp_to_df
# Test Ax objects
from ax.utils.testing.core_stubs import (
get_branin_data,
get_branin_data_multi_objective,
get_branin_experiment,
get_branin_experiment_with_multi_objective,
)
from botorch.acquisition.logei import (
qLogExpectedImprovement,
qLogNoisyExpectedImprovement,
)
from botorch.models.gp_regression import SingleTaskGP
# BoTorch components
from botorch.models.model import Model
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
Setup and Usage of BoTorch Models in Ax
Ax provides a set of flexible wrapper abstractions to mix-and-match BoTorch components
like Model and AcquisitionFunction and combine them into a single Model object in
Ax. The wrapper abstractions: Surrogate, Acquisition, and BoTorchModel – are
located in ax/models/torch/botorch_modular directory and aim to encapsulate
boilerplate code that interfaces between Ax and BoTorch. This functionality is in
beta-release and still evolving.
This tutorial walks through setting up a custom combination of BoTorch components in Ax in following steps:
- Quick-start example of
BoTorchModeluse BoTorchModel=Surrogate+Acquisition(overview)- Example with minimal options that uses the defaults
- Example showing all possible options
- Surrogate and Acquisition Q&A
- I know which Botorch Model and AcquisitionFunction I'd like to combine in Ax. How
do set this up?
- Making a
Surrogatefrom BoTorchModel - Using an arbitrary BoTorch
AcquisitionFunctionin Ax
- Making a
- Using
Models.BOTORCH_MODULAR(convenience wrapper that enables storage and resumability) - Utilizing
BoTorchModelin generation strategies (abstraction that allows to chain models together and use them in Ax Service API etc.)- Specifying
pending_observationsto avoid the model re-suggesting points that are part ofRUNNINGorABANDONEDtrials.
- Specifying
- Customizing a
SurrogateorAcquisition(for cases where existing subcomponent classes are not sufficient)
1. Quick-start example
Here we set up a BoTorchModel with SingleTaskGP with qLogNoisyExpectedImprovement,
one of the most popular combinations in Ax:
experiment = get_branin_experiment(with_trial=True)
data = get_branin_data(trials=[experiment.trials[0]])
[INFO 02-03 18:39:57] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
# `Models` automatically selects a model + model bridge combination.
# For `BOTORCH_MODULAR`, it will select `BoTorchModel` and `TorchModelBridge`.
model_bridge_with_GPEI = Models.BOTORCH_MODULAR(
experiment=experiment,
data=data,
surrogate_spec=SurrogateSpec(
model_configs=[ModelConfig(botorch_model_class=SingleTaskGP)]
), # Optional, will use default if unspecified
botorch_acqf_class=qLogNoisyExpectedImprovement, # Optional, will use default if unspecified
)
[INFO 02-03 18:39:57] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.
Now we can use this model to generate candidates (gen), predict outcome at a point
(predict), or evaluate acquisition function value at a given point
(evaluate_acquisition_function).
generator_run = model_bridge_with_GPEI.gen(n=1)
generator_run.arms[0]
Arm(parameters={'x1': 10.0, 'x2': 0.0})
Before you read the rest of this tutorial:
- Note that the concept of ‘model’ is Ax is somewhat a misnomer; we use 'model' to refer to an optimization setup capable of producing candidate points for optimization (and often capable of being fit to data, with exception for quasi-random generators). See Models documentation page for more information.
- Learn about
ModelBridgein Ax, as users should rarely be interacting with aModelobject directly (more about ModelBridge, a data transformation layer in Ax, here).
2. BoTorchModel = Surrogate + Acquisition
A BoTorchModel in Ax consists of two main subcomponents: a surrogate model and an
acquisition function. A surrogate model is represented as an instance of Ax’s
Surrogate class, which is a wrapper around BoTorch's Model class. The Surrogate is
defined by a SurrogateSpec. The acquisition function is represented as an instance of
Ax’s Acquisition class, a wrapper around BoTorch's AcquisitionFunction class.
2A. Example that uses defaults and requires no options
BoTorchModel does not always require surrogate and acquisition specification. If instantiated without one or both components specified, defaults are selected based on properties of experiment and data (see Appendix 2 for auto-selection logic).
# The surrogate is not specified, so it will be auto-selected
# during `model.fit`.
GPEI_model = BoTorchModel(botorch_acqf_class=qLogExpectedImprovement)
# The acquisition class is not specified, so it will be
# auto-selected during `model.gen` or `model.evaluate_acquisition`
GPEI_model = BoTorchModel(
surrogate_spec=SurrogateSpec(
model_configs=[ModelConfig(botorch_model_class=SingleTaskGP)]
)
)
# Both the surrogate and acquisition class will be auto-selected.
GPEI_model = BoTorchModel()
2B. Example with all the options
Below are the full set of configurable settings of a BoTorchModel with their
descriptions:
model = BoTorchModel(
# Optional `Surrogate` specification to use instead of default
surrogate_spec=SurrogateSpec(
model_configs=[
ModelConfig(
# BoTorch `Model` type
botorch_model_class=SingleTaskGP,
# Optional, MLL class with which to optimize model parameters
mll_class=ExactMarginalLogLikelihood,
# Optional, dictionary of keyword arguments to underlying
# BoTorch `Model` constructor
model_options={},
)
]
),
# Optional BoTorch `AcquisitionFunction` to use instead of default
botorch_acqf_class=qLogExpectedImprovement,
# Optional dict of keyword arguments, passed to the input
# constructor for the given BoTorch `AcquisitionFunction`
acquisition_options={},
# Optional Ax `Acquisition` subclass (if the given BoTorch
# `AcquisitionFunction` requires one, which is rare)
acquisition_class=None,
# Less common model settings shown with default values, refer
# to `BoTorchModel` documentation for detail
refit_on_cv=False,
warm_start_refit=True,
)
2C. Surrogate and Acquisition Q&A
Why is the surrogate argument expected to be an instance, but botorch_acqf_class
–– a class? Because a BoTorch AcquisitionFunction object (and therefore its Ax
wrapper, Acquisition) is ephemeral: it is constructed, immediately used, and destroyed
during BoTorchModel.gen, so there is no reason to keep around an Acquisition
instance. A Surrogate, on another hand, is kept in memory as long as its parent
BoTorchModel is.
How to know when to use specify acquisition_class (and thereby a non-default
Acquisition type) instead of just passing in botorch_acqf_class? In short, custom
Acquisition subclasses are needed when a given AcquisitionFunction in BoTorch needs
some non-standard subcomponents or inputs (e.g. a custom BoTorch
MCAcquisitionObjective).
Please post any other questions you have to our dedicated issue on Github: https://github.com/facebook/Ax/issues/363. This functionality is in beta-release and your feedback will be of great help to us!
3. I know which Botorch Model and AcquisitionFunction I'd like to combine in Ax. How do set this up?
3a. Making a Surrogate from BoTorch Model:
Most models should work with base Surrogate in Ax, except for BoTorch ModelListGP.
ModelListGP is a special case because its purpose is to combine multiple sub-models
into a single Model in BoTorch. It is most commonly used for multi-objective and
constrained optimization. Whether or not ModelListGP is used is determined
automatically based on the Model class and the data being used via the
ax.models.torch.botorch_modular.utils.use_model_list function.
If your Model is not a ModelListGP, the steps to set it up as a Surrogate are:
- Implement a
construct_inputsclass method. The purpose of this method is to produce arguments to a particular model from a standardized set of inputs passed to BoTorchModel-s fromSurrogate.constructin Ax. It should accept training data in form of aSupervisedDatasetcontainer and optionally other keyword arguments and produce a dictionary of arguments to__init__of theModel. SeeSingleTaskMultiFidelityGP.construct_inputsfor an example. - Pass any additional needed keyword arguments for the
Modelconstructor (that cannot be constructed from the training data and other arguments toconstruct_inputs) via themodel_optionsargument toModelConfiginSurrogateSpec.
from botorch.models.model import Model
from botorch.utils.datasets import SupervisedDataset
class MyModelClass(Model):
... # Implementation of `MyModelClass`
@classmethod
def construct_inputs(
cls, training_data: SupervisedDataset, **kwargs
) -> Dict[str, Any]:
fidelity_features = kwargs.get("fidelity_features")
if fidelity_features is None:
raise ValueError(f"Fidelity features required for {cls.__name__}.")
return {
**super().construct_inputs(training_data=training_data, **kwargs),
"fidelity_features": fidelity_features,
}
surrogate_spec = SurrogateSpec(
model_configs=[
ModelConfig(
botorch_model_class=MyModelClass, # Must implement `construct_inputs`
# Optional dict of additional keyword arguments to `MyModelClass`
model_options={},
)
]
)
NOTE: if you run into a case where base Surrogate does not work with your BoTorch
Model, please let us know in this Github issue:
https://github.com/facebook/Ax/issues/363, so we can find the right solution and augment
this tutorial.
3B. Using an arbitrary BoTorch AcquisitionFunction in Ax
Steps to set up any AcquisitionFunction in Ax are:
- Define an input constructor function. The purpose of this method is to produce
arguments to a acquisition function from a standardized set of inputs passed to
BoTorch
AcquisitionFunction-s fromAcquisition.__init__in Ax. For example, seeconstruct_inputs_qEHVI, which creates a fairly complex set of arguments needed byqExpectedHypervolumeImprovement–– a popular multi-objective optimization acquisition function offered in Ax and BoTorch. For more examples, see this collection in BoTorch: botorch/acquisition/input_constructors.py- Note that the new input constructor needs to be decorated with
@acqf_input_constructor(AcquisitionFunctionClass)to register it.
- Note that the new input constructor needs to be decorated with
- Specify the BoTorch
AcquisitionFunctionclass asbotorch_acqf_classtoBoTorchModel - (Optional) Pass any additional keyword arguments to acquisition function constructor
or to the optimizer function via
acquisition_optionsargument toBoTorchModel.
from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.input_constructors import acqf_input_constructor, MaybeDict
from botorch.utils.datasets import SupervisedDataset
from torch import Tensor
class MyAcquisitionFunctionClass(AcquisitionFunction):
... # Actual contents of the acquisition function class.
# 1. Add input constructor
@acqf_input_constructor(MyAcquisitionFunctionClass)
def construct_inputs_my_acqf(
model: Model,
training_data: MaybeDict[SupervisedDataset],
objective_thresholds: Tensor,
**kwargs: Any,
) -> Dict[str, Any]:
pass
# 2-3. Specifying `botorch_acqf_class` and `acquisition_options`
BoTorchModel(
botorch_acqf_class=MyAcquisitionFunctionClass,
acquisition_options={
"alpha": 10**-6,
# The sub-dict by the key "optimizer_options" can be passed
# to propagate options to `optimize_acqf`, used in
# `Acquisition.optimize`, to add/override the default
# optimizer options registered above.
"optimizer_options": {"sequential": False},
},
)
BoTorchModel
See section 2A for combining the resulting Surrogate instance and Acquisition type
into a BoTorchModel. You can also leverage Models.BOTORCH_MODULAR for ease of use;
more on it in section 4 below or in section 1 quick-start example.
4. Using Models.BOTORCH_MODULAR
To simplify the instantiation of an Ax ModelBridge and its undelying Model, Ax provides
a
Models registry enum.
When calling entries of that enum (e.g. Models.BOTORCH_MODULAR(experiment, data)), the
inputs are automatically distributed between a Model and a ModelBridge for a given
setup. A call to a Model enum member yields a model bridge with an underlying model,
ready for use to generate candidates.
Here we use Models.BOTORCH_MODULAR to set up a model with all-default subcomponents:
model_bridge_with_GPEI = Models.BOTORCH_MODULAR(
experiment=experiment,
data=data,
)
model_bridge_with_GPEI.gen(1)
[INFO 02-03 18:39:58] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.
GeneratorRun(1 arms, total weight 1.0)
model_bridge_with_GPEI.model.botorch_acqf_class
botorch.acquisition.logei.qLogNoisyExpectedImprovement
model_bridge_with_GPEI.model.surrogate.model.__class__
botorch.models.gp_regression.SingleTaskGP
We can use the same Models.BOTORCH_MODULAR to set up a model for multi-objective
optimization:
model_bridge_with_EHVI = Models.BOTORCH_MODULAR(
experiment=get_branin_experiment_with_multi_objective(
has_objective_thresholds=True, with_batch=True
),
data=get_branin_data_multi_objective(),
)
model_bridge_with_EHVI.gen(1)
[INFO 02-03 18:39:58] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
[INFO 02-03 18:39:58] ax.modelbridge.transforms.standardize_y: Outcome branin_a is constant, within tolerance.
[INFO 02-03 18:39:58] ax.modelbridge.transforms.standardize_y: Outcome branin_b is constant, within tolerance.
GeneratorRun(1 arms, total weight 1.0)
model_bridge_with_EHVI.model.botorch_acqf_class
botorch.acquisition.multi_objective.logei.qLogNoisyExpectedHypervolumeImprovement
model_bridge_with_EHVI.model.surrogate.model.__class__
botorch.models.gp_regression.SingleTaskGP
Furthermore, the quick-start example at the top of this tutorial shows how to specify
surrogate and acquisition subcomponents to Models.BOTORCH_MODULAR.
5. Utilizing BoTorchModel in generation strategies
Generation strategy is a key concept in Ax, enabling use of Service API (a.k.a.
AxClient) and many other higher-level abstractions. A GenerationStrategy allows to
chain multiple models in Ax and thereby automate candidate generation. Refer to the
"Generation Strategy" tutorial for more detail in generation strategies.
An example generation stategy with the modular BoTorchModel would look like this:
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.modelbridge_utils import get_pending_observation_features
gs = GenerationStrategy(
steps=[
GenerationStep( # Initialization step
# Which model to use for this step
model=Models.SOBOL,
# How many generator runs (each of which is then made a trial)
# to produce with this step
num_trials=5,
# How many trials generated from this step must be `COMPLETED`
# before the next one
min_trials_observed=5,
),
GenerationStep( # BayesOpt step
model=Models.BOTORCH_MODULAR,
# No limit on how many generator runs will be produced
num_trials=-1,
model_kwargs={ # Kwargs to pass to `BoTorchModel.__init__`
"surrogate_spec": SurrogateSpec(
model_configs=[ModelConfig(botorch_model_class=SingleTaskGP)]
),
"botorch_acqf_class": qLogNoisyExpectedImprovement,
},
),
]
)
Set up an experiment and generate 10 trials in it, adding synthetic data to experiment after each one:
experiment = get_branin_experiment(minimize=True)
assert len(experiment.trials) == 0
experiment.search_space
[INFO 02-03 18:39:59] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[])
5a. Specifying pending_observations
Note that it's important to specify pending observations to the call to gen to
avoid getting the same points re-suggested. Without pending_observations argument, Ax
models are not aware of points that should be excluded from generation. Points are
considered "pending" when they belong to STAGED, RUNNING, or ABANDONED trials
(with the latter included so model does not re-suggest points that are considered "bad"
and should not be re-suggested).
If the call to get_pending_observation_features becomes slow in your setup (since it
performs data-fetching etc.), you can opt for
get_pending_observation_features_based_on_trial_status (also from
ax.modelbridge.modelbridge_utils), but note the limitations of that utility (detailed
in its docstring).
for _ in range(10):
# Produce a new generator run and attach it to experiment as a trial
generator_run = gs.gen(
experiment=experiment,
n=1,
pending_observations=get_pending_observation_features(experiment=experiment),
)
trial = experiment.new_trial(generator_run)
# Mark the trial as 'RUNNING' so we can mark it 'COMPLETED' later
trial.mark_running(no_runner_required=True)
# Attach data for the new trial and mark it 'COMPLETED'
experiment.attach_data(get_branin_data(trials=[trial]))
trial.mark_completed()
print(f"Completed trial #{trial.index}, suggested by {generator_run._model_key}.")
/home/runner/work/Ax/Ax/ax/modelbridge/cross_validation.py:439: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.
warn("Encountered exception in computing model fit quality: " + str(e))
/home/runner/work/Ax/Ax/ax/modelbridge/cross_validation.py:439: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.
warn("Encountered exception in computing model fit quality: " + str(e))
/home/runner/work/Ax/Ax/ax/modelbridge/cross_validation.py:439: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.
warn("Encountered exception in computing model fit quality: " + str(e))
/home/runner/work/Ax/Ax/ax/modelbridge/cross_validation.py:439: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.
warn("Encountered exception in computing model fit quality: " + str(e))
/home/runner/work/Ax/Ax/ax/modelbridge/cross_validation.py:439: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.
warn("Encountered exception in computing model fit quality: " + str(e))
Completed trial #0, suggested by Sobol.
Completed trial #1, suggested by Sobol.
Completed trial #2, suggested by Sobol.
Completed trial #3, suggested by Sobol.
Completed trial #4, suggested by Sobol.
Completed trial #5, suggested by BoTorch.
Completed trial #6, suggested by BoTorch.
Completed trial #7, suggested by BoTorch.
Completed trial #8, suggested by BoTorch.
Completed trial #9, suggested by BoTorch.
Now we examine the experiment and observe the trials that were added to it and produced by the generation strategy:
exp_to_df(experiment)
| trial_index | arm_name | trial_status | generation_method | branin | x1 | x2 | |
|---|---|---|---|---|---|---|---|
| 0 | 0 | 0_0 | COMPLETED | Sobol | 46.8544 | -0.222709 | 11.6039 |
| 1 | 1 | 1_0 | COMPLETED | Sobol | 24.7717 | 6.87838 | 3.77636 |
| 2 | 2 | 2_0 | COMPLETED | Sobol | 30.2015 | 2.86796 | 7.92453 |
| 3 | 3 | 3_0 | COMPLETED | Sobol | 157.219 | -3.27361 | 0.075043 |
| 4 | 4 | 4_0 | COMPLETED | Sobol | 5.52554 | -2.99403 | 9.68191 |
| 5 | 5 | 5_0 | COMPLETED | BoTorch | 95.5524 | -5 | 8.08633 |
| 6 | 6 | 6_0 | COMPLETED | BoTorch | 3.94499 | -2.44731 | 11.8195 |
| 7 | 7 | 7_0 | COMPLETED | BoTorch | 5.73079 | -2.04632 | 9.44187 |
| 8 | 8 | 8_0 | COMPLETED | BoTorch | 73.0684 | 5.63191 | 8.57931 |
| 9 | 9 | 9_0 | COMPLETED | BoTorch | 7.62355 | 4.43278 | 2.00645 |
6. Customizing a Surrogate or Acquisition
We expect the base Surrogate and Acquisition classes to work with most BoTorch
components, but there could be a case where you would need to subclass one of
aforementioned abstractions to handle a given BoTorch component. If you run into a case
like this, feel free to open an issue on our
Github issues page –– it would be very useful
for us to know
One such example would be a need for a custom MCAcquisitionObjective or posterior
transform. To subclass Acquisition accordingly, one would override the
get_botorch_objective_and_transform method:
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
from botorch.acquisition.risk_measures import RiskMeasureMCObjective
class CustomObjectiveAcquisition(Acquisition):
def get_botorch_objective_and_transform(
self,
botorch_acqf_class: Type[AcquisitionFunction],
model: Model,
objective_weights: Tensor,
objective_thresholds: Optional[Tensor] = None,
outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None,
X_observed: Optional[Tensor] = None,
risk_measure: Optional[RiskMeasureMCObjective] = None,
) -> Tuple[Optional[MCAcquisitionObjective], Optional[PosteriorTransform]]:
... # Produce the desired `MCAcquisitionObjective` and `PosteriorTransform` instead of the default
Then to use the new subclass in BoTorchModel, just specify acquisition_class
argument along with botorch_acqf_class (to BoTorchModel directly or to
Models.BOTORCH_MODULAR, which just passes the relevant arguments to BoTorchModel
under the hood, as discussed in section 4):
Models.BOTORCH_MODULAR(
experiment=experiment,
data=data,
acquisition_class=CustomObjectiveAcquisition,
botorch_acqf_class=MyAcquisitionFunctionClass,
)
[INFO 02-03 18:40:04] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.
TorchModelBridge(model=BoTorchModel)
To use a custom Surrogate subclass, pass the surrogate argument of that type:
Models.BOTORCH_MODULAR(
experiment=experiment,
data=data,
surrogate=CustomSurrogate(botorch_model_class=MyModelClass),
)
Appendix 1: Methods available on BoTorchModel
Note that usually all these methods are used through ModelBridge –– a convertion and
transformation layer that adapts Ax abstractions to inputs required by the given model.
Core methods on BoTorchModel:
fitselects a surrogate if needed and fits the surrogate model to data viaSurrogate.fit,predictestimates metric values at a given point viaSurrogate.predict,geninstantiates an acquisition function viaAcquisition.__init__and optimizes it to generate candidates.
Other methods on BoTorchModel:
updateupdates surrogate model with training data and optionally reoptimizes model parameters viaSurrogate.update,cross_validatere-fits the surrogate model to subset of training data and makes predictions for test data,evaluate_acquisition_functioninstantiates an acquisition function and evaluates it for a given point.
Appendix 2: Default surrogate models and acquisition functions
By default, the chosen surrogate model will be:
- if fidelity parameters are present in search space:
SingleTaskMultiFidelityGP, - if task parameters are present: a set of
MultiTaskGPwrapped in aModelListGPand each modeling one task, SingleTaskGPotherwise.
The chosen acquisition function will be:
- for multi-objective settings:
qLogExpectedHypervolumeImprovement, - for single-objective settings:
qLogNoisyExpectedImprovement.
Appendix 3: Handling storage errors that arise from objects that don't have serialization logic in A
Attempting to store a generator run produced via Models.BOTORCH_MODULAR instance that
included options without serization logic with will produce an error like:
"Object <SomeAcquisitionOption object> passed to 'object_to_json' (of type <class SomeAcquisitionOption'>) is not registered with a corresponding encoder in ENCODER_REGISTRY."
The two options for handling this error are:
- disabling storage of
BoTorchModel's options by passingno_model_options_storage=TruetoModels.BOTORCH_MODULAR(...)call –– this will prevent model options from being stored on the generator run, so a generator run can be saved but cannot be used to restore the model that produced it, - specifying serialization logic for a given object that needs to occur among the
ModelorAcquisitionFunctionoptions. Tutorial for this is in the works, but in the meantime you can post an issue on the Ax GitHub to get help with this.