Skip to main content
Version: 1.0.0

Ax for AutoML with scikit-learn

Automated machine learning (AutoML) encompasses a large class of problems related to automating time-consuming and labor-intensive aspects of developing ML models. Adaptive experimentation is a natural fit for solving many AutoML tasks, which are often iterative in nature and can involve many expensive trial evaluations.

In this tutorial we will use Ax for hyperparameter optimization (HPO), a common AutoML task in which a model's hyperparameters are adjusted to improve model performance. Hyperparameters refer to the parameters which are set prior to model training or fitting, rather than parameters being learned from data. Traditionally, ML engineers use a combination of domain knowledge, intuition, and manual experimentation comparing many models with different hyperparameter configurations to determine good hyperparameters. As the number of hyperparameters grows and as models become more expensive to train and evaluate sample efficient aproaches to experimentation like Bayesian optimization become increasingly valuable.

In this tutorial we will train an SGDClassifier from the popular scikit-learn library to recognize handwritten digits and tune the model's hyperparameters to improve its performance. You can read more about the SGDClassifier model in their example here, which this tutorial is largely based on. This tutorial will incorporate many advanced features in Ax to demonstrate how they can be applied on complex engineering challenges in a real-world setting.

Learning Objectives

  • Understand how Ax can be used for HPO tasks
  • Use complex optimization configurations like multiple objectives and outcome constraints to achieve nuanced real-world goals
  • Use early stopping to save experimentation resources
  • Analyze the results of the optimization

Prerequisites

Step 1: Import Necessary Modules

First, ensure you have all the necessary imports:

import time

import matplotlib.pyplot as plt

import sklearn.datasets
import sklearn.linear_model
import sklearn.model_selection

from ax.api.client import Client
from ax.api.configs import ChoiceParameterConfig, RangeParameterConfig

from pyre_extensions import assert_is_instance

Step 1.1: Understanding the baseline performance of SGDClassifier

Before we begin HPO, let's understand the task and the performance of SGDClassifier with its default hyperparameters. The following code is largely adapted from the example on scikit-learn's webiste here.

# Load the digits dataset and display the first 4 images to demonstrate
digits = sklearn.datasets.load_digits()
classes = list(set(digits.target))

_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))
for ax, image, label in zip(axes, digits.images, digits.target):
ax.set_axis_off()
ax.imshow(image, cmap=plt.cm.gray_r, interpolation="nearest")
ax.set_title("Training: %i" % label)

# Instantiate a SGDClassifier with default hyperparameters
clf = sklearn.linear_model.SGDClassifier()

# Split the data into a training set and a validation set
train_x, valid_x, train_y, valid_y = sklearn.model_selection.train_test_split(
digits.data, digits.target, test_size=0.20, random_state=0
)

# Train the classifier on the training set using 10 batches
#Also time the training.
batch_size = len(train_x) // 10
start_time = time.time()
for i in range(10):
start_idx = i * batch_size
end_idx = (i + 1) * batch_size

# Use partial fit to update the model on the current batch
clf.partial_fit(
train_x[start_idx:end_idx], train_y[start_idx:end_idx], classes=classes
)

training_time = time.time() - start_time

# Evaluate the classifier on the validation set
score = clf.score(valid_x, valid_y)
score, training_time
Output:
(0.8638888888888889, 0.03519725799560547)

The model performs well, but let's see if we can improve performance by tuning the hyperparameters.

Step 2: Initialize the Client

As always, the first step in running our adaptive experiment with Ax is to create an instance of the Client to manage the state of your experiment.

client = Client()

Step 3: Configure the Experiment

The Client expects a series of Configs which define how the experiment will be run. We'll set this up the same way as we did in our previous tutorial.

Our current task is to tune the hyperparameters of an scikit-learn's SGDClassifier. These parameters control aspects of the model's training process and configuring them can have dramatic effects on the model's ability to correctly classify inputs. A full list of this model's hyperparameters and appropriate values are available in the library's documentation. In this tutorial we will tune the following hyperparameters:

  • loss: The loss function to be used
  • penalty: The penalty (aka regularization term) to be used
  • learning_rate: The learning rate schedule
  • alpha: Constant that multiplies the regularization term. The higher the value, the stronger the regularization
  • eta0: The learning rate for training. In this example we will use a constant learning rate schedule
  • batch_size: A training parameter which controls how many examples are shown during a single epoch. We will use all samples in the dataset for each model training, so a smaller batch size will translate to more epochs and vice versa.

You will notice some hyperparameters are continuous ranges, some are discrete ranges, and some are categorical choices; Ax is able to handle all of these types of parameters via its RangeParameterConfig and ChoiceParameterConfig classes.

# Configure and experiment with the desired parameters
client.configure_experiment(
parameters=[
ChoiceParameterConfig(
name="loss",
parameter_type="str",
values=[
"hinge",
"log_loss",
"squared_hinge",
"modified_huber",
"perceptron",
],
is_ordered=False,
),
ChoiceParameterConfig(
name="penalty",
parameter_type="str",
values=["l1", "l2", "elasticnet"],
is_ordered=False,
),
ChoiceParameterConfig(
name="learning_rate",
parameter_type="str",
values=["constant", "optimal", "invscaling", "adaptive"],
is_ordered=False,
),
RangeParameterConfig(
name="alpha",
bounds=(1e-8, 100),
parameter_type="float",
scaling="log", # Sample this parameter in log transformed space
),
RangeParameterConfig(
name="eta0",
bounds=(1e-8, 1),
parameter_type="float",
scaling="log",
),
RangeParameterConfig(
name="batch_size",
bounds=(5, 500),
parameter_type="int",
),
]
)
Output:
/home/runner/work/Ax/Ax/ax/api/utils/instantiation/from_config.py:75: AxParameterWarning: sort_values is not specified for ChoiceParameter "loss". Defaulting to False for parameters of ParameterType STRING. To override this behavior (or avoid this warning), specify sort_values during ChoiceParameter construction.
return ChoiceParameter(
/home/runner/work/Ax/Ax/ax/api/utils/instantiation/from_config.py:75: AxParameterWarning: sort_values is not specified for ChoiceParameter "penalty". Defaulting to False for parameters of ParameterType STRING. To override this behavior (or avoid this warning), specify sort_values during ChoiceParameter construction.
return ChoiceParameter(
/home/runner/work/Ax/Ax/ax/api/utils/instantiation/from_config.py:75: AxParameterWarning: sort_values is not specified for ChoiceParameter "learning_rate". Defaulting to False for parameters of ParameterType STRING. To override this behavior (or avoid this warning), specify sort_values during ChoiceParameter construction.
return ChoiceParameter(

Step 4: Configure Optimization

Now, we must set up the optimization objective in Client, where objective is a string that specifies which metric we would like to optimize and the direction (higher or lower) that is considered optimal.

In our example we want to consider both performance and computational cost implications of hyperparameter modifications. scikit-learn models use a function called score to report the mean accuracy of the model, and in our optimization we should seek to maximize this value. Since model training can be a very expensive process, especially for large models, this can represent a significant cost.

Let's configure Ax to maximize score while minimizing training time. We call this a multi-objective optimization, and rather than returning a single best parameterization we return a Pareto frontier of points which represent optimal tradeoffs between all metrics present. Multi-objective optimization is useful for competing metrics where a gain in one metric may represent a regression in the other.

In these settings we can also specify outcome constraints, which indicate that if a metric result falls outside of the specified threshold we are not interested in any result, regardless of the wins observed in any other metric. For a concrete example, imagine Ax finding a parameterization that trains in no time at all but has an score no better than if the model were guessing at random.

For this toy example let's configure Ax to maximize score and minimize training time, but avoid any hyperparameter configurations that result in a mean accuracy score of less than 75% or a training time greater than 1 second.

client.configure_optimization(
objective="score, -training_time",
outcome_constraints=["score >= 0.85", "training_time <= 1"],
)

Step 5: Run Trials with early stopping

Before we begin our Bayesian optimization loop, we can attach the data we collected from triaing SGDClassifier with default hyperparameters. This will give our experiment a head start by providing a datapoint to our surrogate model. Because these are the default settings provided by scikit-learn, it's likely they will be pretty good and will provide the optimization with a promising start. It is always advantageous to attach any existing data to an experiment to improve performance.

trial_index = client.attach_baseline(
parameters={
"loss": clf.loss,
"penalty": clf.penalty,
"alpha": clf.alpha,
"learning_rate": clf.learning_rate,
"eta0": clf.eta0
+ 1e-8, # Default eta is 0.0, so add a small value to avoid division by zero
"batch_size": batch_size,
}
)

client.complete_trial(
trial_index=trial_index,
raw_data={"score": score, "training_time": training_time},
)
Output:
<enum 'TrialStatus'>.COMPLETED

After attaching the initial trial, we will begin the experimentation loop by writing a for loop to execute our full experimentation budget of 30 trials. In each iteration we will ask Ax for the next trials (in this case just one), then instantiate an SGDClassifier with the suggested hyperparameters. We will then split the data into train and test sets. Next we will define an inner loop to perform minibatch training, in which we divide the train set into a number of smaller batches and train one epoch of stochastic gradient descent at a time. After each epoch we will report the score and the time.

Because training machine learning models is expensive, we will utilize Ax's early stopping functionality to kill trials unlikely to produce optimal results before they have been completed. After data has been attached we will ask the Client whether or not we should stop the trial, and if it advises us to do so we will report it early stopped and exit out of the training loop. By early stopping, we proactively save compute without regressing optimization performance.

for _ in range(20): # Run 20 rounds of 1 trial each
trials = client.get_next_trials(max_trials=1)
for trial_index, parameters in trials.items():
clf = sklearn.linear_model.SGDClassifier(
loss=parameters["loss"],
penalty=parameters["penalty"],
alpha=parameters["alpha"],
learning_rate=parameters["learning_rate"],
eta0=parameters["eta0"],
)

train_x, valid_x, train_y, valid_y = sklearn.model_selection.train_test_split(
digits.data,
digits.target,
test_size=0.20,
)

batch_size = assert_is_instance(parameters["batch_size"], int)
num_epochs = len(train_x) // batch_size

start_time = time.time()
for i in range(0, num_epochs):
start_idx = i * batch_size
end_idx = (i + 1) * batch_size

# Use partial fit to update the model on the current batch
clf.partial_fit(
train_x[start_idx:end_idx], train_y[start_idx:end_idx], classes=classes
)

raw_data = {
"score": clf.score(valid_x, valid_y),
"training_time": time.time() - start_time,
}

# On the final epoch call complete_trial and break, else call attach_data
if i == num_epochs - 1:
client.complete_trial(
trial_index=trial_index,
raw_data=raw_data,
progression=end_idx, # Use the index of the last example in the batch as the progression value
)
break

client.attach_data(
trial_index=trial_index,
raw_data=raw_data,
progression=end_idx,
)

# If the trial is underperforming, stop it
if client.should_stop_trial_early(trial_index=trial_index):
client.mark_trial_early_stopped(trial_index=trial_index)
break

Output:
[INFO 05-08 22:43:34] ax.early_stopping.strategies.percentile: Early stoppinging trial 8: Trial objective value 0.8 is worse than 50.0-th percentile (0.8159498207885305) across comparable trials..
[INFO 05-08 22:44:01] ax.early_stopping.strategies.percentile: Early stoppinging trial 9: Trial objective value 0.8083333333333333 is worse than 50.0-th percentile (0.8112903225806452) across comparable trials..
[INFO 05-08 22:44:48] ax.early_stopping.strategies.percentile: Early stoppinging trial 11: Trial objective value 0.3527777777777778 is worse than 50.0-th percentile (0.8087606837606838) across comparable trials..
[INFO 05-08 22:45:12] ax.early_stopping.strategies.percentile: Early stoppinging trial 12: Trial objective value 0.29444444444444445 is worse than 50.0-th percentile (0.7960470085470086) across comparable trials..
[INFO 05-08 22:45:30] ax.early_stopping.strategies.percentile: Early stoppinging trial 13: Trial objective value 0.14166666666666666 is worse than 50.0-th percentile (0.7833333333333334) across comparable trials..
[INFO 05-08 22:45:50] ax.early_stopping.strategies.percentile: Early stoppinging trial 14: Trial objective value 0.6916666666666667 is worse than 50.0-th percentile (0.7375) across comparable trials..
[INFO 05-08 22:46:12] ax.early_stopping.strategies.percentile: Early stoppinging trial 15: Trial objective value 0.4166666666666667 is worse than 50.0-th percentile (0.6916666666666667) across comparable trials..
[INFO 05-08 22:46:32] ax.early_stopping.strategies.percentile: Early stoppinging trial 16: Trial objective value 0.29444444444444445 is worse than 50.0-th percentile (0.5541666666666667) across comparable trials..
[INFO 05-08 22:47:20] ax.early_stopping.strategies.percentile: Early stoppinging trial 18: Trial objective value 0.8388888888888889 is worse than 50.0-th percentile (0.8428333025716747) across comparable trials..
[INFO 05-08 22:48:15] ax.early_stopping.strategies.percentile: Early stoppinging trial 20: Trial objective value 0.4861111111111111 is worse than 50.0-th percentile (0.6916666666666667) across comparable trials..

Step 6: Analyze Results

After running trials, you can analyze the results. Most commonly this means extracting the parameterization from the best performing trial you conducted.

Since we are optimizing multiple objectives, rather than a single best point we want to get the Pareto frontier -- the set of points that presents optimal tradeoffs between maximizing score and minimizing training time.

frontier = client.get_pareto_frontier()

# Frontier is a list of tuples, where each tuple contains the parameters, the metric readings, the trial index, and the arm name for a point on the Pareto frontier
for parameters, metrics, trial_index, arm_name in frontier:
print(f"Trial {trial_index} with {parameters=} and {metrics=}\n")
Output:
Trial 19 with parameters={'alpha': 1e-08, 'eta0': 2.0785865726007548e-08, 'batch_size': 500, 'loss': 'perceptron', 'penalty': 'l2', 'learning_rate': 'invscaling'} and metrics={'score': (np.float64(0.927439676073422), 0.00029892551520375397), 'training_time': (np.float64(0.04615980053989792), 6.503865663735869e-05)}
Trial 10 with parameters={'alpha': 0.0142668351160523, 'eta0': 1.0, 'batch_size': 500, 'loss': 'hinge', 'penalty': 'l1', 'learning_rate': 'invscaling'} and metrics={'score': (np.float64(0.8871002119084647), 0.00029458280454515656), 'training_time': (np.float64(0.039824911378048514), 6.537514807487683e-05)}
Trial 0 with parameters={'loss': 'hinge', 'penalty': 'l2', 'alpha': 0.0001, 'learning_rate': 'optimal', 'eta0': 1e-08, 'batch_size': 143} and metrics={'score': (np.float64(0.8653842011548905), 0.0002737580740356988), 'training_time': (np.float64(0.03590925079383221), 6.536118543762476e-05)}

Step 7: Compute Analyses

Ax can also produce a number of analyses to help interpret the results of the experiment via client.compute_analyses. Users can manually select which analyses to run, or can allow Ax to select which would be most relevant. In this case Ax selects the following:

  • Parrellel Coordinates Plot shows which parameterizations were evaluated and what metric values were observed -- this is useful for getting a high level overview of how thoroughly the search space was explored and which regions tend to produce which outcomes
  • Sensitivity Analysis Plot shows which parameters have the largest affect on the objective using Sobol Indicies
  • Slice Plot shows how the model predicts a single parameter effects the objective along with a confidence interval
  • Contour Plot shows how the model predicts a pair of parameters effects the objective as a 2D surface
  • Summary lists all trials generated along with their parameterizations, observations, and miscellaneous metadata
  • Cross Validation helps to visualize how well the surrogate model is able to predict out of sample points
# display=True instructs Ax to sort then render the resulting analyses
cards = client.compute_analyses(display=True)

Modeled score vs. training_time

This plot displays the effects of each arm on the two selected metrics. It is useful for understanding the trade-off between the two metrics and for visualizing the Pareto frontier.

loading...

Summary for Experiment

High-level summary of the Trial-s in this Experiment

trial_indexarm_nametrial_statusgeneration_nodescoretraining_timelosspenaltyalphalearning_rateeta0batch_size
00baselineCOMPLETEDnan0.8638890.035197hingel20.0001optimal1e-08143
111_0COMPLETEDCenterOfSearchSpace0.9138890.079596squared_hingel20.001invscaling0.0001252
222_0COMPLETEDSobol0.8444440.280332squared_hingel10.0290907constant0.013947293
333_0COMPLETEDSobol0.4833330.082041modified_huberl22.20269e-05invscaling2.81959e-07264
444_0COMPLETEDSobol0.0722220.041192log_losselasticnet39.4408adaptive0.000274625441
555_0COMPLETEDMBM0.8638890.267351squared_hingel20.000713196optimal1.10298e-05156
666_0COMPLETEDMBM0.90.134938hingel20.000159468constant0.000336628258
777_0COMPLETEDMBM0.9138890.105615squared_hingel12.18643e-08invscaling0.000129681314
888_0EARLY_STOPPEDMBM0.80.004833perceptronl21e-08invscaling0.000829665423
999_0EARLY_STOPPEDMBM0.8083330.004853perceptronl22.52671e-08invscaling0.000114444429
101010_0COMPLETEDMBM0.8861110.039145hingel10.0142668invscaling1500
111111_0EARLY_STOPPEDMBM0.3527780.005144squared_hingel21e-08invscaling8.40595e-07500
121212_0EARLY_STOPPEDMBM0.2944440.005046hingel21e-08invscaling1e-08500
131313_0EARLY_STOPPEDMBM0.1416670.00513squared_hingel21e-08invscaling2.07693e-07500
141414_0EARLY_STOPPEDMBM0.6916670.005079squared_hingel21e-08invscaling1.83289e-07500
151515_0EARLY_STOPPEDMBM0.4166670.005164squared_hingel21e-08invscaling6.438e-08500
161616_0EARLY_STOPPEDMBM0.2944440.005217squared_hingel21e-08invscaling8.98904e-07500
171717_0COMPLETEDMBM0.9250.272629hingel21e-08invscaling1191
181818_0EARLY_STOPPEDMBM0.8388890.004684perceptronl21e-08invscaling1.41204e-05381
191919_0COMPLETEDMBM0.9277780.04554perceptronl21e-08invscaling2.07859e-08500
202020_0EARLY_STOPPEDMBM0.4861110.005084squared_hingel21e-08invscaling1.53641e-07500

score by progression

The progression plot tracks the evolution of each metric over the course of the experiment. This visualization is typically used to monitor the improvement of metrics over Trial iterations, but can also be useful in informing decisions about early stopping for Trials.

loading...

training_time by progression

The progression plot tracks the evolution of each metric over the course of the experiment. This visualization is typically used to monitor the improvement of metrics over Trial iterations, but can also be useful in informing decisions about early stopping for Trials.

loading...

Sensitivity Analysis for score

Understand how each parameter affects score according to a second-order sensitivity analysis.

loading...

Sensitivity Analysis for training_time

Understand how each parameter affects training_time according to a second-order sensitivity analysis.

loading...

alpha vs. score

The slice plot provides a one-dimensional view of predicted outcomes for score as a function of a single parameter, while keeping all other parameters fixed at their status_quo value (or mean value if status_quo is unavailable). This visualization helps in understanding the sensitivity and impact of changes in the selected parameter on the predicted metric outcomes.

loading...

alpha, batch_size vs. score

The contour plot visualizes the predicted outcomes for score across a two-dimensional parameter space, with other parameters held fixed at their status_quo value (or mean value if status_quo is unavailable). This plot helps in identifying regions of optimal performance and understanding how changes in the selected parameters influence the predicted outcomes. Contour lines represent levels of constant predicted values, providing insights into the gradient and potential optima within the parameter space.

loading...

alpha, eta0 vs. score

The contour plot visualizes the predicted outcomes for score across a two-dimensional parameter space, with other parameters held fixed at their status_quo value (or mean value if status_quo is unavailable). This plot helps in identifying regions of optimal performance and understanding how changes in the selected parameters influence the predicted outcomes. Contour lines represent levels of constant predicted values, providing insights into the gradient and potential optima within the parameter space.

loading...

batch_size vs. score

The slice plot provides a one-dimensional view of predicted outcomes for score as a function of a single parameter, while keeping all other parameters fixed at their status_quo value (or mean value if status_quo is unavailable). This visualization helps in understanding the sensitivity and impact of changes in the selected parameter on the predicted metric outcomes.

loading...

eta0, batch_size vs. score

The contour plot visualizes the predicted outcomes for score across a two-dimensional parameter space, with other parameters held fixed at their status_quo value (or mean value if status_quo is unavailable). This plot helps in identifying regions of optimal performance and understanding how changes in the selected parameters influence the predicted outcomes. Contour lines represent levels of constant predicted values, providing insights into the gradient and potential optima within the parameter space.

loading...

batch_size vs. training_time

The slice plot provides a one-dimensional view of predicted outcomes for training_time as a function of a single parameter, while keeping all other parameters fixed at their status_quo value (or mean value if status_quo is unavailable). This visualization helps in understanding the sensitivity and impact of changes in the selected parameter on the predicted metric outcomes.

loading...

alpha, eta0 vs. training_time

The contour plot visualizes the predicted outcomes for training_time across a two-dimensional parameter space, with other parameters held fixed at their status_quo value (or mean value if status_quo is unavailable). This plot helps in identifying regions of optimal performance and understanding how changes in the selected parameters influence the predicted outcomes. Contour lines represent levels of constant predicted values, providing insights into the gradient and potential optima within the parameter space.

loading...

alpha, batch_size vs. training_time

The contour plot visualizes the predicted outcomes for training_time across a two-dimensional parameter space, with other parameters held fixed at their status_quo value (or mean value if status_quo is unavailable). This plot helps in identifying regions of optimal performance and understanding how changes in the selected parameters influence the predicted outcomes. Contour lines represent levels of constant predicted values, providing insights into the gradient and potential optima within the parameter space.

loading...

alpha vs. training_time

The slice plot provides a one-dimensional view of predicted outcomes for training_time as a function of a single parameter, while keeping all other parameters fixed at their status_quo value (or mean value if status_quo is unavailable). This visualization helps in understanding the sensitivity and impact of changes in the selected parameter on the predicted metric outcomes.

loading...

eta0, batch_size vs. training_time

The contour plot visualizes the predicted outcomes for training_time across a two-dimensional parameter space, with other parameters held fixed at their status_quo value (or mean value if status_quo is unavailable). This plot helps in identifying regions of optimal performance and understanding how changes in the selected parameters influence the predicted outcomes. Contour lines represent levels of constant predicted values, providing insights into the gradient and potential optima within the parameter space.

loading...

Cross Validation for score

The cross-validation plot displays the model fit for each metric in the experiment. It employs a leave-one-out approach, where the model is trained on all data except one sample, which is used for validation. The plot shows the predicted outcome for the validation set on the y-axis against its actual value on the x-axis. Points that align closely with the dotted diagonal line indicate a strong model fit, signifying accurate predictions. Additionally, the plot includes 95% confidence intervals that provide insight into the noise in observations and the uncertainty in model predictions. A horizontal, flat line of predictions indicates that the model has not picked up on sufficient signal in the data, and instead is just predicting the mean.

loading...

Cross Validation for training_time

The cross-validation plot displays the model fit for each metric in the experiment. It employs a leave-one-out approach, where the model is trained on all data except one sample, which is used for validation. The plot shows the predicted outcome for the validation set on the y-axis against its actual value on the x-axis. Points that align closely with the dotted diagonal line indicate a strong model fit, signifying accurate predictions. Additionally, the plot includes 95% confidence intervals that provide insight into the noise in observations and the uncertainty in model predictions. A horizontal, flat line of predictions indicates that the model has not picked up on sufficient signal in the data, and instead is just predicting the mean.

loading...

Conclusion

This tutorial demonstates Ax's ability to solve AutoML tasks with in a resource efficient manor. We configured a complex optimization which captures the nuanced goals of the experiment and utilized early stopping to save resources by killing training runs unlikely to produce optimal results.

While this tutorial shows how to use Ax for HPO on an SGDClassifier, the same techniques can be used for many different AutoML tasks such as feature selection, architecture search, and more.