Skip to main content
Version: 0.5.0

Using external methods for candidate generation in Ax

Out of the box, Ax offers many options for candidate generation, most of which utilize Bayesian optimization algorithms built using BoTorch. For users that want to leverage Ax for experiment orchestration (via AxClient or Scheduler) and other features (e.g., early stopping), while relying on other methods for candidate generation, we introduced ExternalGenerationNode.

A GenerationNode is a building block of a GenerationStrategy. They can be combined together utilize different methods for generating candidates at different stages of an experiment. ExternalGenerationNode exposes a lightweight interface to allow the users to easily integrate their methods into Ax, and use them as standalone or with other GenerationNodes in a GenerationStrategy.

In this tutorial, we will implement a simple generation node using RandomForestRegressor from sklearn, and combine it with Sobol (for initialization) to optimize the Hartmann6 problem.

NOTE: This is for illustration purposes only. We do not recommend using this strategy as it typically does not perform well compared to Ax's default algorithms due to it's overly greedy behavior.

import sys
import plotly.io as pio
if 'google.colab' in sys.modules:
pio.renderers.default = "colab"
%pip install ax-platform
import time
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from ax.core.base_trial import TrialStatus
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.parameter import RangeParameter
from ax.core.types import TParameterization
from ax.modelbridge.external_generation_node import ExternalGenerationNode
from ax.modelbridge.generation_node import GenerationNode
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import Models
from ax.modelbridge.transition_criterion import MaxTrials
from ax.plot.trace import plot_objective_value_vs_trial_index
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.service.utils.report_utils import exp_to_df
from ax.utils.measurement.synthetic_functions import hartmann6
from sklearn.ensemble import RandomForestRegressor
from pyre_extensions import assert_is_instance


class RandomForestGenerationNode(ExternalGenerationNode):
"""A generation node that uses the RandomForestRegressor
from sklearn to predict candidate performance and picks the
next point as the random sample that has the best prediction.

To leverage external methods for candidate generation, the user must
create a subclass that implements ``update_generator_state`` and
``get_next_candidate`` methods. This can then be provided
as a node into a ``GenerationStrategy``, either as standalone or as
part of a larger generation strategy with other generation nodes,
e.g., with a Sobol node for initialization.
"""

def __init__(self, num_samples: int, regressor_options: Dict[str, Any]) -> None:
"""Initialize the generation node.

Args:
regressor_options: Options to pass to the random forest regressor.
num_samples: Number of random samples from the search space
used during candidate generation. The sample with the best
prediction is recommended as the next candidate.
"""
t_init_start = time.monotonic()
super().__init__(node_name="RandomForest")
self.num_samples: int = num_samples
self.regressor: RandomForestRegressor = RandomForestRegressor(
**regressor_options
)
# We will set these later when updating the state.
# Alternatively, we could have required experiment as an input
# and extracted them here.
self.parameters: Optional[List[RangeParameter]] = None
self.minimize: Optional[bool] = None
# Recording time spent in initializing the generator. This is
# used to compute the time spent in candidate generation.
self.fit_time_since_gen: float = time.monotonic() - t_init_start

def update_generator_state(self, experiment: Experiment, data: Data) -> None:
"""A method used to update the state of the generator. This includes any
models, predictors or any other custom state used by the generation node.
This method will be called with the up-to-date experiment and data before
``get_next_candidate`` is called to generate the next trial(s). Note
that ``get_next_candidate`` may be called multiple times (to generate
multiple candidates) after a call to ``update_generator_state``.

For this example, we will train the regressor using the latest data from
the experiment.

Args:
experiment: The ``Experiment`` object representing the current state of the
experiment. The key properties includes ``trials``, ``search_space``,
and ``optimization_config``. The data is provided as a separate arg.
data: The data / metrics collected on the experiment so far.
"""
search_space = experiment.search_space
parameter_names = list(search_space.parameters.keys())
metric_names = list(experiment.optimization_config.metrics.keys())
if any(
not isinstance(p, RangeParameter) for p in search_space.parameters.values()
):
raise NotImplementedError(
"This example only supports RangeParameters in the search space."
)
if search_space.parameter_constraints:
raise NotImplementedError(
"This example does not support parameter constraints."
)
if len(metric_names) != 1:
raise NotImplementedError(
"This example only supports single-objective optimization."
)
# Get the data for the completed trials.
num_completed_trials = len(experiment.trials_by_status[TrialStatus.COMPLETED])
x = np.zeros([num_completed_trials, len(parameter_names)])
y = np.zeros([num_completed_trials, 1])
for t_idx, trial in experiment.trials.items():
if trial.status == "COMPLETED":
trial_parameters = trial.arm.parameters
x[t_idx, :] = np.array([trial_parameters[p] for p in parameter_names])
trial_df = data.df[data.df["trial_index"] == t_idx]
y[t_idx, 0] = trial_df[trial_df["metric_name"] == metric_names[0]][
"mean"
].item()

# Train the regressor.
self.regressor.fit(x, y)
# Update the attributes not set in __init__.
self.parameters = search_space.parameters
self.minimize = experiment.optimization_config.objective.minimize

def get_next_candidate(
self, pending_parameters: List[TParameterization]
) -> TParameterization:
"""Get the parameters for the next candidate configuration to evaluate.

We will draw ``self.num_samples`` random samples from the search space
and predict the objective value for each sample. We will then return
the sample with the best predicted value.

Args:
pending_parameters: A list of parameters of the candidates pending
evaluation. This is often used to avoid generating duplicate candidates.
We ignore this here for simplicity.

Returns:
A dictionary mapping parameter names to parameter values for the next
candidate suggested by the method.
"""
bounds = np.array([[p.lower, p.upper] for p in self.parameters.values()])
unit_samples = np.random.random_sample([self.num_samples, len(bounds)])
samples = bounds[:, 0] + (bounds[:, 1] - bounds[:, 0]) * unit_samples
# Predict the objective value for each sample.
y_pred = self.regressor.predict(samples)
# Find the best sample.
best_idx = np.argmin(y_pred) if self.minimize else np.argmax(y_pred)
best_sample = samples[best_idx, :]
# Convert the sample to a parameterization.
candidate = {
p_name: best_sample[i].item()
for i, p_name in enumerate(self.parameters.keys())
}
return candidate

Construct the GenerationStrategy

We will use Sobol for the first 5 trials and defer to random forest for the rest.

generation_strategy = GenerationStrategy(
name="Sobol+RandomForest",
nodes=[
GenerationNode(
node_name="Sobol",
model_specs=[ModelSpec(Models.SOBOL)],
transition_criteria=[
MaxTrials(
# This specifies the maximum number of trials to generate from this node,
# and the next node in the strategy.
threshold=5,
block_transition_if_unmet=True,
transition_to="RandomForest"
)
],
),
RandomForestGenerationNode(num_samples=128, regressor_options={}),
],
)

Run a simple experiment using AxClient

More details on how to use AxClient can be found in the tutorial.

ax_client = AxClient(generation_strategy=generation_strategy)

ax_client.create_experiment(
name="hartmann_test_experiment",
parameters=[
{
"name": f"x{i}",
"type": "range",
"bounds": [0.0, 1.0],
"value_type": "float", # Optional, defaults to inference from type of "bounds".
}
for i in range(1, 7)
],
objectives={"hartmann6": ObjectiveProperties(minimize=True)},
)


def evaluate(parameterization: TParameterization) -> Dict[str, Tuple[float, float]]:
x = np.array([parameterization.get(f"x{i+1}") for i in range(6)])
return {"hartmann6": (assert_is_instance(hartmann6(x), float), 0.0)}
Out:

[INFO 02-03 20:35:20] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the verbose_logging argument to False. Note that float values in the logs are rounded to 6 decimal points.

Out:

[INFO 02-03 20:35:20] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x3', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x4', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x5', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x6', parameter_type=FLOAT, range=[0.0, 1.0])], parameter_constraints=[]).

Run the optimization loop

for i in range(15):
parameterization, trial_index = ax_client.get_next_trial()
ax_client.complete_trial(
trial_index=trial_index, raw_data=evaluate(parameterization)
)
Out:

/home/runner/work/Ax/Ax/ax/modelbridge/cross_validation.py:439: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.

warn("Encountered exception in computing model fit quality: " + str(e))

[INFO 02-03 20:35:20] ax.service.ax_client: Generated new trial 0 with parameters {'x1': 0.63831, 'x2': 0.140672, 'x3': 0.17142, 'x4': 0.05316, 'x5': 0.613935, 'x6': 0.33025} using model Sobol.

Out:

[INFO 02-03 20:35:20] ax.service.ax_client: Completed trial 0 with data: {'hartmann6': (-0.084001, 0.0)}.

Out:

/home/runner/work/Ax/Ax/ax/modelbridge/cross_validation.py:439: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.

warn("Encountered exception in computing model fit quality: " + str(e))

[INFO 02-03 20:35:20] ax.service.ax_client: Generated new trial 1 with parameters {'x1': 0.385221, 'x2': 0.886413, 'x3': 0.65499, 'x4': 0.783199, 'x5': 0.344269, 'x6': 0.788216} using model Sobol.

Out:

[INFO 02-03 20:35:20] ax.service.ax_client: Completed trial 1 with data: {'hartmann6': (-0.057529, 0.0)}.

Out:

/home/runner/work/Ax/Ax/ax/modelbridge/cross_validation.py:439: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.

warn("Encountered exception in computing model fit quality: " + str(e))

[INFO 02-03 20:35:20] ax.service.ax_client: Generated new trial 2 with parameters {'x1': 0.134921, 'x2': 0.286339, 'x3': 0.444643, 'x4': 0.433479, 'x5': 0.231935, 'x6': 0.514665} using model Sobol.

Out:

[INFO 02-03 20:35:20] ax.service.ax_client: Completed trial 2 with data: {'hartmann6': (-1.875527, 0.0)}.

Out:

/home/runner/work/Ax/Ax/ax/modelbridge/cross_validation.py:439: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.

warn("Encountered exception in computing model fit quality: " + str(e))

[INFO 02-03 20:35:20] ax.service.ax_client: Generated new trial 3 with parameters {'x1': 0.888668, 'x2': 0.561121, 'x3': 0.99065, 'x4': 0.668399, 'x5': 0.993786, 'x6': 0.119184} using model Sobol.

Out:

[INFO 02-03 20:35:20] ax.service.ax_client: Completed trial 3 with data: {'hartmann6': (-0.020125, 0.0)}.

Out:

/home/runner/work/Ax/Ax/ax/modelbridge/cross_validation.py:439: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.

warn("Encountered exception in computing model fit quality: " + str(e))

[INFO 02-03 20:35:20] ax.service.ax_client: Generated new trial 4 with parameters {'x1': 0.757716, 'x2': 0.380667, 'x3': 0.56912, 'x4': 0.500669, 'x5': 0.85785, 'x6': 0.375725} using model Sobol.

Out:

[INFO 02-03 20:35:20] ax.service.ax_client: Completed trial 4 with data: {'hartmann6': (-0.016265, 0.0)}.

Out:

/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/sklearn/base.py:1389: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().

return fit_method(estimator, *args, **kwargs)

[INFO 02-03 20:35:20] ax.service.ax_client: Generated new trial 5 with parameters {'x1': 0.299728, 'x2': 0.566086, 'x3': 0.287149, 'x4': 0.507838, 'x5': 0.501386, 'x6': 0.82669} using model RandomForest.

Out:

[INFO 02-03 20:35:20] ax.service.ax_client: Completed trial 5 with data: {'hartmann6': (-0.44245, 0.0)}.

Out:

/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/sklearn/base.py:1389: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().

return fit_method(estimator, *args, **kwargs)

Out:

[INFO 02-03 20:35:20] ax.service.ax_client: Generated new trial 6 with parameters {'x1': 0.038023, 'x2': 0.569057, 'x3': 0.06513, 'x4': 0.709504, 'x5': 0.896278, 'x6': 0.031761} using model RandomForest.

Out:

[INFO 02-03 20:35:20] ax.service.ax_client: Completed trial 6 with data: {'hartmann6': (-0.112306, 0.0)}.

Out:

/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/sklearn/base.py:1389: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().

return fit_method(estimator, *args, **kwargs)

Out:

[INFO 02-03 20:35:21] ax.service.ax_client: Generated new trial 7 with parameters {'x1': 0.743513, 'x2': 0.641311, 'x3': 0.167111, 'x4': 0.505774, 'x5': 0.187773, 'x6': 0.095034} using model RandomForest.

Out:

[INFO 02-03 20:35:21] ax.service.ax_client: Completed trial 7 with data: {'hartmann6': (-0.273633, 0.0)}.

Out:

/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/sklearn/base.py:1389: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().

return fit_method(estimator, *args, **kwargs)

Out:

[INFO 02-03 20:35:21] ax.service.ax_client: Generated new trial 8 with parameters {'x1': 0.227091, 'x2': 0.617702, 'x3': 0.22959, 'x4': 0.454016, 'x5': 0.182462, 'x6': 0.32985} using model RandomForest.

Out:

[INFO 02-03 20:35:21] ax.service.ax_client: Completed trial 8 with data: {'hartmann6': (-0.607305, 0.0)}.

Out:

/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/sklearn/base.py:1389: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().

return fit_method(estimator, *args, **kwargs)

Out:

[INFO 02-03 20:35:21] ax.service.ax_client: Generated new trial 9 with parameters {'x1': 0.995698, 'x2': 0.901448, 'x3': 0.831931, 'x4': 0.742936, 'x5': 0.126695, 'x6': 0.101116} using model RandomForest.

Out:

[INFO 02-03 20:35:21] ax.service.ax_client: Completed trial 9 with data: {'hartmann6': (-0.006273, 0.0)}.

Out:

/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/sklearn/base.py:1389: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().

return fit_method(estimator, *args, **kwargs)

Out:

[INFO 02-03 20:35:21] ax.service.ax_client: Generated new trial 10 with parameters {'x1': 0.978118, 'x2': 0.821285, 'x3': 0.802858, 'x4': 0.621627, 'x5': 0.736022, 'x6': 0.481513} using model RandomForest.

Out:

[INFO 02-03 20:35:21] ax.service.ax_client: Completed trial 10 with data: {'hartmann6': (-0.001785, 0.0)}.

Out:

/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/sklearn/base.py:1389: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().

return fit_method(estimator, *args, **kwargs)

Out:

[INFO 02-03 20:35:21] ax.service.ax_client: Generated new trial 11 with parameters {'x1': 0.22562, 'x2': 0.438048, 'x3': 0.334638, 'x4': 0.285988, 'x5': 0.232312, 'x6': 0.811747} using model RandomForest.

Out:

[INFO 02-03 20:35:21] ax.service.ax_client: Completed trial 11 with data: {'hartmann6': (-1.808843, 0.0)}.

Out:

/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/sklearn/base.py:1389: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().

return fit_method(estimator, *args, **kwargs)

Out:

[INFO 02-03 20:35:21] ax.service.ax_client: Generated new trial 12 with parameters {'x1': 0.22676, 'x2': 0.920744, 'x3': 0.793768, 'x4': 0.696917, 'x5': 0.098601, 'x6': 0.801295} using model RandomForest.

Out:

[INFO 02-03 20:35:21] ax.service.ax_client: Completed trial 12 with data: {'hartmann6': (-0.075857, 0.0)}.

Out:

/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/sklearn/base.py:1389: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().

return fit_method(estimator, *args, **kwargs)

Out:

[INFO 02-03 20:35:21] ax.service.ax_client: Generated new trial 13 with parameters {'x1': 0.500975, 'x2': 0.082213, 'x3': 0.404258, 'x4': 0.317887, 'x5': 0.163446, 'x6': 0.706106} using model RandomForest.

Out:

[INFO 02-03 20:35:21] ax.service.ax_client: Completed trial 13 with data: {'hartmann6': (-1.71524, 0.0)}.

Out:

/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/sklearn/base.py:1389: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().

return fit_method(estimator, *args, **kwargs)

Out:

[INFO 02-03 20:35:21] ax.service.ax_client: Generated new trial 14 with parameters {'x1': 0.592487, 'x2': 0.478587, 'x3': 0.083944, 'x4': 0.047016, 'x5': 0.671112, 'x6': 0.841802} using model RandomForest.

Out:

[INFO 02-03 20:35:21] ax.service.ax_client: Completed trial 14 with data: {'hartmann6': (-0.055547, 0.0)}.

View the trials generated during optimization

exp_df = exp_to_df(ax_client.experiment)
exp_df
trial_indexarm_nametrial_statusgeneration_methodhartmann6x1x2x3x4x5x6
000_0COMPLETEDSobol-0.0840010.638310.1406720.171420.053160.6139350.33025
111_0COMPLETEDSobol-0.0575290.3852210.8864130.654990.7831990.3442690.788216
222_0COMPLETEDSobol-1.875530.1349210.2863390.4446430.4334790.2319350.514665
333_0COMPLETEDSobol-0.0201250.8886680.5611210.990650.6683990.9937860.119184
444_0COMPLETEDSobol-0.0162650.7577160.3806670.569120.5006690.857850.375725
555_0COMPLETEDRandomForest-0.442450.2997280.5660860.2871490.5078380.5013860.82669
666_0COMPLETEDRandomForest-0.1123060.0380230.5690570.065130.7095040.8962780.031761
777_0COMPLETEDRandomForest-0.2736330.7435130.6413110.1671110.5057740.1877730.095034
888_0COMPLETEDRandomForest-0.6073050.2270910.6177020.229590.4540160.1824620.32985
999_0COMPLETEDRandomForest-0.0062730.9956980.9014480.8319310.7429360.1266950.101116
101010_0COMPLETEDRandomForest-0.0017850.9781180.8212850.8028580.6216270.7360220.481513
111111_0COMPLETEDRandomForest-1.808840.225620.4380480.3346380.2859880.2323120.811747
121212_0COMPLETEDRandomForest-0.0758570.226760.9207440.7937680.6969170.0986010.801295
131313_0COMPLETEDRandomForest-1.715240.5009750.0822130.4042580.3178870.1634460.706106
141414_0COMPLETEDRandomForest-0.0555470.5924870.4785870.0839440.0470160.6711120.841802
plot_objective_value_vs_trial_index(
exp_df=exp_df,
metric_colname="hartmann6",
minimize=True,
title="Hartmann6 Objective Value vs. Trial Index",
)
Out:

/home/runner/work/Ax/Ax/ax/plot/trace.py:873: FutureWarning:

DataFrame.fillna with 'method' is deprecated and will raise in a future version. Use obj.ffill() or obj.bfill() instead.

loading...