# Tune a CNN on MNIST

This tutorial walks through using Ax to tune two hyperparameters (learning rate and momentum) for a PyTorch CNN on the MNIST dataset trained using SGD with momentum.


In [1]:
import torch
import numpy as np

from ax.plot.contour import plot_contour
from ax.plot.trace import optimization_trace_single_method
from ax.service.managed_loop import optimize
from ax.utils.notebook.plotting import render, init_notebook_plotting
from ax.utils.tutorials.cnn_utils import load_mnist, train, evaluate, CNN

init_notebook_plotting()

ImportError: dlopen(/Users/lilidworkin/anaconda3/lib/python3.7/site-packages/torchvision/_C.cpython-37m-darwin.so, 2): Symbol not found: __ZN2at19NonVariableTypeMode10is_enabledEv
  Referenced from: /Users/lilidworkin/anaconda3/lib/python3.7/site-packages/torchvision/_C.cpython-37m-darwin.so
  Expected in: flat namespace
 in /Users/lilidworkin/anaconda3/lib/python3.7/site-packages/torchvision/_C.cpython-37m-darwin.so

In [2]:
torch.manual_seed(12345)
dtype = torch.float
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 1. Load MNIST data
First, we need to load the MNIST data and partition it into training, validation, and test sets.

Note: this will download the dataset if necessary.

In [None]:
BATCH_SIZE = 512
train_loader, valid_loader, test_loader = load_mnist(batch_size=BATCH_SIZE)

## 2. Define function to optimize
In this tutorial, we want to optimize classification accuracy on the validation set as a function of the learning rate and momentum. The function takes in a parameterization (set of parameter values), computes the classification accuracy, and returns a dictionary of metric name ('accuracy') to a tuple with the mean and standard error.

In [6]:
def train_evaluate(parameterization):
    net = CNN()
    net = train(net=net, train_loader=train_loader, parameters=parameterization, dtype=dtype, device=device)
    return evaluate(
        net=net,
        data_loader=valid_loader,
        dtype=dtype,
        device=device,
    )

## 3. Run the optimization loop
Here, we set the bounds on the learning rate and momentum and set the parameter space for the learning rate to be on a log scale. 

In [7]:
best_parameters, values, experiment, model = optimize(
    parameters=[
        {"name": "lr", "type": "range", "bounds": [1e-6, 0.4], "log_scale": True},
        {"name": "momentum", "type": "range", "bounds": [0.0, 1.0]},
    ],
    evaluation_function=train_evaluate,
    objective_name='accuracy',
)

[INFO 08-09 10:32:05] ax.service.utils.dispatch: Using Bayesian Optimization generation strategy. Iterations after 5 will take longer to generate due to model-fitting.
[INFO 08-09 10:32:05] ax.service.managed_loop: Started full optimization with 20 steps.
[INFO 08-09 10:32:05] ax.service.managed_loop: Running optimization trial 1...
[INFO 08-09 10:32:40] ax.service.managed_loop: Running optimization trial 2...
[INFO 08-09 10:33:03] ax.service.managed_loop: Running optimization trial 3...
[INFO 08-09 10:33:26] ax.service.managed_loop: Running optimization trial 4...
[INFO 08-09 10:33:50] ax.service.managed_loop: Running optimization trial 5...
[INFO 08-09 10:34:13] ax.service.managed_loop: Running optimization trial 6...
[INFO 08-09 10:34:41] ax.service.managed_loop: Running optimization trial 7...
[INFO 08-09 10:35:11] ax.service.managed_loop: Running optimization trial 8...
[INFO 08-09 10:35:48] ax.service.managed_loop: Running optimization trial 9...
[INFO 08-09 10:36:26] ax.service.

We can introspect the optimal parameters and their outcomes:

In [8]:
best_parameters

{'lr': 0.0006307873441197164, 'momentum': 0.39064412336820153}

In [9]:
means, covariances = values
means, covariances

({'accuracy': 0.9366666468920226},
 {'accuracy': {'accuracy': 6.466514231853397e-09}})

## 4. Plot response surface

Contour plot showing classification accuracy as a function of the two hyperparameters.

The black squares show points that we have actually run, notice how they are clustered in the optimal region.

In [10]:
render(plot_contour(model=model, param_x='lr', param_y='momentum', metric_name='accuracy'))

## 5. Plot best objective as function of the iteration

Show the model accuracy improving as we identify better hyperparameters.

In [11]:
# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple 
# optimization runs, so we wrap out best objectives array in another array.
best_objectives = np.array([[trial.objective_mean*100 for trial in experiment.trials.values()]])
best_objective_plot = optimization_trace_single_method(
    y=np.maximum.accumulate(best_objectives, axis=1),
    title="Model performance vs. # of iterations",
    ylabel="Classification Accuracy, %",
)
render(best_objective_plot)

## 6. Train CNN with best hyperparameters and evaluate on test set
Note that the resulting accuracy on the test set might not be exactly the same as the maximum accuracy achieved on the evaluation set throughout optimization. 

In [12]:
data = experiment.fetch_data()
df = data.df
best_arm_name = df.arm_name[df['mean'] == df['mean'].max()].values[0]
best_arm = experiment.arms_by_name[best_arm_name]
best_arm

Arm(name='18_0', parameters={'lr': 0.0006307873441197164, 'momentum': 0.39064412336820153})

In [13]:
combined_train_valid_set = torch.utils.data.ConcatDataset([
    train_loader.dataset.dataset, 
    valid_loader.dataset.dataset,
])
combined_train_valid_loader = torch.utils.data.DataLoader(
    combined_train_valid_set, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
)

In [14]:
net = train(
    net=CNN(),
    train_loader=combined_train_valid_loader, 
    parameters=best_arm.parameters,
    dtype=dtype,
    device=device,
)
test_accuracy = evaluate(
    net=net,
    data_loader=test_loader,
    dtype=dtype,
    device=device,
)

In [15]:
print(f"Classification Accuracy (test set): {round(test_accuracy*100, 2)}%")

Classification Accuracy (test set): 97.8%
