Skip to main content
Version: Next

Ask-tell experimentation with trial-level early stopping

Trial-level early stopping aims to monitor the results of expensive evaluations with timeseries-like data and terminate those that are unlikely to produce promising results prior to completing that evaluation. This reduces computational waste, and enables the same amount of resources to explore more configurations. Early stopping is useful for expensive to evaluate problems where stepwise information is available on the way to the final measurement.

Like the ask-tell tutorial we'll be minimizing the Hartmann6 function, but this time we've modified it to incorporate a new parameter tt which allows the function to produce timeseries-like data where the value returned is closer and closer to Hartmann6's true value as tt increases. At t=100t = 100 the function will simply return Hartmann6's unaltered value.

f(x,t)=hartmann6(x)log2(t/100) f(x, t) = hartmann6(x) - log_2(t/100)

While the function is synthetic, the workflow captures the intended principles for this tutorial and is similar to the process of training typical machine learning models.

Learning Objectives

  • Understand when time-series-like data can be used in an optimization experiment
  • Run a simple optimization experiment with early stopping
  • Configure details of an early stopping strategy
  • Analyze the results of the optimization

Prerequisites

Step 1: Import Necessary Modules

First, ensure you have all the necessary imports:

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy
from ax.preview.api.client import Client
from ax.preview.api.configs import ExperimentConfig, ParameterType, RangeParameterConfig

Step 2: Initialize the Client

Create an instance of the Client to manage the state of your experiment.

client = Client()

Step 3: Configure the Experiment

The Client instance can be configured with a series of Configs that define how the experiment will be run.

The Hartmann6 problem is usually evaluated on the hypercube xi(0,1)x_i \in (0, 1), so we will define six identical RangeParameterConfigs with these bounds and add these to an ExperimentConfig along with other metadata about the experiment.

You may specify additional features like parameter constraints to further refine the search space and parameter scaling to help navigate parameters with nonuniform effects. For more on configuring experiments, see this recipe.

# Define six float parameters for the Hartmann6 function
parameters = [
RangeParameterConfig(
name=f"x{i + 1}", parameter_type=ParameterType.FLOAT, bounds=(0, 1)
)
for i in range(6)
]

# Create an experiment configuration
experiment_config = ExperimentConfig(
name="hartmann6_experiment",
parameters=parameters,
# The following arguments are optional
description="Optimization of the Hartmann6 function",
owner="developer",
)

# Apply the experiment configuration to the client
client.configure_experiment(experiment_config=experiment_config)

Step 4: Configure Optimization

Now, we must configure the objective for this optimization, which we do using Client.configure_optimization. This method expects a string objective, an expression containing either a single metric to maximize, a linear combination of metrics to maximize, or a tuple of multiple metrics to jointly maximize. These expressions are parsed using SymPy. For example:

  • "score" would direct Ax to maximize a metric named score
  • "-loss" would direct Ax to Ax to minimize a metric named loss
  • "task_0 + 0.5 * task_1" would direct Ax to maximize the sum of two task scores, downweighting task_1 by a factor of 0.5
  • "score, -flops" would direct Ax to simultaneously maximize score while minimizing flops

For more information on configuring objectives and outcome constraints, see this recipe.

client.configure_optimization(objective="-hartmann6")

Step 5: Run Trials with early stopping

Here, we will configure the ask-tell loop.

We begin by defining our Hartmann6 function as written above. Remember, this is just an example problem and any Python function can be substituted here.

Then we will iteratively do the following:

  • Call client.get_next_trials to "ask" Ax for a parameterization to evaluate
  • Evaluate hartmann6_curve using those parameters in an inner loop to simulate the generation of timeseries data
  • "Tell" Ax the partial result using client.attach_data
  • Query whether the trial should be stopped via client.should_stop_trial_early
  • Stop the underperforming trial and report back to Ax that is has been stopped

This loop will run multiple trials to optimize the function.

Ax will configure an EarlyStoppingStrategy when should_stop_trial_early is called for the first time. By default Ax uses a Percentile early stopping strategy which will terminate a trial early if its performance falls below a percentile threshold when compared to other trials at the same step. Early stopping can only occur after a minimum number of progressions to prevent premature early stopping. This validates that both enough data is gathered to make a decision and there is a minimum number of completed trials with curve data; these completed trials establish a baseline.

# Hartmann6 function
def hartmann6(x1, x2, x3, x4, x5, x6):
alpha = np.array([1.0, 1.2, 3.0, 3.2])
A = np.array(
[
[10, 3, 17, 3.5, 1.7, 8],
[0.05, 10, 17, 0.1, 8, 14],
[3, 3.5, 1.7, 10, 17, 8],
[17, 8, 0.05, 10, 0.1, 14],
]
)
P = 10**-4 * np.array(
[
[1312, 1696, 5569, 124, 8283, 5886],
[2329, 4135, 8307, 3736, 1004, 9991],
[2348, 1451, 3522, 2883, 3047, 6650],
[4047, 8828, 8732, 5743, 1091, 381],
]
)

outer = 0.0
for i in range(4):
inner = 0.0
for j, x in enumerate([x1, x2, x3, x4, x5, x6]):
inner += A[i, j] * (x - P[i, j]) ** 2
outer += alpha[i] * np.exp(-inner)
return -outer


# Hartmann6 function with additional t term such that
# hartmann6(X) == hartmann6_curve(X, t=100)
def hartmann6_curve(x1, x2, x3, x4, x5, x6, t):
return hartmann6(x1, x2, x3, x4, x5, x6) - np.log2(t / 100)


(
hartmann6(0.1, 0.45, 0.8, 0.25, 0.552, 1.0),
hartmann6_curve(0.1, 0.45, 0.8, 0.25, 0.552, 1.0, 100),
)
Out:

(-0.4878737485613134, -0.4878737485613134)

maximum_progressions = 100  # Observe hartmann6_curve over 100 progressions

for _ in range(30): # Run 30 trials
trials = client.get_next_trials(maximum_trials=1)
for trial_index, parameters in trials.items():
for t in range(1, maximum_progressions + 1):
raw_data = {"hartmann6": hartmann6_curve(t=t, **parameters)}

# On the final reading call complete_trial and break, else call attach_data
if t == maximum_progressions:
client.complete_trial(
trial_index=trial_index, raw_data=raw_data, progression=t
)
break

client.attach_data(
trial_index=trial_index, raw_data=raw_data, progression=t
)

# If the trial is underperforming, stop it
if client.should_stop_trial_early(trial_index=trial_index):
client.mark_trial_early_stopped(trial_index=trial_index)
break
Out:

[INFO 03-14 05:09:31] ax.early_stopping.strategies.percentile: Early stoppinging trial 18: Trial objective value 2.3523120805459214 is worse than 50.0-th percentile (2.2222933918958923) across comparable trials..

Out:

[INFO 03-14 05:09:33] ax.early_stopping.strategies.percentile: Early stoppinging trial 19: Trial objective value 2.348788737582197 is worse than 50.0-th percentile (2.2855410647390446) across comparable trials..

Out:

[INFO 03-14 05:09:34] ax.early_stopping.strategies.percentile: Early stoppinging trial 20: Trial objective value 2.183045056882433 is worse than 50.0-th percentile (2.0847898681459576) across comparable trials..

Out:

[INFO 03-14 05:09:36] ax.early_stopping.strategies.percentile: Early stoppinging trial 21: Trial objective value 2.3445778143344795 is worse than 50.0-th percentile (2.332563197483424) across comparable trials..

Out:

[INFO 03-14 05:09:38] ax.early_stopping.strategies.percentile: Early stoppinging trial 22: Trial objective value 2.34746345690617 is worse than 50.0-th percentile (2.3445778143344795) across comparable trials..

Out:

[INFO 03-14 05:09:39] ax.early_stopping.strategies.percentile: Early stoppinging trial 23: Trial objective value 2.1656337631409786 is worse than 50.0-th percentile (2.125211815643468) across comparable trials..

Out:

[INFO 03-14 05:09:41] ax.early_stopping.strategies.percentile: Early stoppinging trial 24: Trial objective value 2.366761630772938 is worse than 50.0-th percentile (2.3445778143344795) across comparable trials..

Out:

[INFO 03-14 05:09:42] ax.early_stopping.strategies.percentile: Early stoppinging trial 25: Trial objective value 2.0343078723589256 is worse than 50.0-th percentile (1.9592589860620986) across comparable trials..

Out:

[INFO 03-14 05:09:44] ax.early_stopping.strategies.percentile: Early stoppinging trial 26: Trial objective value 2.19863836545895 is worse than 50.0-th percentile (2.1627362587918815) across comparable trials..

Out:

[INFO 03-14 05:09:46] ax.early_stopping.strategies.percentile: Early stoppinging trial 27: Trial objective value 2.3568249094213334 is worse than 50.0-th percentile (2.340359851771682) across comparable trials..

Out:

[INFO 03-14 05:09:47] ax.early_stopping.strategies.percentile: Early stoppinging trial 28: Trial objective value 2.1681458509119547 is worse than 50.0-th percentile (2.1656337631409786) across comparable trials..

Out:

[INFO 03-14 05:09:48] ax.early_stopping.strategies.percentile: Early stoppinging trial 29: Trial objective value 2.360109465443389 is worse than 50.0-th percentile (2.340359851771682) across comparable trials..

Step 6: Analyze Results

After running trials, you can analyze the results. Most commonly this means extracting the parameterization from the best performing trial you conducted.

best_parameters, prediction, index, name = client.get_best_parameterization()
print("Best Parameters:", best_parameters)
print("Prediction (mean, variance):", prediction)
Out:

Best Parameters: {'x1': 0.44143969602748834, 'x2': 0.2132962520317021, 'x3': 0.3326155761025044, 'x4': 0.2712898299798851, 'x5': 0.2751418364985259, 'x6': 0.6328874129133936}

Prediction (mean, variance): {'hartmann6': (-2.5748502808245615, 0.0032479046786129126)}

Step 7: Compute Analyses

Ax can also produce a number of analyses to help interpret the results of the experiment via client.compute_analyses. Users can manually select which analyses to run, or can allow Ax to select which would be most relevant. In this case Ax selects the following:

  • Parrellel Coordinates Plot shows which parameterizations were evaluated and what metric values were observed -- this is useful for getting a high level overview of how thoroughly the search space was explored and which regions tend to produce which outcomes
  • Interaction Analysis Plot shows which parameters have the largest affect on the function and plots the most important parameters as 1 or 2 dimensional surfaces
  • Summary lists all trials generated along with their parameterizations, observations, and miscellaneous metadata
client.compute_analyses(display=True) # By default Ax will display the AnalysisCards produced by compute_analyses

Parallel Coordinates for hartmann6

View arm parameterizations with their respective metric values

loading...

hartmann6 by progression

Observe how the metric changes as each trial progresses

loading...

Interaction Analysis for hartmann6

Understand an Experiment's data as one- or two-dimensional additive components with sparsity. Important components are visualized through slice or contour plots

loading...

Summary for hartmann6_experiment

High-level summary of the Trial-s in this Experiment

trial_indexarm_nametrial_statusgeneration_methodgeneration_nodehartmann6x1x2x3x4x5x6
000_0COMPLETEDSobolSobol-0.1069250.1773230.75540.7426660.4405120.7058990.80327
111_0COMPLETEDSobolSobol-0.1323580.9698210.0495050.1413750.5334870.2798220.352548
222_0COMPLETEDSobolSobol-0.0280890.7355930.5000310.989770.1005720.1607830.057986
333_0COMPLETEDSobolSobol-0.0218490.4118180.303950.4035630.8824610.8525180.598567
444_0COMPLETEDSobolSobol-0.0003150.3162040.7393740.0553040.7862780.9485840.885846
555_0COMPLETEDBoTorchMBM-1.247190.654420.143010.4351670.4660660.3502560.645672
666_0COMPLETEDBoTorchMBM-1.34990.6322910.0638840.432630.4217930.2848250.789439
777_0COMPLETEDBoTorchMBM-0.5658010.6437510.0355960.1544090.3762370.4170060.974714
888_0COMPLETEDBoTorchMBM-1.431450.6408350.0260370.4803170.2295870.310640.805952
999_0COMPLETEDBoTorchMBM-0.4274340.5852750.275510.49742200.1980161
101010_0COMPLETEDBoTorchMBM-0.5721030.70130300.64248400.2758410.725476
111111_0COMPLETEDBoTorchMBM-1.099630.66220600.5095230.3273810.4486670.656137
121212_0COMPLETEDBoTorchMBM-2.194190.49481400.4936840.3257830.3189690.726258
131313_0COMPLETEDBoTorchMBM-1.181380.45704800.5078470.4546610.3561540.890672
141414_0COMPLETEDBoTorchMBM-2.420080.4593850.0020230.4686990.2866510.2702830.653674
151515_0COMPLETEDBoTorchMBM-1.416630.42532900.9046640.2983630.2500270.608034
161616_0COMPLETEDBoTorchMBM-2.604790.441440.2132960.3326160.271290.2751420.632887
171717_0COMPLETEDBoTorchMBM-2.23870.42937700.1497330.2769080.2620490.633281
181818_0EARLY_STOPPEDBoTorchMBM2.352310.4318790.6875920.3837250.2795720.2754580.600176
191919_0EARLY_STOPPEDBoTorchMBM2.348790.4316630.6871030.3825160.2788030.2768680.600813
202020_0EARLY_STOPPEDBoTorchMBM2.183040.4318010.678690.3835340.2793410.275360.599902
212121_0EARLY_STOPPEDBoTorchMBM2.344580.4320710.6850540.3836120.28050.2745640.599833
222222_0EARLY_STOPPEDBoTorchMBM2.347460.4317160.6866420.3826290.2790180.2766220.600718
232323_0EARLY_STOPPEDBoTorchMBM2.165630.4320420.6734040.3841130.2802380.2741410.599344
242424_0EARLY_STOPPEDBoTorchMBM2.366760.4323290.6910180.3836110.2815850.2738670.599837
252525_0EARLY_STOPPEDBoTorchMBM2.034310.431520.6729340.3828220.277990.2769720.600532
262626_0EARLY_STOPPEDBoTorchMBM2.198640.4318710.6829760.3834420.2796830.2752620.599991
272727_0EARLY_STOPPEDBoTorchMBM2.356830.4318640.6890520.382810.2796620.276040.600567
282828_0EARLY_STOPPEDBoTorchMBM2.168150.4318230.6746690.3833970.2793090.275580.600015
292929_0EARLY_STOPPEDBoTorchMBM2.360110.4319210.6896220.3834130.2800110.275070.60003

Cross Validation for hartmann6

Out-of-sample predictions using leave-one-out CV

loading...
Out:

[<ax.analysis.plotly.plotly_analysis.PlotlyAnalysisCard at 0x7fd1cc793110>,

<ax.analysis.plotly.plotly_analysis.PlotlyAnalysisCard at 0x7fd1c3f0cc20>,

<ax.analysis.plotly.plotly_analysis.PlotlyAnalysisCard at 0x7fd1c3707710>,

<ax.analysis.plotly.plotly_analysis.PlotlyAnalysisCard at 0x7fd1c37738f0>,

<ax.analysis.analysis.AnalysisCard at 0x7fd1c3ce1010>]

Conclusion

This tutorial demonstates Ax's early stopping capabilities, which utilize timeseries-like data to monitor the results of expensive evaluations and terminate those that are unlikely to produce promising results, freeing up resources to explore more configurations. This can be used in a number of applications, and is especially useful in machine learning contexts.