{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Ax Service API with RayTune on PyTorch CNN\n", "\n", "Ax integrates easily with different scheduling frameworks and distributed training frameworks. In this example, Ax-driven optimization is executed in a distributed fashion using [RayTune](https://ray.readthedocs.io/en/latest/tune.html). \n", "\n", "RayTune is a scalable framework for hyperparameter tuning that provides many state-of-the-art hyperparameter tuning algorithms and seamlessly scales from laptop to distributed cluster with fault tolerance. RayTune leverages [Ray](https://ray.readthedocs.io/)'s Actor API to provide asynchronous parallel and distributed execution.\n", "\n", "Ray 'Actors' are a simple and clean abstraction for replicating your Python classes across multiple workers and nodes. Each hyperparameter evaluation is asynchronously executed on a separate Ray actor and reports intermediate training progress back to RayTune. Upon reporting, RayTune then uses this information to performs actions such as early termination, re-prioritization, or checkpointing." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "ename": "ModuleNotFoundError", "evalue": "No module named 'ray'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mray\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtune\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mray\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtune\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrack\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mray\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtune\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msuggest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0max\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAxSearch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mlogger\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetLogger\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtune\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'ray'" ] } ], "source": [ "import logging\n", "from ray import tune\n", "from ray.tune import track\n", "from ray.tune.suggest.ax import AxSearch\n", "logger = logging.getLogger(tune.__name__) \n", "logger.setLevel(level=logging.CRITICAL) # Reduce the number of Ray warnings that are not relevant here." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 06-28 16:41:16] ipy_plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.\n" ] } ], "source": [ "import torch\n", "import numpy as np\n", "\n", "from ax.plot.contour import plot_contour\n", "from ax.plot.trace import optimization_trace_single_method\n", "from ax.service.ax_client import AxClient\n", "from ax.utils.notebook.plotting import render, init_notebook_plotting\n", "from ax.utils.tutorials.cnn_utils import load_mnist, train, evaluate\n", "\n", "\n", "init_notebook_plotting()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Initialize client\n", "We specify `enforce_sequential_optimization` as False, because Ray runs many trials in parallel. With the sequential optimization enforcement, `AxClient` would expect the first few trials to be completed with data before generating more trials.\n", "\n", "When high parallelism is not required, it is best to enforce sequential optimization, as it allows for achieving optimal results in fewer (but sequential) trials. In cases where parallelism is important, such as with distributed training using Ray, we choose to forego minimizing resource utilization and run more trials in parallel." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "ax = AxClient(enforce_sequential_optimization=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Set up experiment\n", "Here we set up the search space and specify the objective; refer to the Ax API tutorials for more detail." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[INFO 06-28 16:41:16] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy. Iterations after 5 will take longer to generate due to model-fitting.\n" ] } ], "source": [ "ax.create_experiment(\n", " name=\"mnist_experiment\",\n", " parameters=[\n", " {\"name\": \"lr\", \"type\": \"range\", \"bounds\": [1e-6, 0.4], \"log_scale\": True},\n", " {\"name\": \"momentum\", \"type\": \"range\", \"bounds\": [0.0, 1.0]},\n", " ],\n", " objective_name=\"mean_accuracy\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Define how to evaluate trials\n", "Since we use the Ax Service API here, we evaluate the parameterizations that Ax suggests, using RayTune. The evaluation function follows its usual pattern, taking in a parameterization and outputting an objective value. For detail on evaluation functions, see [Trial Evaluation](https://ax.dev/docs/runner.html). " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def train_evaluate(parameterization):\n", " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", " train_loader, valid_loader, test_loader = load_mnist(data_path='~/.data')\n", " net = train(train_loader=train_loader, parameters=parameterization, dtype=torch.float, device=device)\n", " track.log(\n", " mean_accuracy=evaluate(\n", " net=net,\n", " data_loader=valid_loader,\n", " dtype=torch.float,\n", " device=device,\n", " )\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Run optimization\n", "Execute the Ax optimization and trial evaluation in RayTune using [AxSearch algorithm](https://ray.readthedocs.io/en/latest/tune-searchalg.html#ax-search):" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2019-06-28 16:41:16,284\tWARNING worker.py:1331 -- WARNING: Not updating worker name since `setproctitle` is not installed. Install this with `pip install setproctitle` (or ray[debug]) to enable monitoring of worker processes.\n", "2019-06-28 16:41:16,286\tINFO node.py:498 -- Process STDOUT and STDERR is being redirected to /tmp/ray/session_2019-06-28_16-41-16_286272_48449/logs.\n", "2019-06-28 16:41:16,408\tINFO services.py:409 -- Waiting for redis server at 127.0.0.1:13143 to respond...\n", "2019-06-28 16:41:16,544\tINFO services.py:409 -- Waiting for redis server at 127.0.0.1:15265 to respond...\n", "2019-06-28 16:41:16,548\tINFO services.py:806 -- Starting Redis shard with 3.44 GB max memory.\n", "2019-06-28 16:41:16,612\tINFO node.py:512 -- Process STDOUT and STDERR is being redirected to /tmp/ray/session_2019-06-28_16-41-16_286272_48449/logs.\n", "2019-06-28 16:41:16,616\tINFO services.py:1442 -- Starting the Plasma object store with 5.15 GB memory using /tmp.\n" ] }, { "data": { "text/plain": [ "[train_evaluate_1_lr=0.003839,momentum=0.13058,\n", " train_evaluate_2_lr=0.040651,momentum=0.93145,\n", " train_evaluate_3_lr=0.00029751,momentum=0.42034,\n", " train_evaluate_4_lr=5.0684e-05,momentum=0.87119,\n", " train_evaluate_5_lr=0.16561,momentum=0.3552,\n", " train_evaluate_6_lr=0.0030413,momentum=0.58571,\n", " train_evaluate_7_lr=2.1341e-05,momentum=0.062879,\n", " train_evaluate_8_lr=6.4852e-06,momentum=0.9718,\n", " train_evaluate_9_lr=0.00087986,momentum=0.48901,\n", " train_evaluate_10_lr=0.2652,momentum=0.69606,\n", " train_evaluate_11_lr=6.9975e-05,momentum=0.21816,\n", " train_evaluate_12_lr=0.00019538,momentum=0.51276,\n", " train_evaluate_13_lr=0.00018449,momentum=1.8421e-17,\n", " train_evaluate_14_lr=0.0010596,momentum=0.0,\n", " train_evaluate_15_lr=0.0035145,momentum=7.3464e-16,\n", " train_evaluate_16_lr=0.00030799,momentum=0.22074,\n", " train_evaluate_17_lr=0.00038031,momentum=0.18887,\n", " train_evaluate_18_lr=1.5967e-05,momentum=0.56366,\n", " train_evaluate_19_lr=0.00064806,momentum=0.080248,\n", " train_evaluate_20_lr=0.0010568,momentum=1.0195e-16,\n", " train_evaluate_21_lr=0.0011429,momentum=7.1418e-17,\n", " train_evaluate_22_lr=0.0015902,momentum=9.7071e-18,\n", " train_evaluate_23_lr=1.7204e-05,momentum=0.63942,\n", " train_evaluate_24_lr=0.00094123,momentum=0.25854,\n", " train_evaluate_25_lr=0.00091246,momentum=0.27806,\n", " train_evaluate_26_lr=0.00089562,momentum=0.29374,\n", " train_evaluate_27_lr=0.00087539,momentum=0.29651,\n", " train_evaluate_28_lr=0.00088501,momentum=0.30274,\n", " train_evaluate_29_lr=0.00075754,momentum=0.35846,\n", " train_evaluate_30_lr=0.00076087,momentum=0.35659]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tune.run(\n", " train_evaluate, \n", " num_samples=30, \n", " search_alg=AxSearch(ax), # Note that the argument here is the `AxClient`.\n", " verbose=0, # Set this level to 1 to see status updates and to 2 to also see trial results.\n", " # To use GPU, specify: resources_per_trial={\"gpu\": 1}.\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Retrieve the optimization results" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'lr': 0.0035144522261298635, 'momentum': 7.346423042648887e-16}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_parameters, values = ax.get_best_parameters()\n", "best_parameters" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'mean_accuracy': 0.969666685940424}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "means, covariances = values\n", "means" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Plot the response surface and optimization trace" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "render(\n", " plot_contour(\n", " model=ax.generation_strategy.model, param_x='lr', param_y='momentum', metric_name='mean_accuracy'\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple \n", "# optimization runs, so we wrap out best objectives array in another array.\n", "best_objectives = np.array([[trial.objective_mean * 100 for trial in ax.experiment.trials.values()]])\n", "best_objective_plot = optimization_trace_single_method(\n", " y=np.maximum.accumulate(best_objectives, axis=1),\n", " title=\"Model performance vs. # of iterations\",\n", " ylabel=\"Accuracy\",\n", ")\n", "render(best_objective_plot)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }