The Ax Service API is designed to allow the user to control scheduling of trials and data computation while having an easy to use interface with Ax.
The user iteratively:
from ax.service.ax_client import AxClient
from ax.utils.measurement.synthetic_functions import hartmann6
from ax.utils.notebook.plotting import render, init_notebook_plotting
init_notebook_plotting()
[INFO 12-26 18:17:49] ipy_plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.
Create a client object to interface with Ax APIs. By default this runs locally without storage.
ax_client = AxClient()
[INFO 12-26 18:17:49] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 2 decimal points.
An experiment consists of a search space (parameters and parameter constraints) and optimization configuration (objective name, minimization setting, and outcome constraints). Note that:
name
, parameters
, and objective_name
arguments are required.parameters
have the following required keys: "name" - parameter name, "type" - parameter type ("range", "choice" or "fixed"), "bounds" for range parameters, "values" for choice parameters, and "value" for fixed parameters.parameters
can optionally include "value_type" ("int", "float", "bool" or "str"), "log_scale" flag for range parameters, and "is_ordered" flag for choice parameters.parameter_constraints
should be a list of strings of form "p1 >= p2" or "p1 + p2 <= some_bound".outcome_constraints
should be a list of strings of form "constrained_metric <= some_bound".ax_client.create_experiment(
name="hartmann_test_experiment",
parameters=[
{
"name": "x1",
"type": "range",
"bounds": [0.0, 1.0],
"value_type": "float", # Optional, defaults to inference from type of "bounds".
"log_scale": False, # Optional, defaults to False.
},
{
"name": "x2",
"type": "range",
"bounds": [0.0, 1.0],
},
{
"name": "x3",
"type": "range",
"bounds": [0.0, 1.0],
},
{
"name": "x4",
"type": "range",
"bounds": [0.0, 1.0],
},
{
"name": "x5",
"type": "range",
"bounds": [0.0, 1.0],
},
{
"name": "x6",
"type": "range",
"bounds": [0.0, 1.0],
},
],
objective_name="hartmann6",
minimize=True, # Optional, defaults to False.
parameter_constraints=["x1 + x2 <= 2.0"], # Optional.
outcome_constraints=["l2norm <= 1.25"], # Optional.
)
[INFO 12-26 18:17:49] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 6 arms, GPEI for subsequent arms], generated 0 arm(s) so far). Iterations after 6 will take longer to generate due to model-fitting.
When using Ax a service, evaluation of parameterizations suggested by Ax is done either locally or, more commonly, using an external scheduler. Below is a dummy evaluation function that outputs data for two metrics "hartmann6" and "l2norm". Note that all returned metrics correspond to either the objective_name
set on experiment creation or the metric names mentioned in outcome_constraints
.
import numpy as np
def evaluate(parameters):
x = np.array([parameters.get(f"x{i+1}") for i in range(6)])
# In our case, standard error is 0, since we are computing a synthetic function.
return {"hartmann6": (hartmann6(x), 0.0), "l2norm": (np.sqrt((x ** 2).sum()), 0.0)}
Result of the evaluation should generally be a mapping of the format: {metric_name -> (mean, SEM)}
. If there is only one metric in the experiment – the objective – then evaluation function can return a single tuple of mean and SEM, in which case Ax will assume that evaluation corresponds to the objective. It can also return only the mean as a float, in which case Ax will treat SEM as unknown and use a model that can infer it.
For more details on evaluation function, refer to the "Trial Evaluation" section in the Ax docs at ax.dev
With the experiment set up, we can start the optimization loop.
At each step, the user queries the client for a new trial then submits the evaluation of that trial back to the client.
Note that Ax auto-selects an appropriate optimization algorithm based on the search space. For more advance use cases that require a specific optimization algorithm, pass a generation_strategy
argument into the AxClient
constructor. Note that when Bayesian Optimization is used, generating new trials may take a few minutes.
for i in range(25):
parameters, trial_index = ax_client.get_next_trial()
# Local evaluation here can be replaced with deployment to external system.
ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate(parameters))
[INFO 12-26 18:17:49] ax.service.ax_client: Generated new trial 0 with parameters {'x1': 0.89, 'x2': 0.5, 'x3': 0.56, 'x4': 0.35, 'x5': 0.18, 'x6': 0.25}. [INFO 12-26 18:17:49] ax.service.ax_client: Completed trial 0 with data: {'hartmann6': (-0.1, 0.0), 'l2norm': (1.25, 0.0)}. [INFO 12-26 18:17:49] ax.service.ax_client: Generated new trial 1 with parameters {'x1': 0.12, 'x2': 0.6, 'x3': 0.9, 'x4': 0.96, 'x5': 0.34, 'x6': 0.05}. [INFO 12-26 18:17:49] ax.service.ax_client: Completed trial 1 with data: {'hartmann6': (-0.1, 0.0), 'l2norm': (1.49, 0.0)}. [INFO 12-26 18:17:49] ax.service.ax_client: Generated new trial 2 with parameters {'x1': 0.5, 'x2': 0.02, 'x3': 0.91, 'x4': 0.52, 'x5': 0.81, 'x6': 0.06}. [INFO 12-26 18:17:49] ax.service.ax_client: Completed trial 2 with data: {'hartmann6': (-0.01, 0.0), 'l2norm': (1.42, 0.0)}. [INFO 12-26 18:17:49] ax.service.ax_client: Generated new trial 3 with parameters {'x1': 0.71, 'x2': 0.3, 'x3': 0.81, 'x4': 0.36, 'x5': 0.22, 'x6': 0.04}. [INFO 12-26 18:17:49] ax.service.ax_client: Completed trial 3 with data: {'hartmann6': (-0.06, 0.0), 'l2norm': (1.2, 0.0)}. [INFO 12-26 18:17:49] ax.service.ax_client: Generated new trial 4 with parameters {'x1': 0.76, 'x2': 0.78, 'x3': 0.13, 'x4': 0.38, 'x5': 0.57, 'x6': 0.07}. [INFO 12-26 18:17:49] ax.service.ax_client: Completed trial 4 with data: {'hartmann6': (-0.23, 0.0), 'l2norm': (1.29, 0.0)}. [INFO 12-26 18:17:49] ax.service.ax_client: Generated new trial 5 with parameters {'x1': 0.68, 'x2': 0.23, 'x3': 0.6, 'x4': 0.37, 'x5': 0.94, 'x6': 0.13}. [INFO 12-26 18:17:49] ax.service.ax_client: Completed trial 5 with data: {'hartmann6': (-0.02, 0.0), 'l2norm': (1.39, 0.0)}. [INFO 12-26 18:17:54] ax.service.ax_client: Generated new trial 6 with parameters {'x1': 0.78, 'x2': 0.72, 'x3': 0.23, 'x4': 0.36, 'x5': 0.4, 'x6': 0.07}. [INFO 12-26 18:17:54] ax.service.ax_client: Completed trial 6 with data: {'hartmann6': (-0.16, 0.0), 'l2norm': (1.22, 0.0)}. [INFO 12-26 18:17:59] ax.service.ax_client: Generated new trial 7 with parameters {'x1': 0.77, 'x2': 0.74, 'x3': 0.19, 'x4': 0.37, 'x5': 0.46, 'x6': 0.07}. [INFO 12-26 18:17:59] ax.service.ax_client: Completed trial 7 with data: {'hartmann6': (-0.18, 0.0), 'l2norm': (1.24, 0.0)}. [INFO 12-26 18:18:04] ax.service.ax_client: Generated new trial 8 with parameters {'x1': 0.84, 'x2': 0.83, 'x3': 0.18, 'x4': 0.31, 'x5': 0.52, 'x6': 0.01}. [INFO 12-26 18:18:04] ax.service.ax_client: Completed trial 8 with data: {'hartmann6': (-0.06, 0.0), 'l2norm': (1.34, 0.0)}. [INFO 12-26 18:18:09] ax.service.ax_client: Generated new trial 9 with parameters {'x1': 0.74, 'x2': 0.74, 'x3': 0.16, 'x4': 0.42, 'x5': 0.47, 'x6': 0.09}. [INFO 12-26 18:18:09] ax.service.ax_client: Completed trial 9 with data: {'hartmann6': (-0.3, 0.0), 'l2norm': (1.23, 0.0)}. [INFO 12-26 18:18:14] ax.service.ax_client: Generated new trial 10 with parameters {'x1': 0.72, 'x2': 0.74, 'x3': 0.13, 'x4': 0.45, 'x5': 0.47, 'x6': 0.1}. [INFO 12-26 18:18:14] ax.service.ax_client: Completed trial 10 with data: {'hartmann6': (-0.4, 0.0), 'l2norm': (1.23, 0.0)}. [INFO 12-26 18:18:19] ax.service.ax_client: Generated new trial 11 with parameters {'x1': 0.7, 'x2': 0.73, 'x3': 0.07, 'x4': 0.52, 'x5': 0.45, 'x6': 0.12}. [INFO 12-26 18:18:19] ax.service.ax_client: Completed trial 11 with data: {'hartmann6': (-0.52, 0.0), 'l2norm': (1.23, 0.0)}. [INFO 12-26 18:18:23] ax.service.ax_client: Generated new trial 12 with parameters {'x1': 0.68, 'x2': 0.71, 'x3': 0.02, 'x4': 0.58, 'x5': 0.44, 'x6': 0.13}. [INFO 12-26 18:18:23] ax.service.ax_client: Completed trial 12 with data: {'hartmann6': (-0.62, 0.0), 'l2norm': (1.23, 0.0)}. [INFO 12-26 18:18:28] ax.service.ax_client: Generated new trial 13 with parameters {'x1': 0.66, 'x2': 0.69, 'x3': 0.0, 'x4': 0.66, 'x5': 0.42, 'x6': 0.13}. [INFO 12-26 18:18:28] ax.service.ax_client: Completed trial 13 with data: {'hartmann6': (-0.63, 0.0), 'l2norm': (1.24, 0.0)}. [INFO 12-26 18:18:34] ax.service.ax_client: Generated new trial 14 with parameters {'x1': 0.62, 'x2': 0.69, 'x3': 0.0, 'x4': 0.62, 'x5': 0.46, 'x6': 0.12}. [INFO 12-26 18:18:34] ax.service.ax_client: Completed trial 14 with data: {'hartmann6': (-0.96, 0.0), 'l2norm': (1.21, 0.0)}. [INFO 12-26 18:18:40] ax.service.ax_client: Generated new trial 15 with parameters {'x1': 0.56, 'x2': 0.71, 'x3': 0.0, 'x4': 0.54, 'x5': 0.51, 'x6': 0.11}. [INFO 12-26 18:18:40] ax.service.ax_client: Completed trial 15 with data: {'hartmann6': (-1.47, 0.0), 'l2norm': (1.17, 0.0)}. [INFO 12-26 18:18:45] ax.service.ax_client: Generated new trial 16 with parameters {'x1': 0.48, 'x2': 0.74, 'x3': 0.0, 'x4': 0.45, 'x5': 0.55, 'x6': 0.11}. [INFO 12-26 18:18:45] ax.service.ax_client: Completed trial 16 with data: {'hartmann6': (-1.85, 0.0), 'l2norm': (1.14, 0.0)}. [INFO 12-26 18:18:49] ax.service.ax_client: Generated new trial 17 with parameters {'x1': 0.43, 'x2': 0.79, 'x3': 0.0, 'x4': 0.43, 'x5': 0.57, 'x6': 0.17}. [INFO 12-26 18:18:49] ax.service.ax_client: Completed trial 17 with data: {'hartmann6': (-1.81, 0.0), 'l2norm': (1.16, 0.0)}. [INFO 12-26 18:18:53] ax.service.ax_client: Generated new trial 18 with parameters {'x1': 0.44, 'x2': 0.74, 'x3': 0.0, 'x4': 0.39, 'x5': 0.54, 'x6': 0.03}. [INFO 12-26 18:18:53] ax.service.ax_client: Completed trial 18 with data: {'hartmann6': (-1.8, 0.0), 'l2norm': (1.08, 0.0)}. [INFO 12-26 18:18:57] ax.service.ax_client: Generated new trial 19 with parameters {'x1': 0.46, 'x2': 0.69, 'x3': 0.0, 'x4': 0.34, 'x5': 0.58, 'x6': 0.15}. [INFO 12-26 18:18:57] ax.service.ax_client: Completed trial 19 with data: {'hartmann6': (-1.06, 0.0), 'l2norm': (1.08, 0.0)}. [INFO 12-26 18:19:01] ax.service.ax_client: Generated new trial 20 with parameters {'x1': 0.44, 'x2': 0.79, 'x3': 0.0, 'x4': 0.53, 'x5': 0.53, 'x6': 0.05}. [INFO 12-26 18:19:01] ax.service.ax_client: Completed trial 20 with data: {'hartmann6': (-2.74, 0.0), 'l2norm': (1.18, 0.0)}. [INFO 12-26 18:19:07] ax.service.ax_client: Generated new trial 21 with parameters {'x1': 0.4, 'x2': 0.84, 'x3': 0.0, 'x4': 0.6, 'x5': 0.53, 'x6': 0.01}. [INFO 12-26 18:19:07] ax.service.ax_client: Completed trial 21 with data: {'hartmann6': (-2.95, 0.0), 'l2norm': (1.23, 0.0)}. [INFO 12-26 18:19:12] ax.service.ax_client: Generated new trial 22 with parameters {'x1': 0.41, 'x2': 0.86, 'x3': 0.0, 'x4': 0.58, 'x5': 0.41, 'x6': 0.02}. [INFO 12-26 18:19:12] ax.service.ax_client: Completed trial 22 with data: {'hartmann6': (-3.03, 0.0), 'l2norm': (1.19, 0.0)}. [INFO 12-26 18:19:15] ax.service.ax_client: Generated new trial 23 with parameters {'x1': 0.43, 'x2': 0.91, 'x3': 0.0, 'x4': 0.57, 'x5': 0.44, 'x6': 0.0}. [INFO 12-26 18:19:15] ax.service.ax_client: Completed trial 23 with data: {'hartmann6': (-2.94, 0.0), 'l2norm': (1.24, 0.0)}. [INFO 12-26 18:19:16] ax.service.ax_client: Generated new trial 24 with parameters {'x1': 0.7, 'x2': 0.29, 'x3': 0.51, 'x4': 0.15, 'x5': 0.31, 'x6': 0.34}. [INFO 12-26 18:19:16] ax.service.ax_client: Completed trial 24 with data: {'hartmann6': (-0.5, 0.0), 'l2norm': (1.03, 0.0)}.
To view all trials in a data frame at any point during optimization:
ax_client.get_trials_data_frame().sort_values('trial_index')
/home/travis/virtualenv/python3.6.7/lib/python3.6/site-packages/pandas/core/reshape/merge.py:617: UserWarning: merging between different levels can give an unintended result (2 levels on the left, 1 on the right)
arm_name | hartmann6 | l2norm | trial_index | x1 | x2 | x3 | x4 | x5 | x6 | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 0_0 | -0.100419 | 1.24753 | 0 | 0.887652 | 0.495854 | 5.555373e-01 | 0.345750 | 0.175586 | 0.252082 |
3 | 1_0 | -0.0953492 | 1.49125 | 1 | 0.115817 | 0.602286 | 9.023022e-01 | 0.957067 | 0.339319 | 0.048896 |
15 | 2_0 | -0.0080322 | 1.41665 | 2 | 0.503165 | 0.017719 | 9.111743e-01 | 0.518051 | 0.806991 | 0.059571 |
18 | 3_0 | -0.0619467 | 1.19852 | 3 | 0.712777 | 0.295396 | 8.126358e-01 | 0.363529 | 0.216660 | 0.040941 |
19 | 4_0 | -0.233402 | 1.29349 | 4 | 0.758820 | 0.779445 | 1.318030e-01 | 0.384193 | 0.565871 | 0.067813 |
20 | 5_0 | -0.0208241 | 1.38577 | 5 | 0.682562 | 0.230131 | 5.993323e-01 | 0.367103 | 0.944028 | 0.127885 |
21 | 6_0 | -0.159645 | 1.2158 | 6 | 0.780880 | 0.718473 | 2.313192e-01 | 0.361868 | 0.403791 | 0.068411 |
22 | 7_0 | -0.182643 | 1.23813 | 7 | 0.773039 | 0.743114 | 1.873161e-01 | 0.368272 | 0.455619 | 0.069631 |
23 | 8_0 | -0.0609675 | 1.33768 | 8 | 0.839785 | 0.826986 | 1.770778e-01 | 0.306892 | 0.523914 | 0.014429 |
24 | 9_0 | -0.303642 | 1.23193 | 9 | 0.740044 | 0.736663 | 1.632148e-01 | 0.416059 | 0.469310 | 0.085548 |
1 | 10_0 | -0.395784 | 1.23348 | 10 | 0.722736 | 0.737833 | 1.344941e-01 | 0.450890 | 0.473603 | 0.095084 |
2 | 11_0 | -0.519869 | 1.22843 | 11 | 0.700737 | 0.726829 | 7.330071e-02 | 0.515190 | 0.453391 | 0.115631 |
4 | 12_0 | -0.622343 | 1.23082 | 12 | 0.677505 | 0.710436 | 1.771767e-02 | 0.584862 | 0.439295 | 0.125830 |
5 | 13_0 | -0.633734 | 1.24245 | 13 | 0.661389 | 0.694065 | 1.095771e-15 | 0.657023 | 0.419963 | 0.128307 |
6 | 14_0 | -0.959757 | 1.20988 | 14 | 0.617318 | 0.692469 | 4.966538e-15 | 0.616368 | 0.457510 | 0.118292 |
7 | 15_0 | -1.47252 | 1.17271 | 15 | 0.557028 | 0.707680 | 5.115012e-16 | 0.537624 | 0.512310 | 0.112530 |
8 | 16_0 | -1.84599 | 1.13893 | 16 | 0.482599 | 0.739137 | 1.764231e-06 | 0.449377 | 0.550937 | 0.111671 |
9 | 17_0 | -1.80516 | 1.16089 | 17 | 0.425241 | 0.786617 | 0.000000e+00 | 0.434792 | 0.574326 | 0.170798 |
10 | 18_0 | -1.8014 | 1.08415 | 18 | 0.435372 | 0.738600 | 8.933797e-14 | 0.389759 | 0.536066 | 0.032109 |
11 | 19_0 | -1.06182 | 1.07718 | 19 | 0.460217 | 0.688024 | 5.389991e-16 | 0.344796 | 0.576953 | 0.152906 |
12 | 20_0 | -2.7407 | 1.17841 | 20 | 0.437280 | 0.793417 | 2.384596e-16 | 0.531873 | 0.531882 | 0.046214 |
13 | 21_0 | -2.94922 | 1.23056 | 21 | 0.397572 | 0.844773 | 1.071509e-05 | 0.598931 | 0.532572 | 0.014508 |
14 | 22_0 | -3.02696 | 1.19036 | 22 | 0.406847 | 0.861775 | 7.754041e-18 | 0.581838 | 0.412194 | 0.018122 |
16 | 23_0 | -2.9374 | 1.24071 | 23 | 0.429691 | 0.914801 | 0.000000e+00 | 0.570761 | 0.438299 | 0.001608 |
17 | 24_0 | -0.501497 | 1.03135 | 24 | 0.701115 | 0.287082 | 5.055002e-01 | 0.149397 | 0.313778 | 0.336751 |
Once it's complete, we can access the best parameters found, as well as the corresponding metric values.
best_parameters, values = ax_client.get_best_parameters()
best_parameters
{'x1': 0.40684737996405934, 'x2': 0.8617753067377881, 'x3': 7.754041030229636e-18, 'x4': 0.581837661504816, 'x5': 0.41219421506871223, 'x6': 0.01812220975403133}
means, covariances = values
means
{'hartmann6': -3.026962199619188, 'l2norm': 1.1903567380968083}
For comparison, Hartmann6 minimum:
hartmann6.fmin
-3.32237
Here we arbitrarily select "x1" and "x2" as the two parameters to plot for both metrics, "hartmann6" and "l2norm".
render(ax_client.get_contour_plot())
[INFO 12-26 18:19:16] ax.service.ax_client: Retrieving contour plot with parameter 'x1' on X-axis and 'x2' on Y-axis, for metric 'hartmann6'. Ramaining parameters are affixed to the middle of their range.
We can also retrieve a contour plot for the other metric, "l2norm" –– say, we are interested in seeing the response surface for parameters "x3" and "x4" for this one.
render(ax_client.get_contour_plot(param_x="x3", param_y="x4", metric_name="l2norm"))
[INFO 12-26 18:19:18] ax.service.ax_client: Retrieving contour plot with parameter 'x3' on X-axis and 'x4' on Y-axis, for metric 'l2norm'. Ramaining parameters are affixed to the middle of their range.
Here we plot the optimization trace, showing the progression of finding the point with the optimal objective:
render(ax_client.get_optimization_trace(objective_optimum=hartmann6.fmin)) # Objective_optimum is optional.
We can serialize the state of optimization to JSON and save it to a .json
file or save it to the SQL backend. For the former:
ax_client.save_to_json_file() # For custom filepath, pass `filepath` argument.
[INFO 12-26 18:19:20] ax.service.ax_client: Saved JSON-serialized state of optimization to `ax_client_snapshot.json`.
restored_ax_client = AxClient.load_from_json_file() # For custom filepath, pass `filepath` argument.
[INFO 12-26 18:19:21] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 2 decimal points.
To store state of optimization to an SQL backend, first follow setup instructions on Ax website.
Having set up the SQL backend, pass DBSettings
to AxClient
on instantiation (note that SQLAlchemy
dependency will have to be installed – for installation, refer to optional dependencies on Ax website):
from ax.storage.sqa_store.structs import DBSettings
# URL is of the form "dialect+driver://username:password@host:port/database".
db_settings = DBSettings(url="postgresql+psycopg2://sarah:c82i94d@ocalhost:5432/foobar")
# Instead of URL, can provide a `creator function`; can specify custom encoders/decoders if necessary.
new_ax = AxClient(db_settings=db_settings)
[INFO 12-26 18:19:21] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 2 decimal points.
When valid DBSettings
are passed into AxClient
, a unique experiment name is a required argument (name
) to ax_client.create_experiment
. The state of the optimization is auto-saved any time it changes (i.e. a new trial is added or completed, etc).
To reload an optimization state later, instantiate AxClient
with the same DBSettings
and use ax_client.load_experiment_from_database(experiment_name="my_experiment")
.
Evaluation failure: should any optimization iterations fail during evaluation, log_trial_failure
will ensure that the same trial is not proposed again.
_, trial_index = ax_client.get_next_trial()
ax_client.log_trial_failure(trial_index=trial_index)
[INFO 12-26 18:19:23] ax.service.ax_client: Generated new trial 25 with parameters {'x1': 0.39, 'x2': 0.87, 'x3': 0.11, 'x4': 0.58, 'x5': 0.4, 'x6': 0.02}. [INFO 12-26 18:19:23] ax.service.ax_client: Registered failure of trial 25.
Adding custom trials: should there be need to evaluate a specific parameterization, attach_trial
will add it to the experiment.
ax_client.attach_trial(parameters={"x1": 9.0, "x2": 9.0, "x3": 9.0, "x4": 9.0, "x5": 9.0, "x6": 9.0})
[INFO 12-26 18:19:23] ax.service.ax_client: Attached custom parameterization {'x1': 9.0, 'x2': 9.0, 'x3': 9.0, 'x4': 9.0, 'x5': 9.0, 'x6': 9.0} as trial 26.
({'x1': 9.0, 'x2': 9.0, 'x3': 9.0, 'x4': 9.0, 'x5': 9.0, 'x6': 9.0}, 26)
Need to run many trials in parallel: for optimal results and optimization efficiency, we strongly recommend sequential optimization (generating a few trials, then waiting for them to be completed with evaluation data). However, if your use case needs to dispatch many trials in parallel before they are updated with data and you are running into the "All trials for current model have been generated, but not enough data has been observed to fit next model" error, instantiate AxClient
as AxClient(enforce_sequential_optimization=False)
.
Total runtime of script: 1 minutes, 36.5 seconds.