This tutorial walks through using Ax to tune two hyperparameters (learning rate and momentum) for a PyTorch CNN on the MNIST dataset trained using SGD with momentum.
import torch
import numpy as np
from ax.plot.contour import plot_contour
from ax.plot.trace import optimization_trace_single_method
from ax.service.managed_loop import optimize
from ax.utils.notebook.plotting import render, init_notebook_plotting
from ax.utils.tutorials.cnn_utils import load_mnist, train, evaluate, CNN
init_notebook_plotting()
-------------------------------------------------------------- ImportError Traceback (most recent call last) <ipython-input-1-01f17f24888a> in <module> 6 from ax.service.managed_loop import optimize 7 from ax.utils.notebook.plotting import render, init_notebook_plotting ----> 8 from ax.utils.tutorials.cnn_utils import load_mnist, train, evaluate, CNN 9 10 init_notebook_plotting() ~/anaconda3/lib/python3.7/site-packages/ax/utils/tutorials/cnn_utils.py in <module> 9 import torch.nn.functional as F 10 import torch.optim as optim ---> 11 import torchvision 12 import torchvision.transforms as transforms 13 from torch.utils.data import DataLoader, Dataset, Subset ~/anaconda3/lib/python3.7/site-packages/torchvision/__init__.py in <module> ----> 1 from torchvision import models 2 from torchvision import datasets 3 from torchvision import ops 4 from torchvision import transforms 5 from torchvision import utils ~/anaconda3/lib/python3.7/site-packages/torchvision/models/__init__.py in <module> 9 from .shufflenetv2 import * 10 from . import segmentation ---> 11 from . import detection ~/anaconda3/lib/python3.7/site-packages/torchvision/models/detection/__init__.py in <module> ----> 1 from .faster_rcnn import * 2 from .mask_rcnn import * 3 from .keypoint_rcnn import * ~/anaconda3/lib/python3.7/site-packages/torchvision/models/detection/faster_rcnn.py in <module> 5 import torch.nn.functional as F 6 ----> 7 from torchvision.ops import misc as misc_nn_ops 8 from torchvision.ops import MultiScaleRoIAlign 9 ~/anaconda3/lib/python3.7/site-packages/torchvision/ops/__init__.py in <module> ----> 1 from .boxes import nms, box_iou 2 from .roi_align import roi_align, RoIAlign 3 from .roi_pool import roi_pool, RoIPool 4 from .poolers import MultiScaleRoIAlign 5 from .feature_pyramid_network import FeaturePyramidNetwork ~/anaconda3/lib/python3.7/site-packages/torchvision/ops/boxes.py in <module> 1 import torch ----> 2 from torchvision import _C 3 4 5 def nms(boxes, scores, iou_threshold): ImportError: dlopen(/Users/lilidworkin/anaconda3/lib/python3.7/site-packages/torchvision/_C.cpython-37m-darwin.so, 2): Symbol not found: __ZN2at19NonVariableTypeMode10is_enabledEv Referenced from: /Users/lilidworkin/anaconda3/lib/python3.7/site-packages/torchvision/_C.cpython-37m-darwin.so Expected in: flat namespace in /Users/lilidworkin/anaconda3/lib/python3.7/site-packages/torchvision/_C.cpython-37m-darwin.so
torch.manual_seed(12345)
dtype = torch.float
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
First, we need to load the MNIST data and partition it into training, validation, and test sets.
Note: this will download the dataset if necessary.
BATCH_SIZE = 512
train_loader, valid_loader, test_loader = load_mnist(batch_size=BATCH_SIZE)
In this tutorial, we want to optimize classification accuracy on the validation set as a function of the learning rate and momentum. The function takes in a parameterization (set of parameter values), computes the classification accuracy, and returns a dictionary of metric name ('accuracy') to a tuple with the mean and standard error.
def train_evaluate(parameterization):
net = CNN()
net = train(net=net, train_loader=train_loader, parameters=parameterization, dtype=dtype, device=device)
return evaluate(
net=net,
data_loader=valid_loader,
dtype=dtype,
device=device,
)
Here, we set the bounds on the learning rate and momentum and set the parameter space for the learning rate to be on a log scale.
best_parameters, values, experiment, model = optimize(
parameters=[
{"name": "lr", "type": "range", "bounds": [1e-6, 0.4], "log_scale": True},
{"name": "momentum", "type": "range", "bounds": [0.0, 1.0]},
],
evaluation_function=train_evaluate,
objective_name='accuracy',
)
[INFO 08-09 10:32:05] ax.service.utils.dispatch: Using Bayesian Optimization generation strategy. Iterations after 5 will take longer to generate due to model-fitting. [INFO 08-09 10:32:05] ax.service.managed_loop: Started full optimization with 20 steps. [INFO 08-09 10:32:05] ax.service.managed_loop: Running optimization trial 1... [INFO 08-09 10:32:40] ax.service.managed_loop: Running optimization trial 2... [INFO 08-09 10:33:03] ax.service.managed_loop: Running optimization trial 3... [INFO 08-09 10:33:26] ax.service.managed_loop: Running optimization trial 4... [INFO 08-09 10:33:50] ax.service.managed_loop: Running optimization trial 5... [INFO 08-09 10:34:13] ax.service.managed_loop: Running optimization trial 6... [INFO 08-09 10:34:41] ax.service.managed_loop: Running optimization trial 7... [INFO 08-09 10:35:11] ax.service.managed_loop: Running optimization trial 8... [INFO 08-09 10:35:48] ax.service.managed_loop: Running optimization trial 9... [INFO 08-09 10:36:26] ax.service.managed_loop: Running optimization trial 10... [INFO 08-09 10:37:04] ax.service.managed_loop: Running optimization trial 11... [INFO 08-09 10:37:40] ax.service.managed_loop: Running optimization trial 12... [INFO 08-09 10:38:25] ax.service.managed_loop: Running optimization trial 13... [INFO 08-09 10:39:11] ax.service.managed_loop: Running optimization trial 14... [INFO 08-09 10:39:54] ax.service.managed_loop: Running optimization trial 15... [INFO 08-09 10:40:38] ax.service.managed_loop: Running optimization trial 16... [INFO 08-09 10:41:35] ax.service.managed_loop: Running optimization trial 17... [INFO 08-09 10:42:21] ax.service.managed_loop: Running optimization trial 18... [INFO 08-09 10:43:07] ax.service.managed_loop: Running optimization trial 19... [INFO 08-09 10:43:50] ax.service.managed_loop: Running optimization trial 20...
We can introspect the optimal parameters and their outcomes:
best_parameters
{'lr': 0.0006307873441197164, 'momentum': 0.39064412336820153}
means, covariances = values
means, covariances
({'accuracy': 0.9366666468920226}, {'accuracy': {'accuracy': 6.466514231853397e-09}})
Contour plot showing classification accuracy as a function of the two hyperparameters.
The black squares show points that we have actually run, notice how they are clustered in the optimal region.
render(plot_contour(model=model, param_x='lr', param_y='momentum', metric_name='accuracy'))
Show the model accuracy improving as we identify better hyperparameters.
# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple
# optimization runs, so we wrap out best objectives array in another array.
best_objectives = np.array([[trial.objective_mean*100 for trial in experiment.trials.values()]])
best_objective_plot = optimization_trace_single_method(
y=np.maximum.accumulate(best_objectives, axis=1),
title="Model performance vs. # of iterations",
ylabel="Classification Accuracy, %",
)
render(best_objective_plot)
Note that the resulting accuracy on the test set might not be exactly the same as the maximum accuracy achieved on the evaluation set throughout optimization.
data = experiment.fetch_data()
df = data.df
best_arm_name = df.arm_name[df['mean'] == df['mean'].max()].values[0]
best_arm = experiment.arms_by_name[best_arm_name]
best_arm
Arm(name='18_0', parameters={'lr': 0.0006307873441197164, 'momentum': 0.39064412336820153})
combined_train_valid_set = torch.utils.data.ConcatDataset([
train_loader.dataset.dataset,
valid_loader.dataset.dataset,
])
combined_train_valid_loader = torch.utils.data.DataLoader(
combined_train_valid_set,
batch_size=BATCH_SIZE,
shuffle=True,
)
net = train(
net=CNN(),
train_loader=combined_train_valid_loader,
parameters=best_arm.parameters,
dtype=dtype,
device=device,
)
test_accuracy = evaluate(
net=net,
data_loader=test_loader,
dtype=dtype,
device=device,
)
print(f"Classification Accuracy (test set): {round(test_accuracy*100, 2)}%")
Classification Accuracy (test set): 97.8%