Ax for AutoML
Automated machine learning (AutoML) encompasses a large class of problems related to automating time-consuming and labor-intensive aspects of developing ML models. Adaptive experimentation is a natural fit for solving many AutoML tasks, which are often iterative in nature and can involve many expensive trial evaluations.
In this tutorial we will use Ax for hyperparameter optimization (HPO), a common AutoML task in which a model's hyperparameters are adjusted to improve model performance. Hyperparameters refer to the parameters which are set prior to model training or fitting, rather than parameters being learned from data. Traditionally, ML engineers use a combination of domain knowledge, intuition, and manual experimentation comparing many models with different hyperparameter configurations to determine good hyperparameters. As the number of hyperparameters grows and as models become more expensive to train and evaluate sample efficient aproaches to experimentation like Bayesian optimization become increasingly valuable.
In this tutorial we will train an SGDClassifier
from the popular
scikit-learn library to recognize handwritten digits and
tune the model's hyperparameters to improve its performance. You can read more about the
SGDClassifier
model in their example
here,
which this tutorial is largely based on. This tutorial will incorporate many advanced
features in Ax to demonstrate how they can be applied on complex engineering challenges
in a real-world setting.
Learning Objectives
- Understand how Ax can be used for HPO tasks
- Use complex optimization configurations like multiple objectives and outcome constraints to achieve nuanced real-world goals
- Use early stopping to save experimentation resources
- Analyze the results of the optimization
Prerequisites
- Familiarity with scikit-learn and basic machine learning concepts
- Understanding of adaptive experimentation and Bayesian optimization
- Ask-tell Optimization of Python Functions with early stopping
Step 1: Import Necessary Modules
First, ensure you have all the necessary imports:
import time
import matplotlib.pyplot as plt
import sklearn.datasets
import sklearn.linear_model
import sklearn.model_selection
from ax.preview.api.client import Client
from ax.preview.api.configs import (
ChoiceParameterConfig,
ExperimentConfig,
GenerationMethod,
GenerationStrategyConfig,
ParameterScaling,
ParameterType,
RangeParameterConfig,
)
from pyre_extensions import assert_is_instance
Step 1.1: Understanding the baseline performance of SGDClassifier
Before we begin HPO, let's understand the task and the performance of SGDClassifier
with its default hyperparameters. The following code is largely adapted from the example
on scikit-learn's webiste
here.
# Load the digits dataset and display the first 4 images to demonstrate
digits = sklearn.datasets.load_digits()
classes = list(set(digits.target))
_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))
for ax, image, label in zip(axes, digits.images, digits.target):
ax.set_axis_off()
ax.imshow(image, cmap=plt.cm.gray_r, interpolation="nearest")
ax.set_title("Training: %i" % label)
# Instantiate a SGDClassifier with default hyperparameters
clf = sklearn.linear_model.SGDClassifier()
# Split the data into a training set and a validation set
train_x, valid_x, train_y, valid_y = sklearn.model_selection.train_test_split(
digits.data, digits.target, test_size=0.20, random_state=0
)
# Train the classifier on the training set using 10 batches
#Also time the training.
batch_size = len(train_x) // 10
start_time = time.time()
for i in range(10):
start_idx = i * batch_size
end_idx = (i + 1) * batch_size
# Use partial fit to update the model on the current batch
clf.partial_fit(
train_x[start_idx:end_idx], train_y[start_idx:end_idx], classes=classes
)
training_time = time.time() - start_time
# Evaluate the classifier on the validation set
score = clf.score(valid_x, valid_y)
score, training_time
The model performs well, but let's see if we can improve performance by tuning the hyperparameters.
Step 2: Initialize the Client
As always, the first step in running our adaptive experiment with Ax is to create an
instance of the Client
to manage the state of your experiment.
client = Client()
Step 3: Configure the Experiment
The Client
expects a series of Config
s which define how the experiment will be run.
We'll set this up the same way as we did in our previous tutorial.
Our current task is to tune the hyperparameters of an scikit-learn's SGDClassifier. These parameters control aspects of the model's training process and configuring them can have dramatic effects on the model's ability to correctly classify inputs. A full list of this model's hyperparameters and appropriate values are available in the library's documentation. In this tutorial we will tune the following hyperparameters:
- loss: The loss function to be used
- penalty: The penalty (aka regularization term) to be used
- learning_rate: The learning rate schedule
- alpha: Constant that multiplies the regularization term. The higher the value, the stronger the regularization
- eta0: The learning rate for training. In this example we will use a constant learning rate schedule
- batch_size: A training parameter which controls how many examples are shown during a single epoch. We will use all samples in the dataset for each model training, so a smaller batch size will translate to more epochs and vice versa.
You will notice some hyperparameters are continuous ranges, some are discrete ranges,
and some are categorical choices; Ax is able to handle all of these types of parameters
via its RangeParameterConfig
and ChoiceParameterConfig
classes.
# Create an experiment configuration
experiment_config = ExperimentConfig(
name="SGDClassifier_hpo",
parameters=[
ChoiceParameterConfig(
name="loss",
parameter_type=ParameterType.STRING,
values=[
"hinge",
"log_loss",
"squared_hinge",
"modified_huber",
"perceptron",
],
is_ordered=False,
),
ChoiceParameterConfig(
name="penalty",
parameter_type=ParameterType.STRING,
values=["l1", "l2", "elasticnet"],
is_ordered=False,
),
ChoiceParameterConfig(
name="learning_rate",
parameter_type=ParameterType.STRING,
values=["constant", "optimal", "invscaling", "adaptive"],
is_ordered=False,
),
RangeParameterConfig(
name="alpha",
bounds=(1e-8, 100),
parameter_type=ParameterType.FLOAT,
scaling=ParameterScaling.LOG,
),
RangeParameterConfig(
name="eta0",
bounds=(1e-8, 1),
parameter_type=ParameterType.FLOAT,
scaling=ParameterScaling.LOG,
),
RangeParameterConfig(
name="batch_size",
bounds=(5, 500),
parameter_type=ParameterType.INT,
),
],
# The following arguments are optional
description="Optimization of SGDClassifier for digits dataset",
owner="developer",
)
# Apply the experiment configuration to the client
client.configure_experiment(experiment_config=experiment_config)
client.configure_generation_strategy(
GenerationStrategyConfig(method=GenerationMethod.FAST)
)
Step 4: Configure Optimization
Now, we must set up the optimization objective in Client
, where objective
is a
string that specifies which metric we would like to optimize and the direction (higher
or lower) that is considered optimal.
In our example we want to consider both performance and computational cost implications
of hyperparameter modifications. scikit-learn
models use a function called score
to
report the mean accuracy of the model, and in our optimization we should seek to
maximize this value. Since model training can be a very expensive process, especially
for large models, this can represent a significant cost.
Let's configure Ax to maximize score while minimizing training time. We call this a multi-objective optimization, and rather than returning a single best parameterization we return a Pareto frontier of points which represent optimal tradeoffs between all metrics present. Multi-objective optimization is useful for competing metrics where a gain in one metric may represent a regression in the other.
In these settings we can also specify outcome constraints, which indicate that if a metric result falls outside of the specified threshold we are not interested in any result, regardless of the wins observed in any other metric. For a concrete example, imagine Ax finding a parameterization that trains in no time at all but has an score no better than if the model were guessing at random.
For this toy example let's configure Ax to maximize score and minimize training time, but avoid any hyperparameter configurations that result in a mean accuracy score of less than 75% or a training time greater than 1 second.
client.configure_optimization(
objective="score, -training_time",
outcome_constraints=["score >= 0.85", "training_time <= 1"],
)
Step 5: Run Trials with early stopping
Before we begin our Bayesian optimization loop, we can attach the data we collected from
triaing SGDClassifier
with default hyperparameters. This will give our experiment a
head start by providing a datapoint to our surrogate model. Because these are the
default settings provided by scikit-learn
, it's likely they will be pretty good and
will provide the optimization with a promising start. It is always advantageous to
attach any existing data to an experiment to improve performance.
trial_index = client.attach_baseline(
parameters={
"loss": clf.loss,
"penalty": clf.penalty,
"alpha": clf.alpha,
"learning_rate": clf.learning_rate,
"eta0": clf.eta0
+ 1e-8, # Default eta is 0.0, so add a small value to avoid division by zero
"batch_size": batch_size,
}
)
client.complete_trial(
trial_index=trial_index,
raw_data={"score": score, "training_time": training_time},
)
After attaching the initial trial, we will begin the experimentation loop by writing a
for loop to execute our full experimentation budget of 30 trials. In each iteration we
will ask Ax for the next trials (in this case just one), then instantiate an
SGDClassifier
with the suggested hyperparameters. We will then split the data into
train and test sets. Next we will define an inner loop to perform minibatch training, in
which we divide the train set into a number of smaller batches and train one epoch of
stochastic gradient descent at a time. After each epoch we will report the score and the
time.
Because training machine learning models is expensive, we will utilize Ax's early
stopping functionality to kill trials unlikely to produce optimal results before they
have been completed. After data has been attached we will ask the Client
whether or
not we should stop the trial, and if it advises us to do so we will report it early
stopped and exit out of the training loop. By early stopping, we proactively save
compute without regressing optimization performance.
for _ in range(30):
trials = client.get_next_trials(maximum_trials=1)
for trial_index, parameters in trials.items():
clf = sklearn.linear_model.SGDClassifier(
loss=parameters["loss"],
penalty=parameters["penalty"],
alpha=parameters["alpha"],
learning_rate=parameters["learning_rate"],
eta0=parameters["eta0"],
)
train_x, valid_x, train_y, valid_y = sklearn.model_selection.train_test_split(
digits.data,
digits.target,
test_size=0.20,
)
batch_size = assert_is_instance(parameters["batch_size"], int)
num_epochs = len(train_x) // batch_size
start_time = time.time()
for i in range(0, num_epochs):
start_idx = i * batch_size
end_idx = (i + 1) * batch_size
# Use partial fit to update the model on the current batch
clf.partial_fit(
train_x[start_idx:end_idx], train_y[start_idx:end_idx], classes=classes
)
raw_data = {
"score": clf.score(valid_x, valid_y),
"training_time": time.time() - start_time,
}
# On the final epoch call complete_trial and break, else call attach_data
if i == num_epochs - 1:
client.complete_trial(
trial_index=trial_index,
raw_data=raw_data,
progression=end_idx, # Use the index of the last example in the batch as the progression value
)
break
client.attach_data(
trial_index=trial_index,
raw_data=raw_data,
progression=end_idx,
)
# If the trial is underperforming, stop it
if client.should_stop_trial_early(trial_index=trial_index):
client.mark_trial_early_stopped(trial_index=trial_index)
break
Step 6: Analyze Results
After running trials, you can analyze the results. Most commonly this means extracting the parameterization from the best performing trial you conducted.
Since we are optimizing multiple objectives, rather than a single best point we want to get the Pareto frontier -- the set of points that presents optimal tradeoffs between maximizing score and minimizing training time.
frontier = client.get_pareto_frontier()
# Frontier is a list of tuples, where each tuple contains the parameters, the metric readings, the trial index, and the arm name for a point on the Pareto frontier
for parameters, metrics, trial_index, arm_name in frontier:
print(f"Trial {trial_index} with {parameters=} and {metrics=}\n")
Step 7: Compute Analyses
Ax can also produce a number of analyses to help interpret the results of the experiment
via client.compute_analyses
. Users can manually select which analyses to run, or can
allow Ax to select which would be most relevant. In this case Ax selects the following:
- Scatter Plot shows a plane with each objective on its own axis and a point for each observation. In multi-objective optimizations like ours it also draws a line through the Pareto frontier, indicating which points represent optimal tradeoffs between our objectives.
- Interaction Analysis Plot shows which parameters have the largest affect on the function and plots the most important parameters as 1 or 2 dimensional surfaces
- Summary lists all trials generated along with their parameterizations, observations, and miscellaneous metadata
client.compute_analyses(display=True) # By default Ax will display the AnalysisCards produced by compute_analyses
score by progression
Observe how the metric changes as each trial progresses
training_time by progression
Observe how the metric changes as each trial progresses
Summary for SGDClassifier_hpo
High-level summary of the Trial
-s in this Experiment
trial_index | arm_name | trial_status | generation_method | generation_node | score | training_time | loss | penalty | alpha | learning_rate | eta0 | batch_size | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | baseline | COMPLETED | nan | nan | 0.880556 | 0.03495 | hinge | l2 | 0.0001 | optimal | 1e-08 | 143 |
1 | 1 | 1_0 | COMPLETED | Sobol | Sobol | 0.65 | 0.053883 | modified_huber | l1 | 1.77034 | invscaling | 3.346e-07 | 452 |
2 | 2 | 2_0 | COMPLETED | Sobol | Sobol | 0.95 | 0.274918 | squared_hinge | elasticnet | 0.000123815 | constant | 0.000109691 | 124 |
3 | 3 | 3_0 | COMPLETED | Sobol | Sobol | 0.894444 | 0.132055 | perceptron | l2 | 0.204403 | optimal | 3.00607e-05 | 207 |
4 | 4 | 4_0 | COMPLETED | Sobol | Sobol | 0.905556 | 0.07959 | log_loss | l1 | 1.07207e-08 | optimal | 0.0916614 | 353 |
5 | 5 | 5_0 | EARLY_STOPPED | BoTorch | MBM | 0.783333 | 0.283206 | perceptron | elasticnet | 4.27665e-05 | constant | 0.000202005 | 124 |
6 | 6 | 6_0 | EARLY_STOPPED | BoTorch | MBM | 0.758333 | 0.1206 | perceptron | elasticnet | 4.63926e-05 | constant | 0.000208576 | 122 |
7 | 7 | 7_0 | EARLY_STOPPED | BoTorch | MBM | 0.752778 | 0.038344 | perceptron | elasticnet | 5.69491e-05 | constant | 0.000168132 | 130 |
8 | 8 | 8_0 | COMPLETED | BoTorch | MBM | 0.908333 | 0.441074 | perceptron | elasticnet | 5.09063e-05 | constant | 0.000184407 | 125 |
9 | 9 | 9_0 | COMPLETED | BoTorch | MBM | 0.869444 | 0.32119 | squared_hinge | elasticnet | 9.71316e-06 | optimal | 0.00212544 | 179 |
10 | 10 | 10_0 | COMPLETED | BoTorch | MBM | 0.930556 | 0.276684 | squared_hinge | l2 | 0.000117479 | constant | 9.26604e-05 | 188 |
11 | 11 | 11_0 | EARLY_STOPPED | BoTorch | MBM | 0.619444 | 0.004961 | log_loss | l1 | 0.205074 | constant | 3.99409e-06 | 205 |
12 | 12 | 12_0 | EARLY_STOPPED | BoTorch | MBM | 0.211111 | 0.004955 | log_loss | l1 | 0.548273 | constant | 3.82095e-06 | 198 |
13 | 13 | 13_0 | EARLY_STOPPED | BoTorch | MBM | 0.444444 | 0.005328 | log_loss | l1 | 0.557321 | constant | 3.08567e-06 | 192 |
14 | 14 | 14_0 | EARLY_STOPPED | BoTorch | MBM | 0.241667 | 0.004966 | log_loss | l1 | 0.317142 | constant | 3.57574e-06 | 197 |
15 | 15 | 15_0 | EARLY_STOPPED | BoTorch | MBM | 0.505556 | 0.004962 | log_loss | l1 | 0.285497 | constant | 4.03447e-06 | 203 |
16 | 16 | 16_0 | EARLY_STOPPED | BoTorch | MBM | 0.275 | 0.004982 | log_loss | l1 | 0.0813153 | constant | 4.5994e-06 | 216 |
17 | 17 | 17_0 | EARLY_STOPPED | BoTorch | MBM | 0.486111 | 0.004973 | log_loss | l1 | 0.121932 | constant | 4.6834e-06 | 206 |
18 | 18 | 18_0 | EARLY_STOPPED | BoTorch | MBM | 0.172222 | 0.004872 | log_loss | l1 | 0.485148 | constant | 3.11631e-06 | 196 |
19 | 19 | 19_0 | EARLY_STOPPED | BoTorch | MBM | 0.188889 | 0.004965 | log_loss | l1 | 0.434221 | constant | 3.4096e-06 | 203 |
20 | 20 | 20_0 | EARLY_STOPPED | BoTorch | MBM | 0.430556 | 0.004902 | log_loss | l1 | 0.385857 | constant | 3.3245e-06 | 197 |
21 | 21 | 21_0 | EARLY_STOPPED | BoTorch | MBM | 0.238889 | 0.005004 | log_loss | l1 | 0.170539 | constant | 5.39843e-06 | 211 |
22 | 22 | 22_0 | EARLY_STOPPED | BoTorch | MBM | 0.297222 | 0.00491 | log_loss | l1 | 0.372314 | constant | 3.07541e-06 | 194 |
23 | 23 | 23_0 | EARLY_STOPPED | BoTorch | MBM | 0.397222 | 0.004968 | log_loss | l1 | 0.366029 | constant | 2.9322e-06 | 199 |
24 | 24 | 24_0 | EARLY_STOPPED | BoTorch | MBM | 0.433333 | 0.005012 | log_loss | l1 | 0.12292 | constant | 4.37456e-06 | 208 |
25 | 25 | 25_0 | EARLY_STOPPED | BoTorch | MBM | 0.411111 | 0.005052 | log_loss | l1 | 0.238168 | constant | 4.52233e-06 | 206 |
26 | 26 | 26_0 | EARLY_STOPPED | BoTorch | MBM | 0.152778 | 0.005043 | log_loss | l1 | 0.165718 | constant | 4.00178e-06 | 203 |
27 | 27 | 27_0 | EARLY_STOPPED | BoTorch | MBM | 0.397222 | 0.005028 | log_loss | l1 | 0.32417 | constant | 3.32965e-06 | 204 |
28 | 28 | 28_0 | EARLY_STOPPED | BoTorch | MBM | 0.552778 | 0.004958 | log_loss | l1 | 0.105632 | constant | 5.01087e-06 | 207 |
29 | 29 | 29_0 | EARLY_STOPPED | BoTorch | MBM | 0.313889 | 0.005044 | log_loss | l1 | 0.205153 | constant | 3.97733e-06 | 202 |
30 | 30 | 30_0 | EARLY_STOPPED | BoTorch | MBM | 0.413889 | 0.0049 | log_loss | l1 | 0.386347 | constant | 3.40869e-06 | 196 |
Cross Validation for score
Out-of-sample predictions using leave-one-out CV
Cross Validation for training_time
Out-of-sample predictions using leave-one-out CV
Conclusion
This tutorial demonstates Ax's ability to solve AutoML tasks with in a resource efficient manor. We configured a complex optimization which captures the nuanced goals of the experiment and utilized early stopping to save resources by killing training runs unlikely to produce optimal results.
While this tutorial shows how to use Ax for HPO on an SGDClassifier
, the same
techniques can be used for many different AutoML tasks such as feature selection,
architecture search, and more.