Ask-tell Optimization with Ax
Complex optimization problems where we wish to tune multiple parameters to improve metric performance, but the inter-parameter interactions are not fully understood, are common across various fields including machine learning, robotics, materials science, and chemistry. This category of problem is known as "black-box" optimization. The complexity of black-box optimization problems further increases if evaluations are expensive to conduct, time-consuming, or noisy.
We can use Ax to efficiently conduct an experiment in which we "ask" for candidate
points to evaluate, "tell" Ax the results, and repeat. We'll uses Ax's Client
, a tool
for managing the state of our experiment, and we'll learn how to define an optimization
problem, configure an experiment, run trials, analyze results, and persist the
experiment for later use using the Client
.
Because Ax is a black box optimizer, we can use it to optimize any arbitrary function. In this example we will minimize the Hartmann6 function, a complicated 6-dimensional function with multiple local minima. Hartmann6 is a challenging benchmark for optimization algorithms commonly used in the global optimization literature -- it tests the algorithm's ability to identify the true global minimum, rather than mistakenly converging on a local minimum. Looking at its analytic form we can see that it would be incredibly challenging to efficiently find the global minimum either by manual trial-and-error or traditional design of experiments like grid-search or random-search.
where
Learning Objectives
- Understand the basic concepts of black box optimization
- Learn how to define an optimization problem using Ax
- Configure and run an experiment using Ax's
Client
- Analyze the results of the optimization
Prerequisites
- Familiarity with Python and basic programming concepts
- Understanding of adaptive experimentation and Bayesian optimization
Step 1: Import Necessary Modules
First, ensure you have all the necessary imports:
import numpy as np
from ax.api.client import Client
from ax.api.configs import (
ExperimentConfig,
RangeParameterConfig,
ParameterType,
)
/opt/hostedtoolcache/Python/3.12.10/x64/lib/python3.12/site-packages/pyro/ops/stats.py:527: SyntaxWarning: invalid escape sequence 'g'
we have :math:ES^{*}(P,Q) ge ES^{*}(Q,Q) with equality holding if and only if :math:P=Q, i.e.
Step 2: Initialize the Client
Create an instance of the Client
to manage the state of your experiment.
client = Client()
Step 3: Configure the Experiment
The Client
instance can be configured with a series of Config
s that define how the
experiment will be run.
The Hartmann6 problem is usually evaluated on the hypercube , so we will
define six identical RangeParameterConfig
s with these bounds and add these to an
ExperimentConfig
along with other metadata about the experiment.
You may specify additional features like parameter constraints to further refine the search space and parameter scaling to help navigate parameters with nonuniform effects. For more on configuring experiments, see this recipe.
# Define six float parameters x1, x2, x3, ... for the Hartmann6 function
parameters = [
RangeParameterConfig(
name=f"x{i + 1}", parameter_type=ParameterType.FLOAT, bounds=(0, 1)
)
for i in range(6)
]
# Create an experiment configuration
experiment_config = ExperimentConfig(
name="hartmann6_experiment",
parameters=parameters,
# The following arguments are optional
description="Optimization of the Hartmann6 function",
owner="developer",
)
# Apply the experiment configuration to the client
client.configure_experiment(experiment_config=experiment_config)
Step 4: Configure Optimization
Now, we must configure the objective for this optimization, which we do using
Client.configure_optimization
. This method expects a string objective
, an expression
containing either a single metric to maximize, a linear combination of metrics to
maximize, or a tuple of multiple metrics to jointly maximize. These expressions are
parsed using SymPy. For example:
"score"
would direct Ax to maximize a metric named score"-loss"
would direct Ax to Ax to minimize a metric named loss"task_0 + 0.5 * task_1"
would direct Ax to maximize the sum of two task scores, downweighting task_1 by a factor of 0.5"score, -flops"
would direct Ax to simultaneously maximize score while minimizing flops
For more information on configuring objectives and outcome constraints, see this recipe.
metric_name = "hartmann6" # this name is used during the optimization loop in Step 5
objective = f"-{metric_name}" # minimization is specified by the negative sign
client.configure_optimization(objective=objective)
Step 5: Run Trials
Here, we will configure the ask-tell loop.
We begin by defining the Hartmann6 function as written above. Remember, this is just an example problem and any Python function can be substituted here.
# Hartmann6 function
def hartmann6(x1, x2, x3, x4, x5, x6):
alpha = np.array([1.0, 1.2, 3.0, 3.2])
A = np.array([
[10, 3, 17, 3.5, 1.7, 8],
[0.05, 10, 17, 0.1, 8, 14],
[3, 3.5, 1.7, 10, 17, 8],
[17, 8, 0.05, 10, 0.1, 14]
])
P = 10**-4 * np.array([
[1312, 1696, 5569, 124, 8283, 5886],
[2329, 4135, 8307, 3736, 1004, 9991],
[2348, 1451, 3522, 2883, 3047, 6650],
[4047, 8828, 8732, 5743, 1091, 381]
])
outer = 0.0
for i in range(4):
inner = 0.0
for j, x in enumerate([x1, x2, x3, x4, x5, x6]):
inner += A[i, j] * (x - P[i, j])**2
outer += alpha[i] * np.exp(-inner)
return -outer
hartmann6(0.1, 0.45, 0.8, 0.25, 0.552, 1.0)
np.float64(-0.4878737485613134)
Optimization Loop
We will iteratively call client.get_next_trials
to "ask" Ax for a parameterization to
evaluate, then call hartmann6
using those parameters, and finally "tell" Ax the result
using client.complete_trial
.
This loop will run multiple trials to optimize the function.
# Number of trials to run
num_trials = 30
# Run trials
for _ in range(num_trials):
trials = client.get_next_trials(
maximum_trials=1
) # We will request just one trial at a time in this example
for trial_index, parameters in trials.items():
x1 = parameters["x1"]
x2 = parameters["x2"]
x3 = parameters["x3"]
x4 = parameters["x4"]
x5 = parameters["x5"]
x6 = parameters["x6"]
result = hartmann6(x1, x2, x3, x4, x5, x6)
# Set raw_data as a dictionary with metric names as keys and results as values
raw_data = {metric_name: result}
# Complete the trial with the result
client.complete_trial(trial_index=trial_index, raw_data=raw_data)
print(f"Completed trial {trial_index} with {raw_data=}")
Completed trial 0 with raw_data={'hartmann6': np.float64(-0.5019255343509779)}
Completed trial 1 with raw_data={'hartmann6': np.float64(-0.022448780563003885)}
Completed trial 2 with raw_data={'hartmann6': np.float64(-0.14519093834066654)}
Completed trial 3 with raw_data={'hartmann6': np.float64(-0.8604693413805216)}
Completed trial 4 with raw_data={'hartmann6': np.float64(-0.04463788653477111)}
Completed trial 5 with raw_data={'hartmann6': np.float64(-1.53747420257552)}
Completed trial 6 with raw_data={'hartmann6': np.float64(-0.7746285626008993)}
Completed trial 7 with raw_data={'hartmann6': np.float64(-0.5958143119835694)}
Completed trial 8 with raw_data={'hartmann6': np.float64(-0.9180207256799746)}
Completed trial 9 with raw_data={'hartmann6': np.float64(-1.4571186548903279)}
Completed trial 10 with raw_data={'hartmann6': np.float64(-0.5652020510015644)}
Completed trial 11 with raw_data={'hartmann6': np.float64(-1.9025798308645765)}
Completed trial 12 with raw_data={'hartmann6': np.float64(-2.4023918423159127)}
Completed trial 13 with raw_data={'hartmann6': np.float64(-2.667845217834489)}
Completed trial 14 with raw_data={'hartmann6': np.float64(-1.734596106152918)}
Completed trial 15 with raw_data={'hartmann6': np.float64(-2.203282961813707)}
Completed trial 16 with raw_data={'hartmann6': np.float64(-2.9661849718642617)}
Completed trial 17 with raw_data={'hartmann6': np.float64(-2.204289273292293)}
Completed trial 18 with raw_data={'hartmann6': np.float64(-2.453902532362328)}
Completed trial 19 with raw_data={'hartmann6': np.float64(-2.5879260020462516)}
Completed trial 20 with raw_data={'hartmann6': np.float64(-3.034029178842838)}
Completed trial 21 with raw_data={'hartmann6': np.float64(-0.5740518442706214)}
Completed trial 22 with raw_data={'hartmann6': np.float64(-3.253454881380015)}
Completed trial 23 with raw_data={'hartmann6': np.float64(-2.8468072715625494)}
Completed trial 24 with raw_data={'hartmann6': np.float64(-3.213619091963154)}
Completed trial 25 with raw_data={'hartmann6': np.float64(-3.2123294457451803)}
Completed trial 26 with raw_data={'hartmann6': np.float64(-3.2824154473972658)}
Completed trial 27 with raw_data={'hartmann6': np.float64(-3.1691049255409105)}
Completed trial 28 with raw_data={'hartmann6': np.float64(-3.299233158827855)}
Completed trial 29 with raw_data={'hartmann6': np.float64(-3.177834274334158)}
Step 6: Analyze Results
After running trials, you can analyze the results. Most commonly this means extracting the parameterization from the best performing trial you conducted.
Hartmann6 has a known global minimum of at . Ax is able to identify a point very near to this true optimum using just 30 evaluations. This is possible due to the sample-efficiency of Bayesian optimization, the optimization method we use under the hood in Ax.
best_parameters, prediction, index, name = client.get_best_parameterization()
print("Best Parameters:", best_parameters)
print("Prediction (mean, variance):", prediction)
Best Parameters: {'x1': 0.21188346451138274, 'x2': 0.16428117335904663, 'x3': 0.4341952205574083, 'x4': 0.25907492071763133, 'x5': 0.29839289904146926, 'x6': 0.6470518604430703}
Prediction (mean, variance): {'hartmann6': (np.float64(-3.2758915184784447), np.float64(0.00119121385765093))}
Step 7: Compute Analyses
Ax can also produce a number of analyses to help interpret the results of the experiment
via client.compute_analyses
. Users can manually select which analyses to run, or can
allow Ax to select which would be most relevant. In this case Ax selects the following:
- Parrellel Coordinates Plot shows which parameterizations were evaluated and what metric values were observed -- this is useful for getting a high level overview of how thoroughly the search space was explored and which regions tend to produce which outcomes
- Interaction Analysis Plot shows which parameters have the largest affect on the function and plots the most important parameters as 1 or 2 dimensional surfaces
- Summary lists all trials generated along with their parameterizations, observations, and miscellaneous metadata
# display=True instructs Ax to sort then render the resulting analyses
cards = client.compute_analyses(display=True)
Parallel Coordinates for hartmann6
The parallel coordinates plot displays multi-dimensional data by representing each parameter as a parallel axis. This plot helps in assessing how thoroughly the search space has been explored and in identifying patterns or clusterings associated with high-performing (good) or low-performing (bad) arms. By tracing lines across the axes, one can observe correlations and interactions between parameters, gaining insights into the relationships that contribute to the success or failure of different configurations within the experiment.
Summary for hartmann6_experiment
High-level summary of the Trial
-s in this Experiment
trial_index | arm_name | trial_status | generation_node | hartmann6 | x1 | x2 | x3 | x4 | x5 | x6 | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 0_0 | COMPLETED | Sobol | -0.501926 | 0.221352 | 0.505657 | 0.791436 | 0.888981 | 0.406419 | 0.956151 |
1 | 1 | 1_0 | COMPLETED | Sobol | -0.022449 | 0.704186 | 0.250006 | 0.379711 | 0.458323 | 0.961675 | 0.123115 |
2 | 2 | 2_0 | COMPLETED | Sobol | -0.145191 | 0.771154 | 0.755537 | 0.534144 | 0.589093 | 0.701101 | 0.255063 |
3 | 3 | 3_0 | COMPLETED | Sobol | -0.860469 | 0.287191 | 0.00037 | 0.137009 | 0.000216 | 0.176761 | 0.664214 |
4 | 4 | 4_0 | COMPLETED | Sobol | -0.044638 | 0.407901 | 0.976937 | 0.3023 | 0.195248 | 0.006709 | 0.822607 |
5 | 5 | 5_0 | COMPLETED | MBM | -1.53747 | 0.073155 | 0 | 0.25316 | 0.148463 | 0.20774 | 0.816667 |
6 | 6 | 6_0 | COMPLETED | MBM | -0.774629 | 0 | 0 | 0 | 0 | 0.371688 | 0.665134 |
7 | 7 | 7_0 | COMPLETED | MBM | -0.595814 | 0.196423 | 0 | 0.434822 | 0.166626 | 0.037395 | 0.885145 |
8 | 8 | 8_0 | COMPLETED | MBM | -0.918021 | 0.03533 | 0 | 0.405542 | 0.163929 | 0.234641 | 0.990631 |
9 | 9 | 9_0 | COMPLETED | MBM | -1.45712 | 0 | 0 | 0.218451 | 0.244235 | 0.313667 | 0.90227 |
10 | 10 | 10_0 | COMPLETED | MBM | -0.565202 | 0 | 0 | 0.240638 | 0 | 0.107605 | 0.607965 |
11 | 11 | 11_0 | COMPLETED | MBM | -1.90258 | 0 | 0 | 0.249127 | 0.236354 | 0.284169 | 0.819713 |
12 | 12 | 12_0 | COMPLETED | MBM | -2.40239 | 0.025407 | 0 | 0.336605 | 0.305614 | 0.29196 | 0.756491 |
13 | 13 | 13_0 | COMPLETED | MBM | -2.66784 | 0.119091 | 0 | 0.439416 | 0.379654 | 0.296443 | 0.695929 |
14 | 14 | 14_0 | COMPLETED | MBM | -1.7346 | 0 | 0 | 0.25926 | 0.45492 | 0.254297 | 0.652627 |
15 | 15 | 15_0 | COMPLETED | MBM | -2.20328 | 0.178457 | 0 | 0.505581 | 0.437471 | 0.369049 | 0.725332 |
16 | 16 | 16_0 | COMPLETED | MBM | -2.96618 | 0.257104 | 0 | 0.529468 | 0.302724 | 0.290807 | 0.647679 |
17 | 17 | 17_0 | COMPLETED | MBM | -2.20429 | 0 | 0 | 0.67244 | 0.317315 | 0.275811 | 0.636231 |
18 | 18 | 18_0 | COMPLETED | MBM | -2.4539 | 0.45743 | 0 | 0.43786 | 0.285296 | 0.308876 | 0.619494 |
19 | 19 | 19_0 | COMPLETED | MBM | -2.58793 | 0.240049 | 0 | 0.139718 | 0.307023 | 0.300896 | 0.655223 |
20 | 20 | 20_0 | COMPLETED | MBM | -3.03403 | 0.235173 | 0.299635 | 0.477795 | 0.299152 | 0.296647 | 0.676215 |
21 | 21 | 21_0 | COMPLETED | MBM | -0.574052 | 0.217013 | 0.86203 | 0.450361 | 0.284583 | 0.322264 | 0.614581 |
22 | 22 | 22_0 | COMPLETED | MBM | -3.25346 | 0.238937 | 0.161156 | 0.475236 | 0.299304 | 0.288184 | 0.678222 |
23 | 23 | 23_0 | COMPLETED | MBM | -2.84681 | 0.286002 | 0.17272 | 0.520515 | 0.332247 | 0.247339 | 0.725811 |
24 | 24 | 24_0 | COMPLETED | MBM | -3.21362 | 0.203698 | 0.155476 | 0.41977 | 0.273532 | 0.325993 | 0.606466 |
25 | 25 | 25_0 | COMPLETED | MBM | -3.21233 | 0.225211 | 0.15683 | 0.464645 | 0.253081 | 0.350249 | 0.679414 |
26 | 26 | 26_0 | COMPLETED | MBM | -3.28241 | 0.211883 | 0.164281 | 0.434195 | 0.259075 | 0.298393 | 0.647052 |
27 | 27 | 27_0 | COMPLETED | MBM | -3.16911 | 0.227873 | 0.18671 | 0.538931 | 0.231134 | 0.302479 | 0.621384 |
28 | 28 | 28_0 | COMPLETED | MBM | -3.29923 | 0.169304 | 0.174556 | 0.466938 | 0.270202 | 0.316357 | 0.657593 |
29 | 29 | 29_0 | COMPLETED | MBM | -3.17783 | 0.165182 | 0.209001 | 0.399035 | 0.246805 | 0.303844 | 0.662241 |
Sensitivity Analysis for hartmann6
Understand how each parameter affects hartmann6 according to a second-order sensitivity analysis.
x1 vs. hartmann6
The slice plot provides a one-dimensional view of predicted outcomes for hartmann6 as a function of a single parameter, while keeping all other parameters fixed at their status_quo value (or mean value if status_quo is unavailable). This visualization helps in understanding the sensitivity and impact of changes in the selected parameter on the predicted metric outcomes.
x4 vs. hartmann6
The slice plot provides a one-dimensional view of predicted outcomes for hartmann6 as a function of a single parameter, while keeping all other parameters fixed at their status_quo value (or mean value if status_quo is unavailable). This visualization helps in understanding the sensitivity and impact of changes in the selected parameter on the predicted metric outcomes.
x2, x4 vs. hartmann6
The contour plot visualizes the predicted outcomes for hartmann6 across a two-dimensional parameter space, with other parameters held fixed at their status_quo value (or mean value if status_quo is unavailable). This plot helps in identifying regions of optimal performance and understanding how changes in the selected parameters influence the predicted outcomes. Contour lines represent levels of constant predicted values, providing insights into the gradient and potential optima within the parameter space.
x5, x6 vs. hartmann6
The contour plot visualizes the predicted outcomes for hartmann6 across a two-dimensional parameter space, with other parameters held fixed at their status_quo value (or mean value if status_quo is unavailable). This plot helps in identifying regions of optimal performance and understanding how changes in the selected parameters influence the predicted outcomes. Contour lines represent levels of constant predicted values, providing insights into the gradient and potential optima within the parameter space.
x2, x5 vs. hartmann6
The contour plot visualizes the predicted outcomes for hartmann6 across a two-dimensional parameter space, with other parameters held fixed at their status_quo value (or mean value if status_quo is unavailable). This plot helps in identifying regions of optimal performance and understanding how changes in the selected parameters influence the predicted outcomes. Contour lines represent levels of constant predicted values, providing insights into the gradient and potential optima within the parameter space.
Cross Validation for hartmann6
The cross-validation plot displays the model fit for each metric in the experiment. It employs a leave-one-out approach, where the model is trained on all data except one sample, which is used for validation. The plot shows the predicted outcome for the validation set on the y-axis against its actual value on the x-axis. Points that align closely with the dotted diagonal line indicate a strong model fit, signifying accurate predictions. Additionally, the plot includes 95% confidence intervals that provide insight into the noise in observations and the uncertainty in model predictions. A horizontal, flat line of predictions indicates that the model has not picked up on sufficient signal in the data, and instead is just predicting the mean.
Conclusion
This tutorial demonstrates how to use Ax's Client
for ask-tell optimization of Python
functions using the Hartmann6 function as an example. You can adjust the function and
parameters to suit your specific optimization problem.