This tutorial walks through using Ax to tune two hyperparameters (learning rate and momentum) for a PyTorch CNN on the MNIST dataset trained using SGD with momentum.
import torch
import numpy as np
from ax.plot.contour import plot_contour
from ax.plot.trace import optimization_trace_single_method
from ax.service.managed_loop import optimize
from ax.utils.notebook.plotting import render, init_notebook_plotting
from ax.utils.tutorials.cnn_utils import load_mnist, train, evaluate, CNN
init_notebook_plotting()
[INFO 09-15 02:39:24] ax.utils.notebook.plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.
torch.manual_seed(12345)
dtype = torch.float
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
First, we need to load the MNIST data and partition it into training, validation, and test sets.
Note: this will download the dataset if necessary.
BATCH_SIZE = 512
train_loader, valid_loader, test_loader = load_mnist(batch_size=BATCH_SIZE)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
/opt/hostedtoolcache/Python/3.7.11/x64/lib/python3.7/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)
In this tutorial, we want to optimize classification accuracy on the validation set as a function of the learning rate and momentum. The function takes in a parameterization (set of parameter values), computes the classification accuracy, and returns a dictionary of metric name ('accuracy') to a tuple with the mean and standard error.
def train_evaluate(parameterization):
net = CNN()
net = train(net=net, train_loader=train_loader, parameters=parameterization, dtype=dtype, device=device)
return evaluate(
net=net,
data_loader=valid_loader,
dtype=dtype,
device=device,
)
Here, we set the bounds on the learning rate and momentum and set the parameter space for the learning rate to be on a log scale.
best_parameters, values, experiment, model = optimize(
parameters=[
{"name": "lr", "type": "range", "bounds": [1e-6, 0.4], "log_scale": True},
{"name": "momentum", "type": "range", "bounds": [0.0, 1.0]},
],
evaluation_function=train_evaluate,
objective_name='accuracy',
)
[INFO 09-15 02:39:25] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter lr. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict. [INFO 09-15 02:39:25] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter momentum. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict. [INFO 09-15 02:39:25] ax.modelbridge.dispatch_utils: Using Bayesian optimization since there are more ordered parameters than there are categories for the unordered categorical parameters. [INFO 09-15 02:39:25] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting. [INFO 09-15 02:39:25] ax.service.managed_loop: Started full optimization with 20 steps. [INFO 09-15 02:39:25] ax.service.managed_loop: Running optimization trial 1... /opt/hostedtoolcache/Python/3.7.11/x64/lib/python3.7/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.) [INFO 09-15 02:39:32] ax.service.managed_loop: Running optimization trial 2... [INFO 09-15 02:39:39] ax.service.managed_loop: Running optimization trial 3... [INFO 09-15 02:39:45] ax.service.managed_loop: Running optimization trial 4... [INFO 09-15 02:39:52] ax.service.managed_loop: Running optimization trial 5... [INFO 09-15 02:39:58] ax.service.managed_loop: Running optimization trial 6... [INFO 09-15 02:40:05] ax.service.managed_loop: Running optimization trial 7... [INFO 09-15 02:40:12] ax.service.managed_loop: Running optimization trial 8... [INFO 09-15 02:40:19] ax.service.managed_loop: Running optimization trial 9... [INFO 09-15 02:40:26] ax.service.managed_loop: Running optimization trial 10... [INFO 09-15 02:40:34] ax.service.managed_loop: Running optimization trial 11... [INFO 09-15 02:40:41] ax.service.managed_loop: Running optimization trial 12... [INFO 09-15 02:40:48] ax.service.managed_loop: Running optimization trial 13... [INFO 09-15 02:40:55] ax.service.managed_loop: Running optimization trial 14... [INFO 09-15 02:41:03] ax.service.managed_loop: Running optimization trial 15... [INFO 09-15 02:41:11] ax.service.managed_loop: Running optimization trial 16... [INFO 09-15 02:41:18] ax.service.managed_loop: Running optimization trial 17... [INFO 09-15 02:41:26] ax.service.managed_loop: Running optimization trial 18... [INFO 09-15 02:41:34] ax.service.managed_loop: Running optimization trial 19... [INFO 09-15 02:41:42] ax.service.managed_loop: Running optimization trial 20...
We can introspect the optimal parameters and their outcomes:
best_parameters
{'lr': 0.0005084473167264391, 'momentum': 0.16242224757301063}
means, covariances = values
means, covariances
({'accuracy': 0.9311611763190136}, {'accuracy': {'accuracy': 0.00012250083678216547}})
Contour plot showing classification accuracy as a function of the two hyperparameters.
The black squares show points that we have actually run, notice how they are clustered in the optimal region.
render(plot_contour(model=model, param_x='lr', param_y='momentum', metric_name='accuracy'))
Show the model accuracy improving as we identify better hyperparameters.
# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple
# optimization runs, so we wrap out best objectives array in another array.
best_objectives = np.array([[trial.objective_mean*100 for trial in experiment.trials.values()]])
best_objective_plot = optimization_trace_single_method(
y=np.maximum.accumulate(best_objectives, axis=1),
title="Model performance vs. # of iterations",
ylabel="Classification Accuracy, %",
)
render(best_objective_plot)
Note that the resulting accuracy on the test set might not be exactly the same as the maximum accuracy achieved on the evaluation set throughout optimization.
data = experiment.fetch_data()
df = data.df
best_arm_name = df.arm_name[df['mean'] == df['mean'].max()].values[0]
best_arm = experiment.arms_by_name[best_arm_name]
best_arm
Arm(name='8_0', parameters={'lr': 0.0005084473167264391, 'momentum': 0.16242224757301063})
combined_train_valid_set = torch.utils.data.ConcatDataset([
train_loader.dataset.dataset,
valid_loader.dataset.dataset,
])
combined_train_valid_loader = torch.utils.data.DataLoader(
combined_train_valid_set,
batch_size=BATCH_SIZE,
shuffle=True,
)
net = train(
net=CNN(),
train_loader=combined_train_valid_loader,
parameters=best_arm.parameters,
dtype=dtype,
device=device,
)
test_accuracy = evaluate(
net=net,
data_loader=test_loader,
dtype=dtype,
device=device,
)
print(f"Classification Accuracy (test set): {round(test_accuracy*100, 2)}%")
Classification Accuracy (test set): 97.7%
Total runtime of script: 2 minutes, 56.91 seconds.