Ax integrates easily with different scheduling frameworks and distributed training frameworks. In this example, Ax-driven optimization is executed in a distributed fashion using RayTune.
RayTune is a scalable framework for hyperparameter tuning that provides many state-of-the-art hyperparameter tuning algorithms and seamlessly scales from laptop to distributed cluster with fault tolerance. RayTune leverages Ray's Actor API to provide asynchronous parallel and distributed execution.
Ray 'Actors' are a simple and clean abstraction for replicating your Python classes across multiple workers and nodes. Each hyperparameter evaluation is asynchronously executed on a separate Ray actor and reports intermediate training progress back to RayTune. Upon reporting, RayTune then uses this information to performs actions such as early termination, re-prioritization, or checkpointing.
import logging
from ray import tune
from ray.tune import track
from ray.tune.suggest.ax import AxSearch
logger = logging.getLogger(tune.__name__)
logger.setLevel(level=logging.CRITICAL) # Reduce the number of Ray warnings that are not relevant here.
import torch
import numpy as np
from ax.plot.contour import plot_contour
from ax.plot.trace import optimization_trace_single_method
from ax.service.ax_client import AxClient
from ax.utils.notebook.plotting import render, init_notebook_plotting
from ax.utils.tutorials.cnn_utils import CNN, load_mnist, train, evaluate
init_notebook_plotting()
[INFO 10-01 15:42:23] ax.utils.notebook.plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.
We specify enforce_sequential_optimization
as False, because Ray runs many trials in parallel. With the sequential optimization enforcement, AxClient
would expect the first few trials to be completed with data before generating more trials.
When high parallelism is not required, it is best to enforce sequential optimization, as it allows for achieving optimal results in fewer (but sequential) trials. In cases where parallelism is important, such as with distributed training using Ray, we choose to forego minimizing resource utilization and run more trials in parallel.
ax = AxClient(enforce_sequential_optimization=False)
[INFO 10-01 15:42:23] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 2 decimal points.
Here we set up the search space and specify the objective; refer to the Ax API tutorials for more detail.
ax.create_experiment(
name="mnist_experiment",
parameters=[
{"name": "lr", "type": "range", "bounds": [1e-6, 0.4], "log_scale": True},
{"name": "momentum", "type": "range", "bounds": [0.0, 1.0]},
],
objective_name="mean_accuracy",
)
[INFO 10-01 15:42:23] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter lr. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict. [INFO 10-01 15:42:23] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter momentum. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict. [INFO 10-01 15:42:23] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting.
Since we use the Ax Service API here, we evaluate the parameterizations that Ax suggests, using RayTune. The evaluation function follows its usual pattern, taking in a parameterization and outputting an objective value. For detail on evaluation functions, see Trial Evaluation.
def train_evaluate(parameterization):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, valid_loader, test_loader = load_mnist(data_path='~/.data')
net = train(net=CNN(), train_loader=train_loader, parameters=parameterization, dtype=torch.float, device=device)
track.log(
mean_accuracy=evaluate(
net=net,
data_loader=valid_loader,
dtype=torch.float,
device=device,
)
)
Execute the Ax optimization and trial evaluation in RayTune using AxSearch algorithm:
tune.run(
train_evaluate,
num_samples=30,
search_alg=AxSearch(ax), # Note that the argument here is the `AxClient`.
verbose=0, # Set this level to 1 to see status updates and to 2 to also see trial results.
# To use GPU, specify: resources_per_trial={"gpu": 1}.
)
[INFO 10-01 15:42:23] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 2 decimal points.
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-6-5f9e4648a7f2> in <module> 2 train_evaluate, 3 num_samples=30, ----> 4 search_alg=AxSearch(ax), # Note that the argument here is the `AxClient`. 5 verbose=0, # Set this level to 1 to see status updates and to 2 to also see trial results. 6 # To use GPU, specify: resources_per_trial={"gpu": 1}. ~/virtualenv/python3.7.1/lib/python3.7/site-packages/ray/tune/suggest/ax.py in __init__(self, space, metric, mode, parameter_constraints, outcome_constraints, ax_client, use_early_stopped_trials, max_concurrent) 134 135 if self._ax or self._space: --> 136 self.setup_experiment() 137 138 def setup_experiment(self): ~/virtualenv/python3.7.1/lib/python3.7/site-packages/ray/tune/suggest/ax.py in setup_experiment(self) 158 parameter_constraints=self._parameter_constraints, 159 outcome_constraints=self._outcome_constraints, --> 160 minimize=self._mode != "max") 161 else: 162 if any([ ~/build/facebook/Ax/ax/service/ax_client.py in create_experiment(self, parameters, name, objective_name, minimize, parameter_constraints, outcome_constraints, status_quo, overwrite_existing_experiment, experiment_type, choose_generation_strategy_kwargs) 261 outcome_constraints=outcome_constraints, 262 status_quo=status_quo, --> 263 experiment_type=experiment_type, 264 ) 265 ~/build/facebook/Ax/ax/service/utils/instantiation.py in make_experiment(parameters, name, objective_name, minimize, parameter_constraints, outcome_constraints, status_quo, experiment_type) 327 without importing or instantiating any Ax classes.""" 328 --> 329 exp_parameters: List[Parameter] = [parameter_from_json(p) for p in parameters] 330 status_quo_arm = None if status_quo is None else Arm(parameters=status_quo) 331 parameter_map = {p.name: p for p in exp_parameters} TypeError: 'AxClient' object is not iterable
best_parameters, values = ax.get_best_parameters()
best_parameters
{'lr': 0.001126015869936531, 'momentum': 0.4754412202164531}
means, covariances = values
means
{'mean_accuracy': 0.9616665982503324}
render(
plot_contour(
model=ax.generation_strategy.model, param_x='lr', param_y='momentum', metric_name='mean_accuracy'
)
)
# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple
# optimization runs, so we wrap out best objectives array in another array.
best_objectives = np.array([[trial.objective_mean * 100 for trial in ax.experiment.trials.values()]])
best_objective_plot = optimization_trace_single_method(
y=np.maximum.accumulate(best_objectives, axis=1),
title="Model performance vs. # of iterations",
ylabel="Accuracy",
)
render(best_objective_plot)