This tutorial walks through using Ax to tune two hyperparameters (learning rate and momentum) for a PyTorch CNN on the MNIST dataset trained using SGD with momentum.
import torch
import numpy as np
from ax.plot.contour import plot_contour
from ax.plot.trace import optimization_trace_single_method
from ax.service.managed_loop import optimize
from ax.utils.notebook.plotting import render, init_notebook_plotting
from ax.utils.tutorials.cnn_utils import load_mnist, train, evaluate, CNN
init_notebook_plotting()
[INFO 10-01 15:37:47] ax.utils.notebook.plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.
torch.manual_seed(12345)
dtype = torch.float
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
First, we need to load the MNIST data and partition it into training, validation, and test sets.
Note: this will download the dataset if necessary.
BATCH_SIZE = 512
train_loader, valid_loader, test_loader = load_mnist(batch_size=BATCH_SIZE)
0.1%
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
100.1%
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
113.5%
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
180.4%/home/travis/virtualenv/python3.7.1/lib/python3.7/site-packages/torchvision/datasets/mnist.py:469: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.)
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw Processing... Done!
In this tutorial, we want to optimize classification accuracy on the validation set as a function of the learning rate and momentum. The function takes in a parameterization (set of parameter values), computes the classification accuracy, and returns a dictionary of metric name ('accuracy') to a tuple with the mean and standard error.
def train_evaluate(parameterization):
net = CNN()
net = train(net=net, train_loader=train_loader, parameters=parameterization, dtype=dtype, device=device)
return evaluate(
net=net,
data_loader=valid_loader,
dtype=dtype,
device=device,
)
Here, we set the bounds on the learning rate and momentum and set the parameter space for the learning rate to be on a log scale.
best_parameters, values, experiment, model = optimize(
parameters=[
{"name": "lr", "type": "range", "bounds": [1e-6, 0.4], "log_scale": True},
{"name": "momentum", "type": "range", "bounds": [0.0, 1.0]},
],
evaluation_function=train_evaluate,
objective_name='accuracy',
)
[INFO 10-01 15:37:50] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter lr. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict. [INFO 10-01 15:37:50] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter momentum. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict. [INFO 10-01 15:37:50] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting. [INFO 10-01 15:37:50] ax.service.managed_loop: Started full optimization with 20 steps. [INFO 10-01 15:37:50] ax.service.managed_loop: Running optimization trial 1... [INFO 10-01 15:38:01] ax.service.managed_loop: Running optimization trial 2... [INFO 10-01 15:38:11] ax.service.managed_loop: Running optimization trial 3... [INFO 10-01 15:38:21] ax.service.managed_loop: Running optimization trial 4... [INFO 10-01 15:38:31] ax.service.managed_loop: Running optimization trial 5... [INFO 10-01 15:38:41] ax.service.managed_loop: Running optimization trial 6... /home/travis/build/facebook/Ax/ax/modelbridge/torch.py:311: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). [INFO 10-01 15:38:52] ax.service.managed_loop: Running optimization trial 7... /home/travis/build/facebook/Ax/ax/modelbridge/torch.py:311: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). [INFO 10-01 15:39:04] ax.service.managed_loop: Running optimization trial 8... /home/travis/build/facebook/Ax/ax/modelbridge/torch.py:311: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). [INFO 10-01 15:39:15] ax.service.managed_loop: Running optimization trial 9... /home/travis/build/facebook/Ax/ax/modelbridge/torch.py:311: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). [INFO 10-01 15:39:27] ax.service.managed_loop: Running optimization trial 10... /home/travis/build/facebook/Ax/ax/modelbridge/torch.py:311: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). [INFO 10-01 15:39:38] ax.service.managed_loop: Running optimization trial 11... /home/travis/build/facebook/Ax/ax/modelbridge/torch.py:311: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). [INFO 10-01 15:39:50] ax.service.managed_loop: Running optimization trial 12... /home/travis/build/facebook/Ax/ax/modelbridge/torch.py:311: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). [INFO 10-01 15:40:01] ax.service.managed_loop: Running optimization trial 13... /home/travis/build/facebook/Ax/ax/modelbridge/torch.py:311: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). [INFO 10-01 15:40:12] ax.service.managed_loop: Running optimization trial 14... /home/travis/build/facebook/Ax/ax/modelbridge/torch.py:311: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). [INFO 10-01 15:40:23] ax.service.managed_loop: Running optimization trial 15... /home/travis/build/facebook/Ax/ax/modelbridge/torch.py:311: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). [INFO 10-01 15:40:34] ax.service.managed_loop: Running optimization trial 16... /home/travis/build/facebook/Ax/ax/modelbridge/torch.py:311: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). [INFO 10-01 15:40:46] ax.service.managed_loop: Running optimization trial 17... /home/travis/build/facebook/Ax/ax/modelbridge/torch.py:311: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). [INFO 10-01 15:40:57] ax.service.managed_loop: Running optimization trial 18... /home/travis/build/facebook/Ax/ax/modelbridge/torch.py:311: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). [INFO 10-01 15:41:08] ax.service.managed_loop: Running optimization trial 19... /home/travis/build/facebook/Ax/ax/modelbridge/torch.py:311: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). [INFO 10-01 15:41:20] ax.service.managed_loop: Running optimization trial 20... /home/travis/build/facebook/Ax/ax/modelbridge/torch.py:311: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
We can introspect the optimal parameters and their outcomes:
best_parameters
{'lr': 0.00032609607345790854, 'momentum': 0.5589304304482358}
means, covariances = values
means, covariances
({'accuracy': 0.9391661306353767}, {'accuracy': {'accuracy': 9.507152920732499e-09}})
Contour plot showing classification accuracy as a function of the two hyperparameters.
The black squares show points that we have actually run, notice how they are clustered in the optimal region.
render(plot_contour(model=model, param_x='lr', param_y='momentum', metric_name='accuracy'))
/home/travis/build/facebook/Ax/ax/modelbridge/torch.py:311: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
Show the model accuracy improving as we identify better hyperparameters.
# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple
# optimization runs, so we wrap out best objectives array in another array.
best_objectives = np.array([[trial.objective_mean*100 for trial in experiment.trials.values()]])
best_objective_plot = optimization_trace_single_method(
y=np.maximum.accumulate(best_objectives, axis=1),
title="Model performance vs. # of iterations",
ylabel="Classification Accuracy, %",
)
render(best_objective_plot)
Note that the resulting accuracy on the test set might not be exactly the same as the maximum accuracy achieved on the evaluation set throughout optimization.
data = experiment.fetch_data()
df = data.df
best_arm_name = df.arm_name[df['mean'] == df['mean'].max()].values[0]
best_arm = experiment.arms_by_name[best_arm_name]
best_arm
Arm(name='16_0', parameters={'lr': 0.00032609607345790854, 'momentum': 0.5589304304482358})
combined_train_valid_set = torch.utils.data.ConcatDataset([
train_loader.dataset.dataset,
valid_loader.dataset.dataset,
])
combined_train_valid_loader = torch.utils.data.DataLoader(
combined_train_valid_set,
batch_size=BATCH_SIZE,
shuffle=True,
)
net = train(
net=CNN(),
train_loader=combined_train_valid_loader,
parameters=best_arm.parameters,
dtype=dtype,
device=device,
)
test_accuracy = evaluate(
net=net,
data_loader=test_loader,
dtype=dtype,
device=device,
)
print(f"Classification Accuracy (test set): {round(test_accuracy*100, 2)}%")
Classification Accuracy (test set): 97.86%
Total runtime of script: 4 minutes, 32.83 seconds.