Trial Evaluation
There are 3 paradigms for evaluating trials in Ax. Note: ensure that you are using the appropriate type of trials for your experiment, before proceeding to trial evaluation.
[RECOMMENDED] Service API
The Service API AxClient
exposes
get_next_trial
,
as well as
complete_trial
.
The user is responsible for evaluating the trial parameters and passing the
results to
complete_trial
.
...
for i in range(25):
parameters, trial_index = ax_client.get_next_trial()
raw_data = evaluate_trial(parameters)
ax_client.complete_trial(trial_index=trial_index, raw_data=raw_data)
Evaluating Trial Parameters
In the Service API, the
complete_trial
method requires raw_data
evaluated from the parameters suggested by
get_next_trial
.
The data can be in the form of:
- A dictionary of metric names to tuples of (mean and SEM)
- A single (mean, SEM) tuple
- A single mean
In the second case, Ax will assume that the mean and the SEM are for the experiment objective (if the evaluations are noiseless, simply provide an SEM of 0.0). In the third case, Ax will assume that observations are corrupted by Gaussian noise with zero mean and unknown SEM, and infer the SEM from the data (this is equivalent to specifying an SEM of None). Note that if the observation noise is non-zero (either provided or inferred), the "best arm" suggested by Ax may not always be the one whose evaluation returned the best observed value (as the "best arm" is selected based on the model-predicted mean).
For example, this evaluation function computes mean and SEM for
Hartmann6 function and for the
L2-norm. We return 0.0
for SEM since the observations are noiseless:
from ax.utils.measurement.synthetic_functions import hartmann6
def hartmann_evaluation_function(parameterization):
x = np.array([parameterization.get(f"x{i+1}") for i in range(6)])
# Standard error is 0 since we are computing a synthetic function.
return {"hartmann6": (hartmann6(x), 0.0), "l2norm": (np.sqrt((x ** 2).sum()), 0.0)}
This function computes just the objective mean and SEM, assuming the Branin function is the objective of the experiment:
from ax.utils.measurement.synthetic_functions import branin
def branin_evaluation_function(parameterization):
# Standard error is 0 since we are computing a synthetic function.
return (branin(parameterization.get("x1"), parameterization.get("x2")), 0.0)
Alternatively, if the SEM is unknown, we could use the following form:
lambda parameterization: branin(parameterization.get("x1"), parameterization.get("x2"))
This is equivalent to returning None
for the SEM:
from ax.utils.measurement.synthetic_functions import branin
def branin_evaluation_function_unknown_sem(parameterization):
return (branin(parameterization.get("x1"), parameterization.get("x2")), None)
Loop API
The optimize
function
requires an evaluation_function
, which accepts parameters and returns raw data
in the format described above. It can also accept a weight
parameter, a
nullable float
representing the fraction of available data on which the
parameterization should be evaluated. For example, this could be a downsampling
rate in case of hyperparameter optimization (what portion of data the ML model
should be trained on for evaluation) or the percentage of users exposed to a
given configuration in A/B testing. This weight is not used in unweighted
experiments and defaults to None
.
Developer API
The Developer API is supported by the
Experiment
class. In this
paradigm, the user specifies:
Runner
: Defines how to deploy the experiment.- List of
Metrics
: Each defines how to compute/fetch data for a given objective or outcome.
The experiment requires a generator_run
to create a new trial or batch trial.
A generator run can be generated by a model. The trial then has its own run
and mark_complete
methods.
...
sobol = Models.SOBOL(exp.search_space)
for i in range(5):
trial = exp.new_trial(generator_run=sobol.gen(1))
trial.run()
trial.mark_completed()
for i in range(15):
gpei = Models.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data())
generator_run = gpei.gen(1)
trial = exp.new_trial(generator_run=generator_run)
trial.run()
trial.mark_completed()
Custom Metrics
Similar to a trial evaluation in the Service API, a custom metric computes a
mean and SEM for each arm of a trial. However, the metric's fetch_trial_data
method will be called automatically by the experiment's
fetch_data
method.
If there are multiple objectives or outcomes that need to be optimized for, each
needs its own metric.
class MyMetric(Metric):
def fetch_trial_data(self, trial):
records = []
for arm_name, arm in trial.arms_by_name.items():
params = arm.parameters
records.append({
"arm_name": arm_name,
"metric_name": self.name,
"mean": self.foo(params["x1"], params["x2"]),
"sem": 0.0,
"trial_index": trial.index,
})
return Data(df=pd.DataFrame.from_records(records))
Adding Your Own Runner
In order to control how the experiment is deployed, you can add your own runner.
To do so, subclass Runner
and
implement the run
method and
staging_required
property.
The run
method accepts a
Trial
and returns a JSON-serializable
dictionary of any necessary tracking info to fetch data later from this external
system. A unique identifier or name for this trial in the external system should
be stored in this dictionary with the key "name"
, and this can later be
accessed via trial.deployed_name
.
The
staging_required
indicates whether the trial requires an intermediate staging period before
evaluation begins. This property returns False by default.
An example implementation is given below:
from foo_system import deploy_to_foo
from ax import Runner
class FooRunner(Runner):
def __init__(self, foo_param):
self.foo_param = foo_param
def run(self, trial):
name_to_params = {
arm.name: arm.parameters for arm in trial.arms
}
run_metadata = deploy_to_foo(self.foo_param, name_to_params)
return run_metadata
@property
def staging_required(self):
return False
This is then invoked by calling:
exp = Experiment(...)
exp.runner = FooRunner(foo_param="foo")
trial = exp.new_batch_trial()
# This calls runner's run method and stores metadata output
# in the trial.run_metadata field
trial.run()