Ax integrates easily with different scheduling frameworks and distributed training frameworks. In this example, Ax-driven optimization is executed in a distributed fashion using RayTune.
RayTune is a scalable framework for hyperparameter tuning that provides many state-of-the-art hyperparameter tuning algorithms and seamlessly scales from laptop to distributed cluster with fault tolerance. RayTune leverages Ray's Actor API to provide asynchronous parallel and distributed execution.
Ray 'Actors' are a simple and clean abstraction for replicating your Python classes across multiple workers and nodes. Each hyperparameter evaluation is asynchronously executed on a separate Ray actor and reports intermediate training progress back to RayTune. Upon reporting, RayTune then uses this information to performs actions such as early termination, re-prioritization, or checkpointing.
import logging from ray import tune from ray.tune import track from ray.tune.suggest.ax import AxSearch logger = logging.getLogger(tune.__name__) logger.setLevel(level=logging.CRITICAL) # Reduce the number of Ray warnings that are not relevant here.
import torch import numpy as np from ax.plot.contour import plot_contour from ax.plot.trace import optimization_trace_single_method from ax.service.ax_client import AxClient from ax.utils.notebook.plotting import render, init_notebook_plotting from ax.utils.tutorials.cnn_utils import CNN, load_mnist, train, evaluate init_notebook_plotting()
[INFO 06-10 18:33:25] ipy_plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.