This tutorial walks through using Ax to tune two hyperparameters (learning rate and momentum) for a PyTorch CNN on the MNIST dataset trained using SGD with momentum.
import torch
import torch.nn as nn
import torch.nn.functional as F
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.service.utils.report_utils import exp_to_df
from ax.utils.notebook.plotting import init_notebook_plotting, render
from ax.utils.tutorials.cnn_utils import evaluate, load_mnist, train
from torch._tensor import Tensor
from import DataLoader
[INFO 09-23 20:33:01] ax.utils.notebook.plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.
[INFO 09-23 20:33:01] ax.utils.notebook.plotting: Please see ( if visualizations are not rendering.