Skip to main content
Version: 1.0.0

⚠ INFO ⚠

This document discusses non-API components of Ax, which may be subject to backwards compatibility breaking changes between major library versions. This guide is primarily useful for researchers that indend to utilize custom BoTorch components for candidate generation in Ax. For most users, we recommend limiting the customization to the options that are exposed in GenerationStrategyConfig.

Utilizing custom Generators via Modular BoTorch Interface

In Ax, we primarily utilize Bayesian optimization algorithms implemneted in BoTorch for candidate generation. While Ax offers a user-friendly API for experiment creation & orchestration, BoTorch implements a series of surrogate models, acquisition functions, optimizers and other utilities that primarily operate on PyTorch Tensors. In a sense, Ax & BoTorch speak two different languages, and the Modular BoTorch Interface is the translation layer that allows them to communicate and operate together.

GenerationStrategy and the components of Ax's modeling layer

Before diving into the specifics of Modular BoTorch, it is useful to provide brief context on how candidate generation happens in Ax.

  • The GenerationStrategy is the top level abstraction that specifies
    • a series of GenerationNodes
    • and some rules (TransitionCriterion) for transitioning between them.
  • Each GenerationNode specifies a GeneratorSpec (could be multiple, but that's beyond the scope), which contains a Generators registry entry that specifies
    • an Adapter and Generator class to use,
    • as well as any additional options to customize these objects.
  • At a high level, the Adapter classes handle the translation between
    • the Ax data model (search space, trials, data)
    • and the Generator classes, which typically operate on simplified, fully-numerical spaces.
    • This is in part handled by the Transform classes, which can implement things like
      • converting a string valued ChoiceParameter into a numerical parameter,
      • log-transforming a log-scale RangeParameter,
      • or standardizing the observations for a given metric.

In this tutorial:

  • We will consider a setup similar to the default GenerationStrategy, where we transition from CenterOfSearchSpace to Sobol to ModularBoTorch.
  • We will be using Generators.BOTORCH_MODULAR, which combines
    • the TorchAdapter and BoTorchGenerator classes (key component of the Modular BoTorch Interface),
    • and a set of default transforms to convert the Ax search space & observations into all-numerical valued inputs that are compatible with the BoTorch objects.

Let's define a helper function that will construct the GenerationStrategy from a given GeneratorSpec input that we will construct later in the tutorial.

from ax.generation_strategy.center_generation_node import CenterGenerationNode
from ax.generation_strategy.transition_criterion import MinTrials
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.generation_strategy.generation_node import GenerationNode
from ax.generation_strategy.model_spec import GeneratorSpec
from ax.modelbridge.registry import Generators

def construct_generation_strategy(
generator_spec: GeneratorSpec, node_name: str,
) -> GenerationStrategy:
"""Constructs a Center + Sobol + Modular BoTorch `GenerationStrategy`
using the provided `generator_spec` for the Modular BoTorch node.
"""
botorch_node = GenerationNode(
node_name=node_name,
model_specs=[generator_spec],
)
sobol_node = GenerationNode(
node_name="Sobol",
model_specs=[
GeneratorSpec(
model_enum=Generators.SOBOL,
# Let's use model_kwargs to set the random seed.
model_kwargs={"seed": 0},
),
],
transition_criteria=[
# Transition to BoTorch node once there are 5 trials on the experiment.
MinTrials(
threshold=5,
transition_to=botorch_node.node_name,
use_all_trials_in_exp=True,
)
]
)
# Center node is a customized node that uses a simplified logic and has a
# built-in transition criteria that transitions after generating once.
center_node = CenterGenerationNode(next_node_name=sobol_node.node_name)
return GenerationStrategy(
name=f"Center+Sobol+{node_name}",
nodes=[center_node, sobol_node, botorch_node]
)

# Let's construct the simplest version with all defaults.
construct_generation_strategy(
generator_spec=GeneratorSpec(model_enum=Generators.BOTORCH_MODULAR),
node_name="Modular BoTorch",
)
Output:
GenerationStrategy(name='Center+Sobol+Modular BoTorch', nodes=[CenterGenerationNode(next_node_name='Sobol'), GenerationNode(node_name='Sobol', model_specs=[GeneratorSpec(model_enum=Sobol, model_key_override=None)], transition_criteria=[MinTrials(transition_to='Modular BoTorch')]), GenerationNode(node_name='Modular BoTorch', model_specs=[GeneratorSpec(model_enum=BoTorch, model_key_override=None)], transition_criteria=[])])

The Modular BoTorch Generator

BoTorchGenerator is responsible for fitting surrogate models (including model selection), constructing acquisition functions, and optimizing the acquisition functions to generate candidates; using the inputs provided by TorchAdapter. BoTorchGenerator is a highly modular class that aims to balance user-friendliness with customizability. It implements dispatching logic at various places to select the appropriate surrogate model (single task, multi-task or multi-fidelity GP), acquisition function (qLogNEI, qLogNEHVI) and the optimizer, based on the properties of the (transformed) search space and optimization config. In this tutorial, we will be focusing on the customizability aspect of it.

The SurrogateSpec is a container of inputs that can be used to specify which surrogate models to fit for which metrics, and additional inputs to use when constructing these surrogate models. The ModelConfig container specifies one surrogate model class and any additional inputs for it. If multiple ModelConfigs are specified in a SurrogateSpec, both surrogate models will be fit to the training data, and the best model will be selected according to the specified criteria. This is a recent feature that is still under active development.

Let's construct an example that uses model selection between a relatively vanilla GP and a fancier option, using the same ModelConfigs for all metrics. Later in the tutorial we will also demonstrate how to implement custom BoTorch models and acquisition functions and make them compatible with the Modular BoTorch Generator.

from gpytorch.kernels import MaternKernel
from botorch.models import SingleTaskGP
from botorch.models.transforms.input import Warp
from botorch.models.map_saas import AdditiveMapSaasSingleTaskGP
from ax.utils.stats.model_fit_stats import MSE
from ax.models.torch.botorch_modular.surrogate import SurrogateSpec, ModelConfig

surrogate_spec = SurrogateSpec(
model_configs=[
# Select between two models:
# An additive mixture of relatively strong SAAS priors with input Warping.
# A relatively vanilla GP with a Matern kernel.
ModelConfig(
botorch_model_class=AdditiveMapSaasSingleTaskGP,
input_transform_classes=[Warp],
# Additional options for the model constructor. These need to be supported
# by the input constructor. We will see that below.
model_options={},
),
ModelConfig(
botorch_model_class=SingleTaskGP,
covar_module_class=MaternKernel,
covar_module_options={"nu": 2.5},
),
],
eval_criterion=MSE, # Select the model to use as the one that minimizes mean squared error.
allow_batched_models=False, # Forces each metric to be modeled with an independent BoTorch model.
# If we wanted to specify different options for different metrics.
# metric_to_model_configs: dict[str, list[ModelConfig]]
)

The surrogate model is one key component of Bayesian optimization, and the other one is the acquisition function. We can customize the acquisition function to use as well, and complete the Modular BoTorch Generator specification.

Note that we do not currently support manually selecting the acquisition function optimizer to use. We use a dispatching logic that selects the appropriate optimizer from BoTorch based on the properties of the (transformed) search space. However, we support passing in options to customize the optimization budget and other inputs used by the optimizers.

from botorch.acquisition.logei import qLogNoisyExpectedImprovement

generator_spec = GeneratorSpec(
model_enum=Generators.BOTORCH_MODULAR,
model_kwargs={
"surrogate_spec": surrogate_spec,
"botorch_acqf_class": qLogNoisyExpectedImprovement,
# Can be used for additional inputs that are not constructed
# by default in Ax. We will demonstrate below.
"acquisition_options": {},
},
# We can specify various options for the optimizer here.
model_gen_kwargs = {
"model_gen_options": {
"optimizer_kwargs": {
"num_restarts": 20,
"sequential": False,
"options": {
"batch_limit": 5,
"maxiter": 200,
},
},
},
}
)

generation_strategy = construct_generation_strategy(
generator_spec=generator_spec,
node_name="BoTorch w/ Model Selection",
)
generation_strategy
Output:
GenerationStrategy(name='Center+Sobol+BoTorch w/ Model Selection', nodes=[CenterGenerationNode(next_node_name='Sobol'), GenerationNode(node_name='Sobol', model_specs=[GeneratorSpec(model_enum=Sobol, model_key_override=None)], transition_criteria=[MinTrials(transition_to='BoTorch w/ Model Selection')]), GenerationNode(node_name='BoTorch w/ Model Selection', model_specs=[GeneratorSpec(model_enum=BoTorch, model_key_override=None)], transition_criteria=[])])

Using the custom GenerationStrategy with Client.

The custom GenerationStrategy usage, like much of the rest of this tutorial, is not considered a part of Ax API and does not come with API-level stability guarantees. However, we do expose some methods on Client to facilitate its usage with the other API components, to support advanced usage demonstrated here.

See the getting started tutorial to learn more about Client.

import numpy as np
from ax.api.client import Client
from ax.api.configs import RangeParameterConfig

client = Client()

# Define two float parameters x1, x2 in unit hypercube.
range_parameters = [
RangeParameterConfig(
name="x1", parameter_type="float", bounds=(0, 1)
),
RangeParameterConfig(
name="x2", parameter_type="float", bounds=(0, 1)
)
]

client.configure_experiment(parameters=range_parameters)

metric_name = "test_metric" # this name is used during the optimization loop
objective = f"-{metric_name}" # minimization is specified by the negative sign

client.configure_optimization(objective=objective)


def test_function(x1, x2):
# A made-up function.
return x1 ** 2.0 - (x2 + 5.0) ** 0.75 / 4.0

Let's configure the client to use our custom GenerationStrategy

client.set_generation_strategy(
generation_strategy=generation_strategy,
)

Run 10 trials to make sure it works.

for _ in range(10):
trials = client.get_next_trials(max_trials=1)
for index, parameters in trials.items():
result = test_function(**parameters)
client.complete_trial(trial_index=index, raw_data={"test_metric": result})
Output:
/opt/hostedtoolcache/Python/3.12.10/x64/lib/python3.12/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal
warnings.warn(
/opt/hostedtoolcache/Python/3.12.10/x64/lib/python3.12/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal
warnings.warn(
/opt/hostedtoolcache/Python/3.12.10/x64/lib/python3.12/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal
warnings.warn(
/opt/hostedtoolcache/Python/3.12.10/x64/lib/python3.12/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal
warnings.warn(

We can see that the trials were generated using the different GenerationNodes we have specified.

client.summarize()
trial_indexarm_nametrial_statusgeneration_nodetest_metricx1x2
000_0COMPLETEDCenterOfSearchSpace-0.6478670.50.5
111_0COMPLETEDSobol-0.6834450.4751070.592524
222_0COMPLETEDSobol-0.5056090.5787630.037122
333_0COMPLETEDSobol-0.0381020.950670.862344
444_0COMPLETEDSobol-0.8539870.1204580.261442
555_0COMPLETEDBoTorch w/ Model Selection-0.95841501
665_0COMPLETEDBoTorch w/ Model Selection-0.95841501
775_0COMPLETEDBoTorch w/ Model Selection-0.95841501
885_0COMPLETEDBoTorch w/ Model Selection-0.95841501
995_0COMPLETEDBoTorch w/ Model Selection-0.95841501

Using custom models and acquisition functions

Many models and acquisition functions that are available in BoTorch implement input constructors that allow them to interface with the Modular BoTorch Generator. In this section, we will demonstrate the necessary steps to take to ensure compatibility for any custom classes you may want to implement.

Implementing a custom model

For this tutorial, we implement a very simple GPyTorch ExactGP model that uses an RBF kernel (with ARD) and infers a homoskedastic noise level.

Model definition is straightforward. Here we implement a GPyTorch ExactGP that inherits from GPyTorchModel; together these two superclasses add all the API calls that BoTorch expects in its various modules.

For compatibility with Modular BoTorch Generator, the model class must implement construct_inputs, which is used to filter & extract the inputs necessary to construct the model from the inputs provided by Ax. The base Model class implements a very basic version of this, though a more capable implementation may be necessary depending on the specifics of the custom model class.

from typing import Optional

from botorch.models.gpytorch import GPyTorchModel
from botorch.utils.datasets import SupervisedDataset
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP
from torch import Tensor


class SimpleCustomGP(ExactGP, GPyTorchModel):

_num_outputs = 1 # to inform GPyTorchModel API

def __init__(self, train_X, train_Y, train_Yvar: Optional[Tensor] = None):
# NOTE: This ignores train_Yvar and uses inferred noise instead.
# squeeze output dim before passing train_Y to ExactGP
super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
self.mean_module = ConstantMean()
self.covar_module = ScaleKernel(
base_kernel=RBFKernel(ard_num_dims=train_X.shape[-1]),
)
self.to(train_X) # make sure we're on the right device/dtype

def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)

@classmethod
def construct_inputs(
cls,
training_data: SupervisedDataset,
# Depending on the experiment setup, additional arguments may be passed in here.
) -> dict[str, Tensor]:
return {
"train_X": training_data.X,
"train_Y": training_data.Y,
"train_Yvar": training_data.Yvar,
}

In most cases, implementing the construct_inputs method should be sufficient to support the custom model class. For more complicated cases, a dispatcher case for submodel_input_constructor can be registered, which will allow further customization. A very simple example is provided here to demonstrate.

from typing import Any
from botorch.models.model import Model
from ax.core.search_space import SearchSpaceDigest
from ax.models.torch.botorch_modular.surrogate import Surrogate, submodel_input_constructor

@submodel_input_constructor.register(SimpleCustomGP)
def _submodel_input_constructor_test(
botorch_model_class: type[Model],
model_config: ModelConfig,
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
surrogate: Surrogate,
) -> dict[str, Any]:
return botorch_model_class.construct_inputs(
training_data=dataset,
**model_config.model_options,
)

In some cases, the default model fitting logic may not be appropriate. For example, we may utilize a pre-trained surrogate model, in which case we may want to skip model fitting. Other cases may include models that do not utilize a MarginalLogLikelihood class from GPyTorch, in which case a custom model fitting routine may be registered. To customize, we can register a dispatcher case for the fit_botorch_model helper.

from botorch.fit import fit_gpytorch_mll
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from ax.models.torch.botorch_modular.utils import fit_botorch_model

@fit_botorch_model.register(SimpleCustomGP)
def _fit_botorch_model_test(
model: SimpleCustomGP,
mll_class: type[MarginalLogLikelihood],
mll_options: dict[str, Any] | None = None,
) -> None:
"""Fit a GPyTorch based BoTorch model."""
mll_options = mll_options or {}
mll = mll_class(likelihood=model.likelihood, model=model, **mll_options)
fit_gpytorch_mll(mll)

Implementing a custom acquisition function

Since the author of the tutorial wasn't feeling particularly creative, we will demonstrate this using a "custom" simple regret acquisition function. The key piece in compatibility with Modular BoTorch Generator is again an input constructor. Let's define the acquisition function and register the input constructor for it after.

import torch
from botorch.acquisition.objective import PosteriorTransform, MCAcquisitionObjective
from botorch.acquisition.monte_carlo import MCAcquisitionFunction
from botorch.sampling.base import MCSampler

class CustomSimpleRegret(MCAcquisitionFunction):
# See qSimpleRegret in BoTorch for a better implementation.
# This is simplified from the original implementation.
def __init__(
self,
model: Model,
sampler: MCSampler | None = None,
objective: MCAcquisitionObjective | None = None,
posterior_transform: PosteriorTransform | None = None,
X_pending: Tensor | None = None,
) -> None:
super().__init__(
model=model,
sampler=sampler,
objective=objective,
posterior_transform=posterior_transform,
X_pending=X_pending,
)

def forward(self, X: Tensor) -> Tensor:
samples, obj = self._get_samples_and_objectives(X=X)
return torch.mean(torch.amax(obj, dim=-1), dim=0)

The input constructors job is to extract & filter any arguments necessary to construct the acquisition function from the list of arguments that are provided by the Acquisition class (in Ax). Only a small number of arguments are required to be supported by any given input constructor. Other arguments that are not handled by the input constructor will be ignored. Additional details and examples can be found in botorch/acquisition/input_constructors.py

from botorch.acquisition.input_constructors import (
acqf_input_constructor,
construct_inputs_qSimpleRegret,
)
from typing import Callable

@acqf_input_constructor(CustomSimpleRegret)
def construct_inputs_custom_simple_regret(
model: Model,
objective: MCAcquisitionObjective | None = None,
posterior_transform: PosteriorTransform | None = None,
X_pending: Tensor | None = None,
sampler: MCSampler | None = None,
constraints: list[Callable[[Tensor], Tensor]] | None = None,
X_baseline: Tensor | None = None,
) -> dict[str, Any]:
return construct_inputs_qSimpleRegret(
model=model,
objective=objective,
posterior_transform=posterior_transform,
X_pending=X_pending,
sampler=sampler,
constraints=constraints,
X_baseline=X_baseline
)

Let's use the custom model and acquisition function

We repeat the above example but with the custom model and acquisition function this time.

generation_strategy = construct_generation_strategy(
generator_spec=GeneratorSpec(
model_enum=Generators.BOTORCH_MODULAR,
model_kwargs={
"surrogate_spec": SurrogateSpec(
model_configs=[
ModelConfig(
botorch_model_class=SimpleCustomGP
)
]
),
"botorch_acqf_class": CustomSimpleRegret
}
),
node_name="BoTorch w/ Custom Components"
)

client = Client()
client.configure_experiment(parameters=range_parameters)
client.configure_optimization(objective=objective)
client.set_generation_strategy(generation_strategy=generation_strategy)

for _ in range(10):
trials = client.get_next_trials(max_trials=1)
for index, parameters in trials.items():
result = test_function(**parameters)
client.complete_trial(trial_index=index, raw_data={"test_metric": result})

client.summarize()

trial_indexarm_nametrial_statusgeneration_nodetest_metricx1x2
000_0COMPLETEDCenterOfSearchSpace-0.6478670.50.5
111_0COMPLETEDSobol-0.6834450.4751070.592524
222_0COMPLETEDSobol-0.5056090.5787630.037122
333_0COMPLETEDSobol-0.0381020.950670.862344
444_0COMPLETEDSobol-0.8539870.1204580.261442
555_0COMPLETEDBoTorch w/ Custom Components-0.9122280.2149121
666_0COMPLETEDBoTorch w/ Custom Components-0.9283030.1735271
777_0COMPLETEDBoTorch w/ Custom Components-0.9553450.0554051
888_0COMPLETEDBoTorch w/ Custom Components-0.9584140.0004571
999_0COMPLETEDBoTorch w/ Custom Components-0.95841501