This tutorial walks through using Ax to tune two hyperparameters (learning rate and momentum) for a PyTorch CNN on the MNIST dataset trained using SGD with momentum.
import torch
import torch.nn as nn
import torch.nn.functional as F
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.service.utils.report_utils import exp_to_df
from ax.utils.notebook.plotting import init_notebook_plotting, render
from ax.utils.tutorials.cnn_utils import evaluate, load_mnist, train
from torch._tensor import Tensor
from torch.utils.data import DataLoader
init_notebook_plotting()
[INFO 09-23 20:33:01] ax.utils.notebook.plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.
[INFO 09-23 20:33:01] ax.utils.notebook.plotting: Please see (https://ax.dev/tutorials/visualizations.html#Fix-for-plots-that-are-not-rendering) if visualizations are not rendering.