The Developer API is suitable when the user wants maximal customization of the optimization loop. This tutorial demonstrates optimization of a Hartmann6 function using the Experiment
construct. In this example, trials will be evaluated synchronously.
from ax import (
ComparisonOp,
ParameterType,
RangeParameter,
ChoiceParameter,
FixedParameter,
SearchSpace,
Experiment,
OutcomeConstraint,
OrderConstraint,
SumConstraint,
OptimizationConfig,
Objective,
Metric,
)
from ax.utils.notebook.plotting import render, init_notebook_plotting
init_notebook_plotting()
[INFO 12-16 16:42:33] ax.utils.notebook.plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.
First, we define a search space, which defines the type and allowed range for the parameters.
hartmann_search_space = SearchSpace(
parameters=[
RangeParameter(
name=f"x{i}", parameter_type=ParameterType.FLOAT, lower=0.0, upper=1.0
)
for i in range(6)
]
)
Note that there are two other parameter classes, FixedParameter and ChoiceParameter. Although we won't use these in this example, you can create them as follows.
choice_param = ChoiceParameter(name="choice", values=["foo", "bar"], parameter_type=ParameterType.STRING)
fixed_param = FixedParameter(name="fixed", value=[True], parameter_type=ParameterType.BOOL)
/home/runner/work/Ax/Ax/ax/core/parameter.py:468: UserWarning: `is_ordered` is not specified for `ChoiceParameter` "choice". Defaulting to `False` for parameters of `ParameterType` STRING. To override this behavior (or avoid this warning), specify `is_ordered` during `ChoiceParameter` construction. /home/runner/work/Ax/Ax/ax/core/parameter.py:468: UserWarning: `sort_values` is not specified for `ChoiceParameter` "choice". Defaulting to `False` for parameters of `ParameterType` STRING. To override this behavior (or avoid this warning), specify `sort_values` during `ChoiceParameter` construction.
Sum constraints enforce that the sum of a set of parameters is greater or less than some bound, and order constraints enforce that one parameter is smaller than the other. We won't use these either, but see two examples below.
sum_constraint = SumConstraint(
parameters=[hartmann_search_space.parameters['x0'], hartmann_search_space.parameters['x1']],
is_upper_bound=True,
bound=5.0,
)
order_constraint = OrderConstraint(
lower_parameter = hartmann_search_space.parameters['x0'],
upper_parameter = hartmann_search_space.parameters['x1'],
)
Second, we define the optimization_config
with an objective
and outcome_constraints
.
When doing the optimization, we will find points that minimize the objective while obeying the constraints (which in this case means l2norm < 1.25
).
Note: we are using Hartmann6Metric
and L2NormMetric
here, which have built in evaluation functions for testing. For creating your own cutom metrics, see 8. Defining custom metrics.
from ax.metrics.l2norm import L2NormMetric
from ax.metrics.hartmann6 import Hartmann6Metric
param_names = [f"x{i}" for i in range(6)]
optimization_config = OptimizationConfig(
objective = Objective(
metric=Hartmann6Metric(name="hartmann6", param_names=param_names),
minimize=True,
),
outcome_constraints=[
OutcomeConstraint(
metric=L2NormMetric(
name="l2norm", param_names=param_names, noise_sd=0.2
),
op=ComparisonOp.LEQ,
bound=1.25,
relative=False,
)
],
)
Before an experiment can collect data, it must have a Runner attached. A runner handles the deployment of trials. A trial must be "run" before it can be evaluated.
Here, we have a dummy runner that does nothing. In practice, a runner might be in charge of pushing an experiment to production.
The only method that needs to be defined for runner subclasses is run, which performs any necessary deployment logic, and returns a dictionary of resulting metadata. This metadata can later be accessed through the trial's run_metadata
property.
from ax import Runner
class MyRunner(Runner):
def run(self, trial):
trial_metadata = {"name": str(trial.index)}
return trial_metadata
Next, we make an Experiment
with our search space, runner, and optimization config.
exp = Experiment(
name="test_hartmann",
search_space=hartmann_search_space,
optimization_config=optimization_config,
runner=MyRunner(),
)
Run the optimization using the settings defined on the experiment. We will create 5 random sobol points for exploration followed by 15 points generated using the GPEI optimizer.
Instead of a member of the Models
enum to produce generator runs, users can leverage a GenerationStrategy
. See the Generation Strategy Tutorial for more info.
from ax.modelbridge.registry import Models
NUM_SOBOL_TRIALS = 5
NUM_BOTORCH_TRIALS = 15
print(f"Running Sobol initialization trials...")
sobol = Models.SOBOL(search_space=exp.search_space)
for i in range(NUM_SOBOL_TRIALS):
# Produce a GeneratorRun from the model, which contains proposed arm(s) and other metadata
generator_run = sobol.gen(n=1)
# Add generator run to a trial to make it part of the experiment and evaluate arm(s) in it
trial = exp.new_trial(generator_run=generator_run)
# Start trial run to evaluate arm(s) in the trial
trial.run()
# Mark trial as completed to record when a trial run is completed
# and enable fetching of data for metrics on the experiment
# (by default, trials must be completed before metrics can fetch their data,
# unless a metric is explicitly configured otherwise)
trial.mark_completed()
for i in range(NUM_BOTORCH_TRIALS):
print(
f"Running GP+EI optimization trial {i + NUM_SOBOL_TRIALS + 1}/{NUM_SOBOL_TRIALS + NUM_BOTORCH_TRIALS}..."
)
# Reinitialize GP+EI model at each step with updated data.
gpei = Models.BOTORCH(experiment=exp, data=exp.fetch_data())
generator_run = gpei.gen(n=1)
trial = exp.new_trial(generator_run=generator_run)
trial.run()
trial.mark_completed()
print("Done!")
Running Sobol initialization trials... Running GP+EI optimization trial 6/20... Running GP+EI optimization trial 7/20... Running GP+EI optimization trial 8/20... Running GP+EI optimization trial 9/20... Running GP+EI optimization trial 10/20... Running GP+EI optimization trial 11/20... Running GP+EI optimization trial 12/20... Running GP+EI optimization trial 13/20... Running GP+EI optimization trial 14/20... Running GP+EI optimization trial 15/20... Running GP+EI optimization trial 16/20... Running GP+EI optimization trial 17/20... Running GP+EI optimization trial 18/20... Running GP+EI optimization trial 19/20... Running GP+EI optimization trial 20/20... Done!
Now we can inspect the Experiment
's data by calling fetch_data()
, which retrieves evaluation data for all trials of the experiment.
To fetch trial data, we need to run it and mark it completed. For most metrics in Ax, data is only available once the status of the trial is COMPLETED
, since in real-worlds scenarios, metrics can typically only be fetched after the trial finished running.
NOTE: Metrics classes may implement the is_available_while_running
method. When this method returns True
, data is available when trials are either RUNNING
or COMPLETED
. This can be used to obtain intermediate results from A/B test trials and other online experiments, or when metric values are available immediately, like in the case of synthetic problem metrics.
We can also use the fetch_trials_data
function to get evaluation data for a specific trials in the experiment, like so:
trial_data = exp.fetch_trials_data([NUM_SOBOL_TRIALS + NUM_BOTORCH_TRIALS - 1])
trial_data.df
arm_name | metric_name | mean | sem | trial_index | n | frac_nonnull | |
---|---|---|---|---|---|---|---|
0 | 19_0 | l2norm | 1.180978 | 0.2 | 19 | 10000 | 1.180978 |
1 | 19_0 | hartmann6 | -2.368136 | 0.0 | 19 | 10000 | -2.368136 |
The below call to exp.fetch_data()
also attaches data to the last trial, which because of the way we looped through Botorch trials in 5. Perform Optimization, would otherwise not have data attached. This is necessary to get objective_means
in 7. Plot results.
exp.fetch_data().df
arm_name | metric_name | mean | sem | trial_index | n | frac_nonnull | |
---|---|---|---|---|---|---|---|
0 | 0_0 | l2norm | 1.301572 | 0.2 | 0 | 10000 | 1.301572 |
1 | 1_0 | l2norm | 1.380915 | 0.2 | 1 | 10000 | 1.380915 |
2 | 2_0 | l2norm | 1.326524 | 0.2 | 2 | 10000 | 1.326524 |
3 | 3_0 | l2norm | 1.959614 | 0.2 | 3 | 10000 | 1.959614 |
4 | 4_0 | l2norm | 2.030366 | 0.2 | 4 | 10000 | 2.030366 |
5 | 5_0 | l2norm | 0.951150 | 0.2 | 5 | 10000 | 0.951150 |
6 | 6_0 | l2norm | 0.170662 | 0.2 | 6 | 10000 | 0.170662 |
7 | 7_0 | l2norm | 0.512136 | 0.2 | 7 | 10000 | 0.512136 |
8 | 8_0 | l2norm | 0.708374 | 0.2 | 8 | 10000 | 0.708374 |
9 | 9_0 | l2norm | 0.818011 | 0.2 | 9 | 10000 | 0.818011 |
10 | 10_0 | l2norm | 0.548412 | 0.2 | 10 | 10000 | 0.548412 |
11 | 11_0 | l2norm | 1.011652 | 0.2 | 11 | 10000 | 1.011652 |
12 | 12_0 | l2norm | 1.194851 | 0.2 | 12 | 10000 | 1.194851 |
13 | 13_0 | l2norm | 1.004335 | 0.2 | 13 | 10000 | 1.004335 |
14 | 14_0 | l2norm | 1.047134 | 0.2 | 14 | 10000 | 1.047134 |
15 | 15_0 | l2norm | 1.051354 | 0.2 | 15 | 10000 | 1.051354 |
16 | 16_0 | l2norm | 0.836373 | 0.2 | 16 | 10000 | 0.836373 |
17 | 17_0 | l2norm | 0.627839 | 0.2 | 17 | 10000 | 0.627839 |
18 | 18_0 | l2norm | 1.190575 | 0.2 | 18 | 10000 | 1.190575 |
19 | 19_0 | l2norm | 1.014486 | 0.2 | 19 | 10000 | 1.014486 |
20 | 0_0 | hartmann6 | -0.328943 | 0.0 | 0 | 10000 | -0.328943 |
21 | 1_0 | hartmann6 | -0.032830 | 0.0 | 1 | 10000 | -0.032830 |
22 | 2_0 | hartmann6 | -0.282132 | 0.0 | 2 | 10000 | -0.282132 |
23 | 3_0 | hartmann6 | -0.068417 | 0.0 | 3 | 10000 | -0.068417 |
24 | 4_0 | hartmann6 | -0.154105 | 0.0 | 4 | 10000 | -0.154105 |
25 | 5_0 | hartmann6 | -0.693260 | 0.0 | 5 | 10000 | -0.693260 |
26 | 6_0 | hartmann6 | -0.740225 | 0.0 | 6 | 10000 | -0.740225 |
27 | 7_0 | hartmann6 | -1.015359 | 0.0 | 7 | 10000 | -1.015359 |
28 | 8_0 | hartmann6 | -1.106793 | 0.0 | 8 | 10000 | -1.106793 |
29 | 9_0 | hartmann6 | -1.154429 | 0.0 | 9 | 10000 | -1.154429 |
30 | 10_0 | hartmann6 | -1.243009 | 0.0 | 10 | 10000 | -1.243009 |
31 | 11_0 | hartmann6 | -0.492035 | 0.0 | 11 | 10000 | -0.492035 |
32 | 12_0 | hartmann6 | -1.524105 | 0.0 | 12 | 10000 | -1.524105 |
33 | 13_0 | hartmann6 | -1.771871 | 0.0 | 13 | 10000 | -1.771871 |
34 | 14_0 | hartmann6 | -2.023056 | 0.0 | 14 | 10000 | -2.023056 |
35 | 15_0 | hartmann6 | -2.230798 | 0.0 | 15 | 10000 | -2.230798 |
36 | 16_0 | hartmann6 | -2.148667 | 0.0 | 16 | 10000 | -2.148667 |
37 | 17_0 | hartmann6 | -2.218789 | 0.0 | 17 | 10000 | -2.218789 |
38 | 18_0 | hartmann6 | -1.986113 | 0.0 | 18 | 10000 | -1.986113 |
39 | 19_0 | hartmann6 | -2.368136 | 0.0 | 19 | 10000 | -2.368136 |
Now we can plot the results of our optimization:
import numpy as np
from ax.plot.trace import optimization_trace_single_method
# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple
# optimization runs, so we wrap out best objectives array in another array.
objective_means = np.array([[trial.objective_mean for trial in exp.trials.values()]])
best_objective_plot = optimization_trace_single_method(
y=np.minimum.accumulate(objective_means, axis=1),
optimum=-3.32237, # Known minimum objective for Hartmann6 function.
)
render(best_objective_plot)
In order to perform an optimization, we also need to define an optimization config for the experiment. An optimization config is composed of an objective metric to be minimized or maximized in the experiment, and optionally a set of outcome constraints that place restrictions on how other metrics can be moved by the experiment.
In order to define an objective or outcome constraint, we first need to subclass Metric. Metrics are used to evaluate trials, which are individual steps of the experiment sequence. Each trial contains one or more arms for which we will collect data at the same time.
Our custom metric(s) will determine how, given a trial, to compute the mean and SEM of each of the trial's arms.
The only method that needs to be defined for most metric subclasses is fetch_trial_data
, which defines how a single trial is evaluated, and returns a pandas dataframe.
The is_available_while_running
method is optional and returns a boolean, specifying whether the trial data can be fetched before the trial is complete. See 6. Inspect trials' data for more details.
from ax import Data
import pandas as pd
class BoothMetric(Metric):
def fetch_trial_data(self, trial):
records = []
for arm_name, arm in trial.arms_by_name.items():
params = arm.parameters
records.append({
"arm_name": arm_name,
"metric_name": self.name,
"trial_index": trial.index,
# in practice, the mean and sem will be looked up based on trial metadata
# but for this tutorial we will calculate them
"mean": (params["x1"] + 2*params["x2"] - 7)**2 + (2*params["x1"] + params["x2"] - 5)**2,
"sem": 0.0,
})
return Data(df=pd.DataFrame.from_records(records))
def is_available_while_running(self) -> bool:
return True
At any point, we can also save our experiment to a JSON file. To ensure that our custom metrics and runner are saved properly, we first need to register them.
from ax.storage.metric_registry import register_metric
from ax.storage.runner_registry import register_runner
register_metric(BoothMetric)
register_metric(L2NormMetric)
register_metric(Hartmann6Metric)
register_runner(MyRunner)
from ax.storage.json_store.load import load_experiment
from ax.storage.json_store.save import save_experiment
save_experiment(exp, "experiment.json")
loaded_experiment = load_experiment("experiment.json")
To save our experiment to SQL, we must first specify a connection to a database and create all necessary tables.
from ax.storage.sqa_store.db import init_engine_and_session_factory,get_engine, create_all_tables
from ax.storage.sqa_store.load import load_experiment
from ax.storage.sqa_store.save import save_experiment
init_engine_and_session_factory(url='sqlite:///foo3.db')
engine = get_engine()
create_all_tables(engine)
exp.name = "new"
save_experiment(exp)
load_experiment(exp.name)
/home/runner/work/Ax/Ax/ax/storage/sqa_store/load.py:230: SAWarning: TypeDecorator JSONEncodedText() will not produce a cache key because the ``cache_ok`` attribute is not set to True. This can have significant performance implications including some performance degradations in comparison to prior SQLAlchemy versions. Set this attribute to True if this type object's state is safe to use in a cache key, or False to disable this warning. (Background on this error at: https://sqlalche.me/e/14/cprf)
Experiment(new)
Total runtime of script: 2 minutes, 39.28 seconds.