{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tune a CNN on MNIST\n", "\n", "This tutorial walks through using Ax to tune two hyperparameters (learning rate and momentum) for a PyTorch CNN on the MNIST dataset trained using SGD with momentum.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 04-30 12:37:08] ipy_plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.\n" ] } ], "source": [ "import torch\n", "import numpy as np\n", "\n", "from ax.plot.contour import plot_contour\n", "from ax.plot.trace import optimization_trace_single_method\n", "from ax.service.managed_loop import optimize\n", "from ax.utils.notebook.plotting import render, init_notebook_plotting\n", "from ax.utils.tutorials.cnn_utils import load_mnist, train, evaluate\n", "\n", "init_notebook_plotting()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "dtype = torch.float\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Load MNIST data\n", "First, we need to load the MNIST data and partition it into training, validation, and test sets.\n", "\n", "Note: this will download the dataset if necessary." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0.00/9.91M [00:00, ?B/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/raw/train-images-idx3-ubyte.gz\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "9.92MB [00:02, 4.45MB/s] \n", "32.8KB [00:00, 197KB/s] \n", "0.00B [00:00, ?B/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/raw/train-labels-idx1-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/raw/t10k-images-idx3-ubyte.gz\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "1.65MB [00:03, 542KB/s] \n", "0.00B [00:00, ?B/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/raw/t10k-labels-idx1-ubyte.gz\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "8.19KB [00:00, 17.1KB/s] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Processing...\n", "Done!\n" ] } ], "source": [ "train_loader, valid_loader, test_loader = load_mnist()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Define function to optimize\n", "In this tutorial, we want to optimize classification accuracy on the validation set as a function of the learning rate and momentum. The function takes in a parameterization (set of parameter values), computes the classification accuracy, and returns a dictionary of metric name ('accuracy') to a tuple with the mean and standard error." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def train_evaluate(parameterization):\n", " net = train(train_loader=train_loader, parameters=parameterization, dtype=dtype, device=device)\n", " return evaluate(\n", " net=net,\n", " data_loader=valid_loader,\n", " dtype=dtype,\n", " device=device,\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Run the optimization loop\n", "Here, we set the bounds on the learning rate and momentum and set the parameter space for the learning rate to be on a log scale. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[INFO 04-30 12:37:40] ax.service.utils.dispatch: Using Bayesian Optimization generation strategy. Iterations after 5 will take longer to generate due to model-fitting.\n", "[INFO 04-30 12:37:40] ax.service.managed_loop: Started full optimization with 20 steps.\n", "[INFO 04-30 12:37:40] ax.service.managed_loop: Running optimization trial 1...\n", "[INFO 04-30 12:38:49] ax.service.managed_loop: Running optimization trial 2...\n", "[INFO 04-30 12:39:57] ax.service.managed_loop: Running optimization trial 3...\n", "[INFO 04-30 12:41:05] ax.service.managed_loop: Running optimization trial 4...\n", "[INFO 04-30 12:42:13] ax.service.managed_loop: Running optimization trial 5...\n", "[INFO 04-30 12:43:20] ax.service.managed_loop: Running optimization trial 6...\n", "[INFO 04-30 12:44:42] ax.service.managed_loop: Running optimization trial 7...\n", "[INFO 04-30 12:46:06] ax.service.managed_loop: Running optimization trial 8...\n", "[INFO 04-30 12:47:34] ax.service.managed_loop: Running optimization trial 9...\n", "[INFO 04-30 12:49:00] ax.service.managed_loop: Running optimization trial 10...\n", "[INFO 04-30 12:50:24] ax.service.managed_loop: Running optimization trial 11...\n", "[INFO 04-30 12:52:01] ax.service.managed_loop: Running optimization trial 12...\n", "[INFO 04-30 12:53:32] ax.service.managed_loop: Running optimization trial 13...\n", "[INFO 04-30 12:55:08] ax.service.managed_loop: Running optimization trial 14...\n", "[INFO 04-30 12:56:45] ax.service.managed_loop: Running optimization trial 15...\n", "[INFO 04-30 12:58:17] ax.service.managed_loop: Running optimization trial 16...\n", "[INFO 04-30 12:59:55] ax.service.managed_loop: Running optimization trial 17...\n", "[INFO 04-30 13:01:43] ax.service.managed_loop: Running optimization trial 18...\n", "[INFO 04-30 13:03:48] ax.service.managed_loop: Running optimization trial 19...\n", "[INFO 04-30 13:06:00] ax.service.managed_loop: Running optimization trial 20...\n" ] } ], "source": [ "best_parameters, values, experiment, model = optimize(\n", " parameters=[\n", " {\"name\": \"lr\", \"type\": \"range\", \"bounds\": [1e-6, 0.4], \"log_scale\": True},\n", " {\"name\": \"momentum\", \"type\": \"range\", \"bounds\": [0.0, 1.0]},\n", " ],\n", " evaluation_function=train_evaluate,\n", " objective_name='accuracy',\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can introspect the optimal parameters and their outcomes:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'lr': 0.0029176399675537317, 'momentum': 3.0347402313065844e-16}" ] }, "execution_count": 6, "metadata": { "bento_obj_id": "139719018969416" }, "output_type": "execute_result" } ], "source": [ "best_parameters" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "({'accuracy': 0.968833362542745},\n", " {'accuracy': {'accuracy': 1.3653840299223108e-08}})" ] }, "execution_count": 8, "metadata": { "bento_obj_id": "139717543046792" }, "output_type": "execute_result" } ], "source": [ "means, covariances = values\n", "means, covariances" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Plot response surface\n", "\n", "Contour plot showing classification accuracy as a function of the two hyperparameters.\n", "\n", "The black squares show points that we have actually run, notice how they are clustered in the optimal region." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "render(plot_contour(model=model, param_x='lr', param_y='momentum', metric_name='accuracy'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Plot best objective as function of the iteration\n", "\n", "Show the model accuracy improving as we identify better hyperparameters." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple \n", "# optimization runs, so we wrap out best objectives array in another array.\n", "best_objectives = np.array([[trial.objective_mean*100 for trial in experiment.trials.values()]])\n", "best_objective_plot = optimization_trace_single_method(\n", " y=np.maximum.accumulate(best_objectives, axis=1),\n", " title=\"Model performance vs. # of iterations\",\n", " ylabel=\"Classification Accuracy, %\",\n", ")\n", "render(best_objective_plot)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Train CNN with best hyperparameters and evaluate on test set\n", "Note that the resulting accuracy on the test set might not be exactly the same as the maximum accuracy achieved on the evaluation set throughout optimization. " ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Arm(name='17_0', parameters={'lr': 0.0029176399675537317, 'momentum': 3.0347402313065844e-16})" ] }, "execution_count": 11, "metadata": { "bento_obj_id": "139717742884176" }, "output_type": "execute_result" } ], "source": [ "data = experiment.fetch_data()\n", "df = data.df\n", "best_arm_name = df.arm_name[df['mean'] == df['mean'].max()].values[0]\n", "best_arm = experiment.arms_by_name[best_arm_name]\n", "best_arm" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "net = train(\n", " train_loader=train_loader, \n", " parameters=best_arm.parameters,\n", " dtype=dtype,\n", " device=device,\n", ")\n", "test_accuracy = evaluate(\n", " net=net,\n", " data_loader=test_loader,\n", " dtype=dtype,\n", " device=device,\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Classification Accuracy (test set): 97.06%\n" ] } ], "source": [ "print(f\"Classification Accuracy (test set): {round(test_accuracy*100, 2)}%\")" ] } ], "metadata": { "kernelspec": { "display_name": "python3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }