Skip to main content
Version: Next

Trial-level early stopping

Trial-level early stopping aims to monitor the results of expensive evaluations with timeseries-like data and terminate those that are unlikely to produce promising results prior to completing that evaluation. This reduces computational waste, and enables the same amount of resources to explore more configurations. Early stopping is useful for expensive to evaluate problems where stepwise information is available on the way to the final measurement.

Like the Getting Started tutorial we'll be minimizing the Hartmann6 function, but this time we've modified it to incorporate a new parameter tt which allows the function to produce timeseries-like data where the value returned is closer and closer to Hartmann6's true value as tt increases. At t=100t = 100 the function will simply return Hartmann6's unaltered value.

f(x,t)=hartmann6(x)log2(t/100) f(x, t) = hartmann6(x) - log_2(t/100)

While the function is synthetic, the workflow captures the intended principles for this tutorial and is similar to the process of training typical machine learning models.

Learning Objectives

  • Understand when time-series-like data can be used in an optimization experiment
  • Run a simple optimization experiment with early stopping
  • Configure details of an early stopping strategy
  • Analyze the results of the optimization

Prerequisites

Step 1: Import Necessary Modules

First, ensure you have all the necessary imports:

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy
from ax.api.client import Client
from ax.api.configs import ExperimentConfig, ParameterType, RangeParameterConfig

Step 2: Initialize the Client

Create an instance of the Client to manage the state of your experiment.

client = Client()

Step 3: Configure the Experiment

The Client instance can be configured with a series of Configs that define how the experiment will be run.

The Hartmann6 problem is usually evaluated on the hypercube xi(0,1)x_i \in (0, 1), so we will define six identical RangeParameterConfigs with these bounds and add these to an ExperimentConfig along with other metadata about the experiment.

You may specify additional features like parameter constraints to further refine the search space and parameter scaling to help navigate parameters with nonuniform effects.

# Define six float parameters for the Hartmann6 function
parameters = [
RangeParameterConfig(
name=f"x{i + 1}", parameter_type=ParameterType.FLOAT, bounds=(0, 1)
)
for i in range(6)
]

# Create an experiment configuration
experiment_config = ExperimentConfig(
name="hartmann6_experiment",
parameters=parameters,
# The following arguments are optional
description="Optimization of the Hartmann6 function",
owner="developer",
)

# Apply the experiment configuration to the client
client.configure_experiment(experiment_config=experiment_config)

Step 4: Configure Optimization

Now, we must configure the objective for this optimization, which we do using Client.configure_optimization. This method expects a string objective, an expression containing either a single metric to maximize, a linear combination of metrics to maximize, or a tuple of multiple metrics to jointly maximize. These expressions are parsed using SymPy. For example:

  • "score" would direct Ax to maximize a metric named score
  • "-loss" would direct Ax to Ax to minimize a metric named loss
  • "task_0 + 0.5 * task_1" would direct Ax to maximize the sum of two task scores, downweighting task_1 by a factor of 0.5
  • "score, -flops" would direct Ax to simultaneously maximize score while minimizing flops

See these recipes for more information on configuring objectives and outcome constraints.

client.configure_optimization(objective="-hartmann6")

Step 5: Run Trials with early stopping

Here, we will configure the ask-tell loop.

We begin by defining our Hartmann6 function as written above. Remember, this is just an example problem and any Python function can be substituted here.

Then we will iteratively do the following:

  • Call client.get_next_trials to "ask" Ax for a parameterization to evaluate
  • Evaluate hartmann6_curve using those parameters in an inner loop to simulate the generation of timeseries data
  • "Tell" Ax the partial result using client.attach_data
  • Query whether the trial should be stopped via client.should_stop_trial_early
  • Stop the underperforming trial and report back to Ax that is has been stopped

This loop will run multiple trials to optimize the function.

Ax will configure an EarlyStoppingStrategy when should_stop_trial_early is called for the first time. By default Ax uses a Percentile early stopping strategy which will terminate a trial early if its performance falls below a percentile threshold when compared to other trials at the same step. Early stopping can only occur after a minimum number of progressions to prevent premature early stopping. This validates that both enough data is gathered to make a decision and there is a minimum number of completed trials with curve data; these completed trials establish a baseline.

# Hartmann6 function
def hartmann6(x1, x2, x3, x4, x5, x6):
alpha = np.array([1.0, 1.2, 3.0, 3.2])
A = np.array(
[
[10, 3, 17, 3.5, 1.7, 8],
[0.05, 10, 17, 0.1, 8, 14],
[3, 3.5, 1.7, 10, 17, 8],
[17, 8, 0.05, 10, 0.1, 14],
]
)
P = 10**-4 * np.array(
[
[1312, 1696, 5569, 124, 8283, 5886],
[2329, 4135, 8307, 3736, 1004, 9991],
[2348, 1451, 3522, 2883, 3047, 6650],
[4047, 8828, 8732, 5743, 1091, 381],
]
)

outer = 0.0
for i in range(4):
inner = 0.0
for j, x in enumerate([x1, x2, x3, x4, x5, x6]):
inner += A[i, j] * (x - P[i, j]) ** 2
outer += alpha[i] * np.exp(-inner)
return -outer


# Hartmann6 function with additional t term such that
# hartmann6(X) == hartmann6_curve(X, t=100)
def hartmann6_curve(x1, x2, x3, x4, x5, x6, t):
return hartmann6(x1, x2, x3, x4, x5, x6) - np.log2(t / 100)


(
hartmann6(0.1, 0.45, 0.8, 0.25, 0.552, 1.0),
hartmann6_curve(0.1, 0.45, 0.8, 0.25, 0.552, 1.0, 100),
)
Output:
(np.float64(-0.4878737485613134), np.float64(-0.4878737485613134))
maximum_progressions = 100  # Observe hartmann6_curve over 100 progressions

for _ in range(30): # Run 30 rounds of trials
trials = client.get_next_trials(max_trials=3)
for trial_index, parameters in trials.items():
for t in range(1, maximum_progressions + 1):
raw_data = {"hartmann6": hartmann6_curve(t=t, **parameters)}

# On the final reading call complete_trial and break, else call attach_data
if t == maximum_progressions:
client.complete_trial(
trial_index=trial_index, raw_data=raw_data, progression=t
)
break

client.attach_data(
trial_index=trial_index, raw_data=raw_data, progression=t
)

# If the trial is underperforming, stop it
if client.should_stop_trial_early(trial_index=trial_index):
client.mark_trial_early_stopped(trial_index=trial_index)
break
Output:
[WARNING 05-01 14:08:07] ax.api.client: 3 trials requested but only 2 could be generated.
[WARNING 05-01 14:08:12] ax.api.client: 3 trials requested but only 1 could be generated.
[INFO 05-01 14:08:24] ax.early_stopping.strategies.percentile: Early stoppinging trial 8: Trial objective value 3.3020524592596057 is worse than 50.0-th percentile (3.067494927255026) across comparable trials..
[INFO 05-01 14:08:50] ax.early_stopping.strategies.percentile: Early stoppinging trial 14: Trial objective value 3.1843578011240523 is worse than 50.0-th percentile (2.9811567055429387) across comparable trials..
[INFO 05-01 14:09:05] ax.early_stopping.strategies.percentile: Early stoppinging trial 17: Trial objective value 2.7898478307501766 is worse than 50.0-th percentile (2.7852538509440192) across comparable trials..
[INFO 05-01 14:11:41] ax.early_stopping.strategies.percentile: Early stoppinging trial 43: Trial objective value 2.340297419565159 is worse than 50.0-th percentile (1.818283529918674) across comparable trials..
[INFO 05-01 14:11:50] ax.early_stopping.strategies.percentile: Early stoppinging trial 45: Trial objective value 2.546339126890599 is worse than 50.0-th percentile (1.818283529918674) across comparable trials..
[INFO 05-01 14:11:51] ax.early_stopping.strategies.percentile: Early stoppinging trial 46: Trial objective value 3.3090945588989564 is worse than 50.0-th percentile (1.9522651938851459) across comparable trials..
[INFO 05-01 14:12:00] ax.early_stopping.strategies.percentile: Early stoppinging trial 48: Trial objective value 3.255543616934257 is worse than 50.0-th percentile (1.9522651938851459) across comparable trials..
[INFO 05-01 14:12:01] ax.early_stopping.strategies.percentile: Early stoppinging trial 49: Trial objective value 3.119780504137988 is worse than 50.0-th percentile (1.9654010050234478) across comparable trials..
[INFO 05-01 14:12:01] ax.early_stopping.strategies.percentile: Early stoppinging trial 50: Trial objective value 2.9242586454272708 is worse than 50.0-th percentile (1.9785368161617498) across comparable trials..
[INFO 05-01 14:12:04] ax.early_stopping.strategies.percentile: Early stoppinging trial 51: Trial objective value 2.2087639013533993 is worse than 50.0-th percentile (2.010904570678737) across comparable trials..
[INFO 05-01 14:12:05] ax.early_stopping.strategies.percentile: Early stoppinging trial 52: Trial objective value 2.9443241446498405 is worse than 50.0-th percentile (2.043272325195724) across comparable trials..
[INFO 05-01 14:12:05] ax.early_stopping.strategies.percentile: Early stoppinging trial 53: Trial objective value 3.3024826387547965 is worse than 50.0-th percentile (2.049400560366128) across comparable trials..
[INFO 05-01 14:12:09] ax.early_stopping.strategies.percentile: Early stoppinging trial 54: Trial objective value 3.227644958975613 is worse than 50.0-th percentile (2.0555287955365316) across comparable trials..
[INFO 05-01 14:12:09] ax.early_stopping.strategies.percentile: Early stoppinging trial 55: Trial objective value 3.2666664647187265 is worse than 50.0-th percentile (2.1321463484449654) across comparable trials..
[INFO 05-01 14:12:21] ax.early_stopping.strategies.percentile: Early stoppinging trial 57: Trial objective value 3.195498712329707 is worse than 50.0-th percentile (2.1321463484449654) across comparable trials..
[INFO 05-01 14:12:21] ax.early_stopping.strategies.percentile: Early stoppinging trial 58: Trial objective value 3.3203448260753263 is worse than 50.0-th percentile (2.2087639013533993) across comparable trials..
[INFO 05-01 14:12:33] ax.early_stopping.strategies.percentile: Early stoppinging trial 60: Trial objective value 3.11623347413749 is worse than 50.0-th percentile (2.2087639013533993) across comparable trials..
[INFO 05-01 14:12:33] ax.early_stopping.strategies.percentile: Early stoppinging trial 61: Trial objective value 3.315505697068359 is worse than 50.0-th percentile (2.2534316993075048) across comparable trials..
[INFO 05-01 14:12:33] ax.early_stopping.strategies.percentile: Early stoppinging trial 62: Trial objective value 3.1960121066144054 is worse than 50.0-th percentile (2.2980994972616102) across comparable trials..
[INFO 05-01 14:12:39] ax.early_stopping.strategies.percentile: Early stoppinging trial 63: Trial objective value 3.214376091429819 is worse than 50.0-th percentile (2.3191984584133847) across comparable trials..
[INFO 05-01 14:12:39] ax.early_stopping.strategies.percentile: Early stoppinging trial 64: Trial objective value 2.903086578879302 is worse than 50.0-th percentile (2.340297419565159) across comparable trials..
[INFO 05-01 14:12:39] ax.early_stopping.strategies.percentile: Early stoppinging trial 65: Trial objective value 3.2915774474582555 is worse than 50.0-th percentile (2.443318273227879) across comparable trials..
[INFO 05-01 14:12:44] ax.early_stopping.strategies.percentile: Early stoppinging trial 66: Trial objective value 3.2273108852226855 is worse than 50.0-th percentile (2.546339126890599) across comparable trials..
[INFO 05-01 14:12:44] ax.early_stopping.strategies.percentile: Early stoppinging trial 67: Trial objective value 3.2604541690083875 is worse than 50.0-th percentile (2.575241263581935) across comparable trials..
[INFO 05-01 14:12:44] ax.early_stopping.strategies.percentile: Early stoppinging trial 68: Trial objective value 3.262490315515125 is worse than 50.0-th percentile (2.6041434002732715) across comparable trials..
[INFO 05-01 14:12:49] ax.early_stopping.strategies.percentile: Early stoppinging trial 69: Trial objective value 2.845446442200189 is worse than 50.0-th percentile (2.6661652590229474) across comparable trials..
[INFO 05-01 14:12:49] ax.early_stopping.strategies.percentile: Early stoppinging trial 70: Trial objective value 3.2687354459695763 is worse than 50.0-th percentile (2.728187117772624) across comparable trials..
[INFO 05-01 14:12:50] ax.early_stopping.strategies.percentile: Early stoppinging trial 71: Trial objective value 3.315054574322863 is worse than 50.0-th percentile (2.7330862699252165) across comparable trials..
[INFO 05-01 14:12:54] ax.early_stopping.strategies.percentile: Early stoppinging trial 72: Trial objective value 3.117131889533674 is worse than 50.0-th percentile (2.7379854220778093) across comparable trials..
[INFO 05-01 14:12:54] ax.early_stopping.strategies.percentile: Early stoppinging trial 73: Trial objective value 2.5343024271185497 is worse than 50.0-th percentile (1.3818329544871384) across comparable trials..
[INFO 05-01 14:12:55] ax.early_stopping.strategies.percentile: Early stoppinging trial 74: Trial objective value 3.307960409878605 is worse than 50.0-th percentile (2.7379854220778093) across comparable trials..
[INFO 05-01 14:12:59] ax.early_stopping.strategies.percentile: Early stoppinging trial 75: Trial objective value 3.2612070218019285 is worse than 50.0-th percentile (2.7569244019586634) across comparable trials..
[INFO 05-01 14:12:59] ax.early_stopping.strategies.percentile: Early stoppinging trial 76: Trial objective value 3.2651697941717623 is worse than 50.0-th percentile (2.7758633818395175) across comparable trials..
[INFO 05-01 14:12:59] ax.early_stopping.strategies.percentile: Early stoppinging trial 77: Trial objective value 3.3176066847464383 is worse than 50.0-th percentile (2.7962382425123233) across comparable trials..
[INFO 05-01 14:13:04] ax.early_stopping.strategies.percentile: Early stoppinging trial 78: Trial objective value 3.135265196533661 is worse than 50.0-th percentile (2.816613103185129) across comparable trials..
[INFO 05-01 14:13:04] ax.early_stopping.strategies.percentile: Early stoppinging trial 79: Trial objective value 3.2937748689336184 is worse than 50.0-th percentile (2.831029772692659) across comparable trials..
[INFO 05-01 14:13:05] ax.early_stopping.strategies.percentile: Early stoppinging trial 80: Trial objective value 3.264127152866202 is worse than 50.0-th percentile (2.845446442200189) across comparable trials..
[INFO 05-01 14:13:10] ax.early_stopping.strategies.percentile: Early stoppinging trial 81: Trial objective value 3.24476601845618 is worse than 50.0-th percentile (2.874266510539745) across comparable trials..
[INFO 05-01 14:13:10] ax.early_stopping.strategies.percentile: Early stoppinging trial 82: Trial objective value 3.235004279487177 is worse than 50.0-th percentile (2.903086578879302) across comparable trials..
[INFO 05-01 14:13:10] ax.early_stopping.strategies.percentile: Early stoppinging trial 83: Trial objective value 3.280956339321412 is worse than 50.0-th percentile (2.9106249868835494) across comparable trials..
[INFO 05-01 14:13:26] ax.early_stopping.strategies.percentile: Early stoppinging trial 85: Trial objective value 3.1655735192086656 is worse than 50.0-th percentile (2.9106249868835494) across comparable trials..
[INFO 05-01 14:13:26] ax.early_stopping.strategies.percentile: Early stoppinging trial 86: Trial objective value 3.075703403605768 is worse than 50.0-th percentile (2.9181633948877965) across comparable trials..

Step 6: Analyze Results

After running trials, you can analyze the results. Most commonly this means extracting the parameterization from the best performing trial you conducted.

best_parameters, prediction, index, name = client.get_best_parameterization()
print("Best Parameters:", best_parameters)
print("Prediction (mean, variance):", prediction)
Output:
Best Parameters: {'x1': 0.2095005222513403, 'x2': 0.162346110149046, 'x3': 0.443268579214745, 'x4': 0.27405431544017445, 'x5': 0.3047073086800985, 'x6': 0.6475183901075783}
Prediction (mean, variance): {'hartmann6': (np.float64(-3.2744844006999516), np.float64(0.0003652795911015538))}

Step 7: Compute Analyses

Ax can also produce a number of analyses to help interpret the results of the experiment via client.compute_analyses. Users can manually select which analyses to run, or can allow Ax to select which would be most relevant. In this case Ax selects the following:

  • Parrellel Coordinates Plot shows which parameterizations were evaluated and what metric values were observed -- this is useful for getting a high level overview of how thoroughly the search space was explored and which regions tend to produce which outcomes
  • Progression Plot shows each partial observation observed by Ax for each trial in a timeseries
  • Sensitivity Analysis Plot shows which parameters have the largest affect on the objective using Sobol Indicies
  • Slice Plot shows how the model predicts a single parameter effects the objective along with a confidence interval
  • Contour Plot shows how the model predicts a pair of parameters effects the objective as a 2D surface
  • Summary lists all trials generated along with their parameterizations, observations, and miscellaneous metadata
  • Cross Validation helps to visualize how well the surrogate model is able to predict out of sample points
# display=True instructs Ax to sort then render the resulting analyses
cards = client.compute_analyses(display=True)

Parallel Coordinates for hartmann6

The parallel coordinates plot displays multi-dimensional data by representing each parameter as a parallel axis. This plot helps in assessing how thoroughly the search space has been explored and in identifying patterns or clusterings associated with high-performing (good) or low-performing (bad) arms. By tracing lines across the axes, one can observe correlations and interactions between parameters, gaining insights into the relationships that contribute to the success or failure of different configurations within the experiment.

loading...

Sensitivity Analysis for hartmann6

Understand how each parameter affects hartmann6 according to a second-order sensitivity analysis.

loading...

x5 vs. hartmann6

The slice plot provides a one-dimensional view of predicted outcomes for hartmann6 as a function of a single parameter, while keeping all other parameters fixed at their status_quo value (or mean value if status_quo is unavailable). This visualization helps in understanding the sensitivity and impact of changes in the selected parameter on the predicted metric outcomes.

loading...

x4, x5 vs. hartmann6

The contour plot visualizes the predicted outcomes for hartmann6 across a two-dimensional parameter space, with other parameters held fixed at their status_quo value (or mean value if status_quo is unavailable). This plot helps in identifying regions of optimal performance and understanding how changes in the selected parameters influence the predicted outcomes. Contour lines represent levels of constant predicted values, providing insights into the gradient and potential optima within the parameter space.

loading...

x4 vs. hartmann6

The slice plot provides a one-dimensional view of predicted outcomes for hartmann6 as a function of a single parameter, while keeping all other parameters fixed at their status_quo value (or mean value if status_quo is unavailable). This visualization helps in understanding the sensitivity and impact of changes in the selected parameter on the predicted metric outcomes.

loading...

x2, x5 vs. hartmann6

The contour plot visualizes the predicted outcomes for hartmann6 across a two-dimensional parameter space, with other parameters held fixed at their status_quo value (or mean value if status_quo is unavailable). This plot helps in identifying regions of optimal performance and understanding how changes in the selected parameters influence the predicted outcomes. Contour lines represent levels of constant predicted values, providing insights into the gradient and potential optima within the parameter space.

loading...

x1, x5 vs. hartmann6

The contour plot visualizes the predicted outcomes for hartmann6 across a two-dimensional parameter space, with other parameters held fixed at their status_quo value (or mean value if status_quo is unavailable). This plot helps in identifying regions of optimal performance and understanding how changes in the selected parameters influence the predicted outcomes. Contour lines represent levels of constant predicted values, providing insights into the gradient and potential optima within the parameter space.

loading...

Summary for hartmann6_experiment

High-level summary of the Trial-s in this Experiment

trial_indexarm_nametrial_statusgeneration_nodehartmann6x1x2x3x4x5x6
000_0COMPLETEDCenterOfSearchSpace-0.5053150.5000000.5000000.5000000.5000000.5000000.500000
111_0COMPLETEDSobol-0.5460650.6024750.3527200.6669980.7703820.3098150.914641
222_0COMPLETEDSobol-0.0403020.0514390.7275450.0530880.4851820.8649080.413376
333_0COMPLETEDSobol-0.0244420.3555640.0997610.9051020.6412680.7247470.089266
444_0COMPLETEDSobol-0.0535050.7972250.9746450.2534750.1139070.1707100.582716
....................................
828282_0EARLY_STOPPEDMBM3.2350040.0000000.2940630.0000000.6350000.4146520.193215
838383_0EARLY_STOPPEDMBM3.2809560.5246160.0000000.0000000.1396090.2739130.000000
848484_0COMPLETEDMBM-3.2380420.2191870.1178580.4231520.2782500.3018770.617018
858585_0EARLY_STOPPEDMBM3.1655741.0000000.0000000.6826840.0000000.0270741.000000
868686_0EARLY_STOPPEDMBM3.0757030.7712280.0000000.5971010.0000000.0980260.661026

hartmann6 by progression

The progression plot tracks the evolution of each metric over the course of the experiment. This visualization is typically used to monitor the improvement of metrics over Trial iterations, but can also be useful in informing decisions about early stopping for Trials.

loading...

Cross Validation for hartmann6

The cross-validation plot displays the model fit for each metric in the experiment. It employs a leave-one-out approach, where the model is trained on all data except one sample, which is used for validation. The plot shows the predicted outcome for the validation set on the y-axis against its actual value on the x-axis. Points that align closely with the dotted diagonal line indicate a strong model fit, signifying accurate predictions. Additionally, the plot includes 95% confidence intervals that provide insight into the noise in observations and the uncertainty in model predictions. A horizontal, flat line of predictions indicates that the model has not picked up on sufficient signal in the data, and instead is just predicting the mean.

loading...

Conclusion

This tutorial demonstates Ax's early stopping capabilities, which utilize timeseries-like data to monitor the results of expensive evaluations and terminate those that are unlikely to produce promising results, freeing up resources to explore more configurations. This can be used in a number of applications, and is especially useful in machine learning contexts.