Out of the box, Ax offers many options for candidate generation, most of which utilize Bayesian optimization algorithms built using BoTorch. For users that want to leverage Ax for experiment orchestration (via AxClient
or Scheduler
) and other features (e.g., early stopping), while relying on other methods for candidate generation, we introduced ExternalGenerationNode
.
A GenerationNode
is a building block of a GenerationStrategy
. They can be combined together utilize different methods for generating candidates at different stages of an experiment. ExternalGenerationNode
exposes a lightweight interface to allow the users to easily integrate their methods into Ax, and use them as standalone or with other GenerationNode
s in a GenerationStrategy
.
In this tutorial, we will implement a simple generation node using RandomForestRegressor
from sklearn, and combine it with Sobol (for initialization) to optimize the Hartmann6 problem.
NOTE: This is for illustration purposes only. We do not recommend using this strategy as it typically does not perform well compared to Ax's default algorithms due to it's overly greedy behavior.
import time
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from ax.core.base_trial import TrialStatus
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.parameter import RangeParameter
from ax.core.types import TParameterization
from ax.modelbridge.external_generation_node import ExternalGenerationNode
from ax.modelbridge.generation_node import GenerationNode
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import Models
from ax.modelbridge.transition_criterion import MaxTrials
from ax.plot.trace import plot_objective_value_vs_trial_index
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.service.utils.report_utils import exp_to_df
from ax.utils.common.typeutils import checked_cast
from ax.utils.measurement.synthetic_functions import hartmann6
from sklearn.ensemble import RandomForestRegressor
class RandomForestGenerationNode(ExternalGenerationNode):
"""A generation node that uses the RandomForestRegressor
from sklearn to predict candidate performance and picks the
next point as the random sample that has the best prediction.
To leverage external methods for candidate generation, the user must
create a subclass that implements ``update_generator_state`` and
``get_next_candidate`` methods. This can then be provided
as a node into a ``GenerationStrategy``, either as standalone or as
part of a larger generation strategy with other generation nodes,
e.g., with a Sobol node for initialization.
"""
def __init__(self, num_samples: int, regressor_options: Dict[str, Any]) -> None:
"""Initialize the generation node.
Args:
regressor_options: Options to pass to the random forest regressor.
num_samples: Number of random samples from the search space
used during candidate generation. The sample with the best
prediction is recommended as the next candidate.
"""
t_init_start = time.monotonic()
super().__init__(node_name="RandomForest")
self.num_samples: int = num_samples
self.regressor: RandomForestRegressor = RandomForestRegressor(
**regressor_options
)
# We will set these later when updating the state.
# Alternatively, we could have required experiment as an input
# and extracted them here.
self.parameters: Optional[List[RangeParameter]] = None
self.minimize: Optional[bool] = None
# Recording time spent in initializing the generator. This is
# used to compute the time spent in candidate generation.
self.fit_time_since_gen: float = time.monotonic() - t_init_start
def update_generator_state(self, experiment: Experiment, data: Data) -> None:
"""A method used to update the state of the generator. This includes any
models, predictors or any other custom state used by the generation node.
This method will be called with the up-to-date experiment and data before
``get_next_candidate`` is called to generate the next trial(s). Note
that ``get_next_candidate`` may be called multiple times (to generate
multiple candidates) after a call to ``update_generator_state``.
For this example, we will train the regressor using the latest data from
the experiment.
Args:
experiment: The ``Experiment`` object representing the current state of the
experiment. The key properties includes ``trials``, ``search_space``,
and ``optimization_config``. The data is provided as a separate arg.
data: The data / metrics collected on the experiment so far.
"""
search_space = experiment.search_space
parameter_names = list(search_space.parameters.keys())
metric_names = list(experiment.optimization_config.metrics.keys())
if any(
not isinstance(p, RangeParameter) for p in search_space.parameters.values()
):
raise NotImplementedError(
"This example only supports RangeParameters in the search space."
)
if search_space.parameter_constraints:
raise NotImplementedError(
"This example does not support parameter constraints."
)
if len(metric_names) != 1:
raise NotImplementedError(
"This example only supports single-objective optimization."
)
# Get the data for the completed trials.
num_completed_trials = len(experiment.trials_by_status[TrialStatus.COMPLETED])
x = np.zeros([num_completed_trials, len(parameter_names)])
y = np.zeros([num_completed_trials, 1])
for t_idx, trial in experiment.trials.items():
if trial.status == "COMPLETED":
trial_parameters = trial.arm.parameters
x[t_idx, :] = np.array([trial_parameters[p] for p in parameter_names])
trial_df = data.df[data.df["trial_index"] == t_idx]
y[t_idx, 0] = trial_df[trial_df["metric_name"] == metric_names[0]][
"mean"
].item()
# Train the regressor.
self.regressor.fit(x, y)
# Update the attributes not set in __init__.
self.parameters = search_space.parameters
self.minimize = experiment.optimization_config.objective.minimize
def get_next_candidate(
self, pending_parameters: List[TParameterization]
) -> TParameterization:
"""Get the parameters for the next candidate configuration to evaluate.
We will draw ``self.num_samples`` random samples from the search space
and predict the objective value for each sample. We will then return
the sample with the best predicted value.
Args:
pending_parameters: A list of parameters of the candidates pending
evaluation. This is often used to avoid generating duplicate candidates.
We ignore this here for simplicity.
Returns:
A dictionary mapping parameter names to parameter values for the next
candidate suggested by the method.
"""
bounds = np.array([[p.lower, p.upper] for p in self.parameters.values()])
unit_samples = np.random.random_sample([self.num_samples, len(bounds)])
samples = bounds[:, 0] + (bounds[:, 1] - bounds[:, 0]) * unit_samples
# Predict the objective value for each sample.
y_pred = self.regressor.predict(samples)
# Find the best sample.
best_idx = np.argmin(y_pred) if self.minimize else np.argmax(y_pred)
best_sample = samples[best_idx, :]
# Convert the sample to a parameterization.
candidate = {
p_name: best_sample[i].item()
for i, p_name in enumerate(self.parameters.keys())
}
return candidate
[ERROR 11-12 07:16:54] ax.storage.sqa_store.encoder: ATTENTION: The Ax team is considering deprecating SQLAlchemy storage. If you are currently using SQLAlchemy storage, please reach out to us via GitHub Issues here: https://github.com/facebook/Ax/issues/2975
We will use Sobol for the first 5 trials and defer to random forest for the rest.
generation_strategy = GenerationStrategy(
name="Sobol+RandomForest",
nodes=[
GenerationNode(
node_name="Sobol",
model_specs=[ModelSpec(Models.SOBOL)],
transition_criteria=[
MaxTrials(
# This specifies the maximum number of trials to generate from this node,
# and the next node in the strategy.
threshold=5,
block_transition_if_unmet=True,
transition_to="RandomForest"
)
],
),
RandomForestGenerationNode(num_samples=128, regressor_options={}),
],
)
ax_client = AxClient(generation_strategy=generation_strategy)
ax_client.create_experiment(
name="hartmann_test_experiment",
parameters=[
{
"name": f"x{i}",
"type": "range",
"bounds": [0.0, 1.0],
"value_type": "float", # Optional, defaults to inference from type of "bounds".
}
for i in range(1, 7)
],
objectives={"hartmann6": ObjectiveProperties(minimize=True)},
)
def evaluate(parameterization: TParameterization) -> Dict[str, Tuple[float, float]]:
x = np.array([parameterization.get(f"x{i+1}") for i in range(6)])
return {"hartmann6": (checked_cast(float, hartmann6(x)), 0.0)}
[INFO 11-12 07:16:54] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 11-12 07:16:54] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x3', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x4', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x5', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x6', parameter_type=FLOAT, range=[0.0, 1.0])], parameter_constraints=[]).
for i in range(15):
parameterization, trial_index = ax_client.get_next_trial()
ax_client.complete_trial(
trial_index=trial_index, raw_data=evaluate(parameterization)
)
/tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. warn("Encountered exception in computing model fit quality: " + str(e)) [INFO 11-12 07:16:54] ax.service.ax_client: Generated new trial 0 with parameters {'x1': 0.929941, 'x2': 0.294402, 'x3': 0.440347, 'x4': 0.284143, 'x5': 0.843343, 'x6': 0.526552} using model Sobol.
[INFO 11-12 07:16:54] ax.service.ax_client: Completed trial 0 with data: {'hartmann6': (-0.005002, 0.0)}.
/tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. warn("Encountered exception in computing model fit quality: " + str(e)) [INFO 11-12 07:16:54] ax.service.ax_client: Generated new trial 1 with parameters {'x1': 0.20274, 'x2': 0.539496, 'x3': 0.757867, 'x4': 0.523255, 'x5': 0.175048, 'x6': 0.229133} using model Sobol.
[INFO 11-12 07:16:54] ax.service.ax_client: Completed trial 1 with data: {'hartmann6': (-0.510071, 0.0)}.
/tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. warn("Encountered exception in computing model fit quality: " + str(e)) [INFO 11-12 07:16:54] ax.service.ax_client: Generated new trial 2 with parameters {'x1': 0.265065, 'x2': 0.041091, 'x3': 0.126597, 'x4': 0.08588, 'x5': 0.499236, 'x6': 0.378261} using model Sobol.
[INFO 11-12 07:16:54] ax.service.ax_client: Completed trial 2 with data: {'hartmann6': (-0.497129, 0.0)}.
/tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. warn("Encountered exception in computing model fit quality: " + str(e)) [INFO 11-12 07:16:54] ax.service.ax_client: Generated new trial 3 with parameters {'x1': 0.617877, 'x2': 0.795981, 'x3': 0.573511, 'x4': 0.846522, 'x5': 0.519155, 'x6': 0.862091} using model Sobol.
[INFO 11-12 07:16:54] ax.service.ax_client: Completed trial 3 with data: {'hartmann6': (-0.023786, 0.0)}.
/tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. warn("Encountered exception in computing model fit quality: " + str(e)) [INFO 11-12 07:16:54] ax.service.ax_client: Generated new trial 4 with parameters {'x1': 0.715555, 'x2': 0.156917, 'x3': 0.936788, 'x4': 0.964074, 'x5': 0.682883, 'x6': 0.68288} using model Sobol.
[INFO 11-12 07:16:54] ax.service.ax_client: Completed trial 4 with data: {'hartmann6': (-0.008878, 0.0)}.
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel(). return fit_method(estimator, *args, **kwargs)
[INFO 11-12 07:16:54] ax.service.ax_client: Generated new trial 5 with parameters {'x1': 0.653481, 'x2': 0.545283, 'x3': 0.734922, 'x4': 0.212491, 'x5': 0.013681, 'x6': 0.242432} using model RandomForest.
[INFO 11-12 07:16:54] ax.service.ax_client: Completed trial 5 with data: {'hartmann6': (-0.112765, 0.0)}.
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel(). return fit_method(estimator, *args, **kwargs)
[INFO 11-12 07:16:54] ax.service.ax_client: Generated new trial 6 with parameters {'x1': 0.67499, 'x2': 0.476091, 'x3': 0.551435, 'x4': 0.681594, 'x5': 0.560001, 'x6': 0.215888} using model RandomForest.
[INFO 11-12 07:16:54] ax.service.ax_client: Completed trial 6 with data: {'hartmann6': (-0.1547, 0.0)}.
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel(). return fit_method(estimator, *args, **kwargs)
[INFO 11-12 07:16:54] ax.service.ax_client: Generated new trial 7 with parameters {'x1': 0.378627, 'x2': 0.492244, 'x3': 0.795518, 'x4': 0.901471, 'x5': 0.502481, 'x6': 0.487357} using model RandomForest.
[INFO 11-12 07:16:54] ax.service.ax_client: Completed trial 7 with data: {'hartmann6': (-0.04581, 0.0)}.
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel(). return fit_method(estimator, *args, **kwargs)
[INFO 11-12 07:16:55] ax.service.ax_client: Generated new trial 8 with parameters {'x1': 0.327532, 'x2': 0.009476, 'x3': 0.818416, 'x4': 0.261737, 'x5': 0.659566, 'x6': 0.000275} using model RandomForest.
[INFO 11-12 07:16:55] ax.service.ax_client: Completed trial 8 with data: {'hartmann6': (-0.018233, 0.0)}.
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel(). return fit_method(estimator, *args, **kwargs)
[INFO 11-12 07:16:55] ax.service.ax_client: Generated new trial 9 with parameters {'x1': 0.357713, 'x2': 0.615457, 'x3': 0.065214, 'x4': 0.566858, 'x5': 0.095674, 'x6': 0.199008} using model RandomForest.
[INFO 11-12 07:16:55] ax.service.ax_client: Completed trial 9 with data: {'hartmann6': (-1.215855, 0.0)}.
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel(). return fit_method(estimator, *args, **kwargs)
[INFO 11-12 07:16:55] ax.service.ax_client: Generated new trial 10 with parameters {'x1': 0.042753, 'x2': 0.506317, 'x3': 0.997733, 'x4': 0.664924, 'x5': 0.674374, 'x6': 0.093085} using model RandomForest.
[INFO 11-12 07:16:55] ax.service.ax_client: Completed trial 10 with data: {'hartmann6': (-0.097075, 0.0)}.
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel(). return fit_method(estimator, *args, **kwargs)
[INFO 11-12 07:16:55] ax.service.ax_client: Generated new trial 11 with parameters {'x1': 0.050608, 'x2': 0.86506, 'x3': 0.082335, 'x4': 0.857122, 'x5': 0.54335, 'x6': 0.916682} using model RandomForest.
[INFO 11-12 07:16:55] ax.service.ax_client: Completed trial 11 with data: {'hartmann6': (-0.003662, 0.0)}.
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel(). return fit_method(estimator, *args, **kwargs)
[INFO 11-12 07:16:55] ax.service.ax_client: Generated new trial 12 with parameters {'x1': 0.610131, 'x2': 0.699291, 'x3': 0.570357, 'x4': 0.521161, 'x5': 0.569504, 'x6': 0.596988} using model RandomForest.
[INFO 11-12 07:16:55] ax.service.ax_client: Completed trial 12 with data: {'hartmann6': (-0.138122, 0.0)}.
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel(). return fit_method(estimator, *args, **kwargs)
[INFO 11-12 07:16:55] ax.service.ax_client: Generated new trial 13 with parameters {'x1': 0.178413, 'x2': 0.502151, 'x3': 0.008517, 'x4': 0.611965, 'x5': 0.450928, 'x6': 0.596716} using model RandomForest.
[INFO 11-12 07:16:55] ax.service.ax_client: Completed trial 13 with data: {'hartmann6': (-0.371504, 0.0)}.
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel(). return fit_method(estimator, *args, **kwargs)
[INFO 11-12 07:16:55] ax.service.ax_client: Generated new trial 14 with parameters {'x1': 0.864776, 'x2': 0.270704, 'x3': 0.840438, 'x4': 0.843946, 'x5': 0.190392, 'x6': 0.072524} using model RandomForest.
[INFO 11-12 07:16:55] ax.service.ax_client: Completed trial 14 with data: {'hartmann6': (-0.003356, 0.0)}.
exp_df = exp_to_df(ax_client.experiment)
exp_df
trial_index | arm_name | trial_status | generation_method | hartmann6 | x1 | x2 | x3 | x4 | x5 | x6 | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 0_0 | COMPLETED | Sobol | -0.005002 | 0.929941 | 0.294402 | 0.440347 | 0.284143 | 0.843343 | 0.526552 |
1 | 1 | 1_0 | COMPLETED | Sobol | -0.510071 | 0.202740 | 0.539496 | 0.757867 | 0.523255 | 0.175048 | 0.229133 |
2 | 2 | 2_0 | COMPLETED | Sobol | -0.497129 | 0.265065 | 0.041091 | 0.126597 | 0.085880 | 0.499236 | 0.378261 |
3 | 3 | 3_0 | COMPLETED | Sobol | -0.023786 | 0.617877 | 0.795981 | 0.573511 | 0.846522 | 0.519155 | 0.862091 |
4 | 4 | 4_0 | COMPLETED | Sobol | -0.008878 | 0.715555 | 0.156917 | 0.936788 | 0.964074 | 0.682883 | 0.682880 |
5 | 5 | 5_0 | COMPLETED | RandomForest | -0.112765 | 0.653481 | 0.545283 | 0.734922 | 0.212491 | 0.013681 | 0.242432 |
6 | 6 | 6_0 | COMPLETED | RandomForest | -0.154700 | 0.674990 | 0.476091 | 0.551435 | 0.681594 | 0.560001 | 0.215888 |
7 | 7 | 7_0 | COMPLETED | RandomForest | -0.045810 | 0.378627 | 0.492244 | 0.795518 | 0.901471 | 0.502481 | 0.487357 |
8 | 8 | 8_0 | COMPLETED | RandomForest | -0.018233 | 0.327532 | 0.009476 | 0.818416 | 0.261737 | 0.659566 | 0.000275 |
9 | 9 | 9_0 | COMPLETED | RandomForest | -1.215855 | 0.357713 | 0.615457 | 0.065214 | 0.566858 | 0.095674 | 0.199008 |
10 | 10 | 10_0 | COMPLETED | RandomForest | -0.097075 | 0.042753 | 0.506317 | 0.997733 | 0.664924 | 0.674374 | 0.093085 |
11 | 11 | 11_0 | COMPLETED | RandomForest | -0.003662 | 0.050608 | 0.865060 | 0.082335 | 0.857122 | 0.543350 | 0.916682 |
12 | 12 | 12_0 | COMPLETED | RandomForest | -0.138122 | 0.610131 | 0.699291 | 0.570357 | 0.521161 | 0.569504 | 0.596988 |
13 | 13 | 13_0 | COMPLETED | RandomForest | -0.371504 | 0.178413 | 0.502151 | 0.008517 | 0.611965 | 0.450928 | 0.596716 |
14 | 14 | 14_0 | COMPLETED | RandomForest | -0.003356 | 0.864776 | 0.270704 | 0.840438 | 0.843946 | 0.190392 | 0.072524 |
plot_objective_value_vs_trial_index(
exp_df=exp_df,
metric_colname="hartmann6",
minimize=True,
title="Hartmann6 Objective Value vs. Trial Index",
)
/tmp/tmp.Lx6ya87xsF/Ax-main/ax/plot/trace.py:870: FutureWarning: DataFrame.fillna with 'method' is deprecated and will raise in a future version. Use obj.ffill() or obj.bfill() instead.
Total runtime of script: 7.1 seconds.