{ "cells": [ { "cell_type": "markdown", "metadata": { "code_folding": [], "hidden_ranges": [], "originalKey": "08064d6a-453e-44d7-85dc-896d40b6303a", "showInput": true }, "source": [ "# Developer API Example on Hartmann6\n", "\n", "The Developer API is suitable when the user wants maximal customization of the optimization loop. This tutorial demonstrates optimization of a Hartmann6 function using the `Experiment` construct. In this example, trials will be evaluated synchronously." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "code_folding": [], "collapsed": false, "hidden_ranges": [], "originalKey": "e45e47fa-35d2-4a1a-86db-1cf5fa0c8a62", "requestMsgId": "c0fdff2d-6317-4b7e-9056-a064c5d2e650" }, "outputs": [ { "data": { "text/html": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 09-15 02:35:41] ax.utils.notebook.plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.\n" ] }, { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from ax import (\n", " ComparisonOp,\n", " ParameterType, \n", " RangeParameter,\n", " ChoiceParameter,\n", " FixedParameter,\n", " SearchSpace, \n", " Experiment, \n", " OutcomeConstraint, \n", " OrderConstraint,\n", " SumConstraint,\n", " OptimizationConfig,\n", " Objective,\n", " Metric,\n", ")\n", "from ax.utils.notebook.plotting import render, init_notebook_plotting\n", "\n", "init_notebook_plotting()" ] }, { "cell_type": "markdown", "metadata": { "code_folding": [], "hidden_ranges": [], "originalKey": "f522bb04-8372-4647-8c90-cffb8a664be3", "showInput": true }, "source": [ "## 1. Create Search Space\n", "\n", "First, we define a search space, which defines the type and allowed range for the parameters." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "originalKey": "38418604-f280-4cec-ba3f-a8d987680a96", "requestMsgId": "d7ccd747-2096-4254-84f3-a619b7c3b473" }, "outputs": [], "source": [ "hartmann_search_space = SearchSpace(\n", " parameters=[\n", " RangeParameter(\n", " name=f\"x{i}\", parameter_type=ParameterType.FLOAT, lower=0.0, upper=1.0\n", " )\n", " for i in range(6)\n", " ]\n", ")" ] }, { "cell_type": "markdown", "metadata": { "customInput": null, "originalKey": "9e0c312c-e290-4e7b-bf9c-45bd5c360c25", "showInput": false }, "source": [ "Note that there are two other parameter classes, FixedParameter and ChoiceParameter. Although we won't use these in this example, you can create them as follows.\n", "\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "code_folding": [], "collapsed": false, "customInput": null, "hidden_ranges": [], "originalKey": "09fd4723-021f-4569-9b5c-fe5aa5ab6835", "requestMsgId": "45f10ccd-2318-4674-8017-9f0e2b5bfcf1", "showInput": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[WARNING 09-15 02:35:42] ax.core.parameter: `is_ordered` is not specified for `ChoiceParameter` \"choice\". Defaulting to `False` for parameters of `ParameterType` STRING. To override this behavior (or avoid this warning), specify `is_ordered` during `ChoiceParameter` construction.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[WARNING 09-15 02:35:42] ax.core.parameter: `sort_values` is not specified for `ChoiceParameter` \"choice\". Defaulting to `False` for parameters of `ParameterType` STRING. To override this behavior (or avoid this warning), specify `sort_values` during `ChoiceParameter` construction.\n" ] } ], "source": [ "choice_param = ChoiceParameter(name=\"choice\", values=[\"foo\", \"bar\"], parameter_type=ParameterType.STRING)\n", "fixed_param = FixedParameter(name=\"fixed\", value=[True], parameter_type=ParameterType.BOOL)" ] }, { "cell_type": "markdown", "metadata": { "customInput": null, "originalKey": "75b46af0-9739-46a6-9b95-21c8e2e9e22a", "showInput": false }, "source": [ "Sum constraints enforce that the sum of a set of parameters is greater or less than some bound, and order constraints enforce that one parameter is smaller than the other. We won't use these either, but see two examples below.\n", "\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "code_folding": [], "collapsed": false, "customInput": null, "hidden_ranges": [], "originalKey": "fe8edcef-6b44-4db8-ba7d-9a3e60ed5896", "requestMsgId": "6175df8e-0376-41db-a96e-e1dd0dd46544", "showInput": true }, "outputs": [], "source": [ "sum_constraint = SumConstraint(\n", " parameters=[hartmann_search_space.parameters['x0'], hartmann_search_space.parameters['x1']], \n", " is_upper_bound=True, \n", " bound=5.0,\n", ")\n", "\n", "order_constraint = OrderConstraint(\n", " lower_parameter = hartmann_search_space.parameters['x0'],\n", " upper_parameter = hartmann_search_space.parameters['x1'],\n", ")" ] }, { "cell_type": "markdown", "metadata": { "code_folding": [], "hidden_ranges": [], "originalKey": "7bf887e2-2b02-4237-ba5e-6fa8beaa85fb", "showInput": false }, "source": [ "## 2. Create Optimization Config\n", "\n", "Second, we define the `optimization_config` with an `objective` and `outcome_constraints`.\n", "\n", "When doing the optimization, we will find points that minimize the objective while obeying the constraints (which in this case means `l2norm < 1.25`).\n", "\n", "Note: we are using `Hartmann6Metric` and `L2NormMetric` here, which have built in evaluation functions for testing. For creating your own cutom metrics, see [8. Defining custom metrics](#8.-Defining-custom-metrics)." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "code_folding": [], "collapsed": false, "hidden_ranges": [], "originalKey": "692b9ecc-9c67-42d8-ab08-c7108293b043", "requestMsgId": "44682533-bb98-40fd-add2-ec0abd643568" }, "outputs": [], "source": [ "from ax.metrics.l2norm import L2NormMetric\n", "from ax.metrics.hartmann6 import Hartmann6Metric\n", "\n", "param_names = [f\"x{i}\" for i in range(6)]\n", "optimization_config = OptimizationConfig(\n", " objective = Objective(\n", " metric=Hartmann6Metric(name=\"hartmann6\", param_names=param_names), \n", " minimize=True,\n", " ),\n", " outcome_constraints=[\n", " OutcomeConstraint(\n", " metric=L2NormMetric(\n", " name=\"l2norm\", param_names=param_names, noise_sd=0.2\n", " ),\n", " op=ComparisonOp.LEQ,\n", " bound=1.25,\n", " relative=False,\n", " )\n", " ],\n", ")" ] }, { "cell_type": "markdown", "metadata": { "code_folding": [], "customInput": null, "hidden_ranges": [], "originalKey": "ed80a5e4-4786-4961-979e-22a295bfa7f0", "showInput": false }, "source": [ "## 3. Define a Runner\n", "Before an experiment can collect data, it must have a Runner attached. A runner handles the deployment of trials. A trial must be \"run\" before it can be evaluated.\n", "\n", "Here, we have a dummy runner that does nothing. In practice, a runner might be in charge of pushing an experiment to production.\n", "\n", "The only method that needs to be defined for runner subclasses is run, which performs any necessary deployment logic, and returns a dictionary of resulting metadata. This metadata can later be accessed through the trial's `run_metadata` property." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "code_folding": [], "collapsed": false, "customInput": null, "hidden_ranges": [], "originalKey": "2e2aca22-d04f-468a-a3a2-cb99115d8998", "requestMsgId": "03408af3-c7d7-41b1-b220-00d795c9e40f", "showInput": true }, "outputs": [], "source": [ "from ax import Runner\n", "\n", "class MyRunner(Runner):\n", " def run(self, trial):\n", " trial_metadata = {\"name\": str(trial.index)}\n", " return trial_metadata" ] }, { "cell_type": "markdown", "metadata": { "code_folding": [], "customInput": null, "hidden_ranges": [], "originalKey": "131ab2a9-e2c7-4752-99a3-547c7dbe42ec", "showInput": false }, "source": [ "## 4. Create Experiment\n", "Next, we make an `Experiment` with our search space, runner, and optimization config." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "code_folding": [], "collapsed": false, "customInput": null, "hidden_ranges": [], "originalKey": "7618dd5f-bba2-4c3c-8d6c-f281ec6aae80", "requestMsgId": "c61e7d66-e9be-439c-a471-b680d81b375d", "showInput": true }, "outputs": [], "source": [ "exp = Experiment(\n", " name=\"test_hartmann\",\n", " search_space=hartmann_search_space,\n", " optimization_config=optimization_config,\n", " runner=MyRunner(),\n", ")" ] }, { "cell_type": "markdown", "metadata": { "code_folding": [], "hidden_ranges": [], "originalKey": "8a04eba9-97f2-45f7-8b10-7216fe9c0101", "showInput": true }, "source": [ "## 5. Perform Optimization\n", "\n", "Run the optimization using the settings defined on the experiment. We will create 5 random sobol points for exploration followed by 15 points generated using the GPEI optimizer.\n", "\n", "Instead of a member of the `Models` enum to produce generator runs, users can leverage a `GenerationStrategy`. See the [Generation Strategy Tutorial](https://ax.dev/tutorials/generation_strategy.html) for more info." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "code_folding": [], "collapsed": false, "hidden_ranges": [], "originalKey": "981eff70-2598-40e8-a850-49d188c82cfa", "requestMsgId": "b7a41ce7-1b33-4b06-add9-83c6516410cd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running Sobol initialization trials...\n", "Running GP+EI optimization trial 6/20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running GP+EI optimization trial 7/20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running GP+EI optimization trial 8/20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running GP+EI optimization trial 9/20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running GP+EI optimization trial 10/20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running GP+EI optimization trial 11/20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running GP+EI optimization trial 12/20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running GP+EI optimization trial 13/20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running GP+EI optimization trial 14/20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running GP+EI optimization trial 15/20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running GP+EI optimization trial 16/20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running GP+EI optimization trial 17/20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running GP+EI optimization trial 18/20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running GP+EI optimization trial 19/20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running GP+EI optimization trial 20/20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Done!\n" ] } ], "source": [ "from ax.modelbridge.registry import Models\n", "\n", "NUM_SOBOL_TRIALS = 5\n", "NUM_BOTORCH_TRIALS = 15\n", "\n", "print(f\"Running Sobol initialization trials...\")\n", "sobol = Models.SOBOL(search_space=exp.search_space)\n", " \n", "for i in range(NUM_SOBOL_TRIALS):\n", " # Produce a GeneratorRun from the model, which contains proposed arm(s) and other metadata\n", " generator_run = sobol.gen(n=1)\n", " # Add generator run to a trial to make it part of the experiment and evaluate arm(s) in it\n", " trial = exp.new_trial(generator_run=generator_run)\n", " # Start trial run to evaluate arm(s) in the trial\n", " trial.run()\n", " # Mark trial as completed to record when a trial run is completed \n", " # and enable fetching of data for metrics on the experiment \n", " # (by default, trials must be completed before metrics can fetch their data,\n", " # unless a metric is explicitly configured otherwise)\n", " trial.mark_completed()\n", "\n", "for i in range(NUM_BOTORCH_TRIALS):\n", " print(\n", " f\"Running GP+EI optimization trial {i + NUM_SOBOL_TRIALS + 1}/{NUM_SOBOL_TRIALS + NUM_BOTORCH_TRIALS}...\"\n", " )\n", " # Reinitialize GP+EI model at each step with updated data.\n", " gpei = Models.BOTORCH(experiment=exp, data=exp.fetch_data())\n", " generator_run = gpei.gen(n=1)\n", " trial = exp.new_trial(generator_run=generator_run)\n", " trial.run()\n", " trial.mark_completed()\n", " \n", "print(\"Done!\")" ] }, { "cell_type": "markdown", "metadata": { "code_folding": [], "hidden_ranges": [], "originalKey": "f503e648-e3f2-419f-a60e-5bfcbc6775bd", "showInput": true }, "source": [ "## 6. Inspect trials' data\n", "\n", "Now we can inspect the `Experiment`'s data by calling `fetch_data()`, which retrieves evaluation data for all trials of the experiment.\n", "\n", "To fetch trial data, we need to run it and mark it completed. For most metrics in Ax, data is only available once the status of the trial is `COMPLETED`, since in real-worlds scenarios, metrics can typically only be fetched after the trial finished running.\n", "\n", "NOTE: Metrics classes may implement the `is_available_while_running` method. When this method returns `True`, data is available when trials are either `RUNNING` or `COMPLETED`. This can be used to obtain intermediate results from A/B test trials and other online experiments, or when metric values are available immediately, like in the case of synthetic problem metrics.\n", "\n", "We can also use the `fetch_trials_data` function to get evaluation data for a specific trials in the experiment, like so:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "code_folding": [], "collapsed": false, "hidden_ranges": [], "originalKey": "8887b621-efff-4dec-9109-1480db461b4e", "requestMsgId": "9ce39f1b-7e3f-4388-a729-8fb6c83d7cf7" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
arm_namemetric_namemeansemtrial_indexnfrac_nonnull
019_0l2norm0.7184550.219100000.718455
119_0hartmann6-3.3048980.01910000-3.304898
\n", "
" ], "text/plain": [ " arm_name metric_name mean sem trial_index n frac_nonnull\n", "0 19_0 l2norm 0.718455 0.2 19 10000 0.718455\n", "1 19_0 hartmann6 -3.304898 0.0 19 10000 -3.304898" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trial_data = exp.fetch_trials_data([NUM_SOBOL_TRIALS + NUM_BOTORCH_TRIALS - 1])\n", "trial_data.df" ] }, { "cell_type": "markdown", "metadata": { "code_folding": [], "customInput": null, "hidden_ranges": [], "originalKey": "74fa0cc5-075c-4b34-98b9-cbf74ad5bb26", "showInput": true }, "source": [ "The below call to `exp.fetch_data()` also attaches data to the last trial, which because of the way we looped through Botorch trials in [5. Perform Optimization](5.-Perform-Optimization), would otherwise not have data attached. This is necessary to get `objective_means` in [7. Plot results](7.-Plot-results)." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "code_folding": [], "collapsed": false, "customInput": null, "executionStartTime": 1626971528112, "executionStopTime": 1626971529427, "hidden_ranges": [], "originalKey": "753afdcb-4c20-4a9f-aa10-90b85cf0fbcd", "requestMsgId": "4aa093dd-d81d-40bb-9931-517bad31bdcb", "showInput": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
arm_namemetric_namemeansemtrial_indexnfrac_nonnull
00_0l2norm1.0606870.20100001.060687
11_0l2norm1.8009100.21100001.800910
22_0l2norm1.6710240.22100001.671024
33_0l2norm1.1124370.23100001.112437
44_0l2norm1.1414520.24100001.141452
55_0l2norm0.9995610.25100000.999561
66_0l2norm1.2357980.26100001.235798
77_0l2norm1.5813710.27100001.581371
88_0l2norm1.2503440.28100001.250344
99_0l2norm0.9269850.29100000.926985
1010_0l2norm0.9630190.210100000.963019
1111_0l2norm0.8143330.211100000.814333
1212_0l2norm0.8288670.212100000.828867
1313_0l2norm1.2055050.213100001.205505
1414_0l2norm0.9374940.214100000.937494
1515_0l2norm0.8943570.215100000.894357
1616_0l2norm0.7662930.216100000.766293
1717_0l2norm0.5599320.217100000.559932
1818_0l2norm1.0071020.218100001.007102
1919_0l2norm0.4552380.219100000.455238
200_0hartmann6-1.5555730.0010000-1.555573
211_0hartmann6-0.0001600.0110000-0.000160
222_0hartmann6-0.0051490.0210000-0.005149
233_0hartmann6-0.2327610.0310000-0.232761
244_0hartmann6-0.4406180.0410000-0.440618
255_0hartmann6-1.8254000.0510000-1.825400
266_0hartmann6-1.4795900.0610000-1.479590
277_0hartmann6-2.0306110.0710000-2.030611
288_0hartmann6-1.8568580.0810000-1.856858
299_0hartmann6-1.3862380.0910000-1.386238
3010_0hartmann6-2.3939340.01010000-2.393934
3111_0hartmann6-2.6938220.01110000-2.693822
3212_0hartmann6-2.7695310.01210000-2.769531
3313_0hartmann6-3.1157810.01310000-3.115781
3414_0hartmann6-2.9377820.01410000-2.937782
3515_0hartmann6-3.1256240.01510000-3.125624
3616_0hartmann6-3.1814250.01610000-3.181425
3717_0hartmann6-3.2157190.01710000-3.215719
3818_0hartmann6-3.2086670.01810000-3.208667
3919_0hartmann6-3.3048980.01910000-3.304898
\n", "
" ], "text/plain": [ " arm_name metric_name mean sem trial_index n frac_nonnull\n", "0 0_0 l2norm 1.060687 0.2 0 10000 1.060687\n", "1 1_0 l2norm 1.800910 0.2 1 10000 1.800910\n", "2 2_0 l2norm 1.671024 0.2 2 10000 1.671024\n", "3 3_0 l2norm 1.112437 0.2 3 10000 1.112437\n", "4 4_0 l2norm 1.141452 0.2 4 10000 1.141452\n", "5 5_0 l2norm 0.999561 0.2 5 10000 0.999561\n", "6 6_0 l2norm 1.235798 0.2 6 10000 1.235798\n", "7 7_0 l2norm 1.581371 0.2 7 10000 1.581371\n", "8 8_0 l2norm 1.250344 0.2 8 10000 1.250344\n", "9 9_0 l2norm 0.926985 0.2 9 10000 0.926985\n", "10 10_0 l2norm 0.963019 0.2 10 10000 0.963019\n", "11 11_0 l2norm 0.814333 0.2 11 10000 0.814333\n", "12 12_0 l2norm 0.828867 0.2 12 10000 0.828867\n", "13 13_0 l2norm 1.205505 0.2 13 10000 1.205505\n", "14 14_0 l2norm 0.937494 0.2 14 10000 0.937494\n", "15 15_0 l2norm 0.894357 0.2 15 10000 0.894357\n", "16 16_0 l2norm 0.766293 0.2 16 10000 0.766293\n", "17 17_0 l2norm 0.559932 0.2 17 10000 0.559932\n", "18 18_0 l2norm 1.007102 0.2 18 10000 1.007102\n", "19 19_0 l2norm 0.455238 0.2 19 10000 0.455238\n", "20 0_0 hartmann6 -1.555573 0.0 0 10000 -1.555573\n", "21 1_0 hartmann6 -0.000160 0.0 1 10000 -0.000160\n", "22 2_0 hartmann6 -0.005149 0.0 2 10000 -0.005149\n", "23 3_0 hartmann6 -0.232761 0.0 3 10000 -0.232761\n", "24 4_0 hartmann6 -0.440618 0.0 4 10000 -0.440618\n", "25 5_0 hartmann6 -1.825400 0.0 5 10000 -1.825400\n", "26 6_0 hartmann6 -1.479590 0.0 6 10000 -1.479590\n", "27 7_0 hartmann6 -2.030611 0.0 7 10000 -2.030611\n", "28 8_0 hartmann6 -1.856858 0.0 8 10000 -1.856858\n", "29 9_0 hartmann6 -1.386238 0.0 9 10000 -1.386238\n", "30 10_0 hartmann6 -2.393934 0.0 10 10000 -2.393934\n", "31 11_0 hartmann6 -2.693822 0.0 11 10000 -2.693822\n", "32 12_0 hartmann6 -2.769531 0.0 12 10000 -2.769531\n", "33 13_0 hartmann6 -3.115781 0.0 13 10000 -3.115781\n", "34 14_0 hartmann6 -2.937782 0.0 14 10000 -2.937782\n", "35 15_0 hartmann6 -3.125624 0.0 15 10000 -3.125624\n", "36 16_0 hartmann6 -3.181425 0.0 16 10000 -3.181425\n", "37 17_0 hartmann6 -3.215719 0.0 17 10000 -3.215719\n", "38 18_0 hartmann6 -3.208667 0.0 18 10000 -3.208667\n", "39 19_0 hartmann6 -3.304898 0.0 19 10000 -3.304898" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "exp.fetch_data().df" ] }, { "cell_type": "markdown", "metadata": { "code_folding": [], "hidden_ranges": [], "originalKey": "940865f9-af61-4668-aea0-b19ed5c5497d", "showInput": false }, "source": [ "## 7. Plot results\n", "Now we can plot the results of our optimization:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "code_folding": [], "collapsed": false, "hidden_ranges": [], "originalKey": "6732e438-ba5a-46a4-b1c9-8d5789fa8dd8", "requestMsgId": "2f5e4110-a5f4-418c-9308-43226426816a" }, "outputs": [ { "data": { "application/vnd.plotly.v1+json": { "config": { "linkText": "Export to plot.ly", "plotlyServerURL": "https://plot.ly", "showLink": false }, "data": [ { "hoverinfo": "none", "legendgroup": "", "line": { "width": 0 }, "mode": "lines", "showlegend": false, "type": "scatter", "x": [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ], "y": [ -1.5555730789526736, -1.5555730789526736, -1.5555730789526736, -1.5555730789526736, -1.5555730789526736, -1.8253997923574483, -1.8253997923574483, -2.0306114349291735, -2.0306114349291735, -2.0306114349291735, -2.3939341364020175, -2.69382236465406, -2.769530547290562, -3.1157808435133214, -3.1157808435133214, -3.125623542179092, -3.1814248624777473, -3.215718892288139, -3.215718892288139, -3.304897986830241 ] }, { "fill": "tonexty", "fillcolor": "rgba(128,177,211,0.3)", "legendgroup": "objective value", "line": { "color": "rgba(128,177,211,1)" }, "mode": "lines", "name": "objective value", "type": "scatter", "x": [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ], "y": [ -1.5555730789526736, -1.5555730789526736, -1.5555730789526736, -1.5555730789526736, -1.5555730789526736, -1.8253997923574483, -1.8253997923574483, -2.0306114349291735, -2.0306114349291735, -2.0306114349291735, -2.3939341364020175, -2.69382236465406, -2.769530547290562, -3.1157808435133214, -3.1157808435133214, -3.125623542179092, -3.1814248624777473, -3.215718892288139, -3.215718892288139, -3.304897986830241 ] }, { "fill": "tonexty", "fillcolor": "rgba(128,177,211,0.3)", "hoverinfo": "none", "legendgroup": "", "line": { "width": 0 }, "mode": "lines", "showlegend": false, "type": "scatter", "x": [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ], "y": [ -1.5555730789526736, -1.5555730789526736, -1.5555730789526736, -1.5555730789526736, -1.5555730789526736, -1.8253997923574483, -1.8253997923574483, -2.0306114349291735, -2.0306114349291735, -2.0306114349291735, -2.3939341364020175, -2.69382236465406, -2.769530547290562, -3.1157808435133214, -3.1157808435133214, -3.125623542179092, -3.1814248624777473, -3.215718892288139, -3.215718892288139, -3.304897986830241 ] }, { "line": { "color": "rgba(253,180,98,1)", "dash": "dash" }, "mode": "lines", "name": "Optimum", "type": "scatter", "x": [ 1, 20 ], "y": [ -3.32237, -3.32237 ] } ], "layout": { "showlegend": true, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ], "sequentialminus": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "text": "" }, "xaxis": { "title": { "text": "Iteration" } }, "yaxis": { "title": { "text": "" } } } }, "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import numpy as np\n", "from ax.plot.trace import optimization_trace_single_method\n", "\n", "# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple \n", "# optimization runs, so we wrap out best objectives array in another array.\n", "objective_means = np.array([[trial.objective_mean for trial in exp.trials.values()]])\n", "best_objective_plot = optimization_trace_single_method(\n", " y=np.minimum.accumulate(objective_means, axis=1),\n", " optimum=-3.32237, # Known minimum objective for Hartmann6 function.\n", ")\n", "render(best_objective_plot)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "code_folding": [], "customInput": null, "hidden_ranges": [], "originalKey": "934db3fd-1dce-421b-8228-820025f3821a", "showInput": true }, "source": [ "## 8. Defining custom metrics\n", "In order to perform an optimization, we also need to define an optimization config for the experiment. An optimization config is composed of an objective metric to be minimized or maximized in the experiment, and optionally a set of outcome constraints that place restrictions on how other metrics can be moved by the experiment.\n", "\n", "In order to define an objective or outcome constraint, we first need to subclass Metric. Metrics are used to evaluate trials, which are individual steps of the experiment sequence. Each trial contains one or more arms for which we will collect data at the same time.\n", "\n", "Our custom metric(s) will determine how, given a trial, to compute the mean and SEM of each of the trial's arms.\n", "\n", "The only method that needs to be defined for most metric subclasses is `fetch_trial_data`, which defines how a single trial is evaluated, and returns a pandas dataframe.\n", " \n", "The `is_available_while_running` method is optional and returns a boolean, specifying whether the trial data can be fetched before the trial is complete. See [6. Inspect trials' data](6.-Inspect-trials'-data) for more details." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "code_folding": [], "collapsed": false, "customInput": null, "hidden_ranges": [], "originalKey": "21bbb0ae-f3e3-4a5a-ac62-53934c3f1a4b", "requestMsgId": "78b8c7d4-3915-4cd2-9bd9-dde0ddeaa65f", "showInput": true }, "outputs": [], "source": [ "from ax import Data\n", "import pandas as pd\n", "\n", "\n", "class BoothMetric(Metric):\n", " def fetch_trial_data(self, trial): \n", " records = []\n", " for arm_name, arm in trial.arms_by_name.items():\n", " params = arm.parameters\n", " records.append({\n", " \"arm_name\": arm_name,\n", " \"metric_name\": self.name,\n", " \"trial_index\": trial.index,\n", " # in practice, the mean and sem will be looked up based on trial metadata\n", " # but for this tutorial we will calculate them\n", " \"mean\": (params[\"x1\"] + 2*params[\"x2\"] - 7)**2 + (2*params[\"x1\"] + params[\"x2\"] - 5)**2,\n", " \"sem\": 0.0,\n", " })\n", " return Data(df=pd.DataFrame.from_records(records))\n", "\n", " def is_available_while_running(self) -> bool:\n", " return True" ] }, { "cell_type": "markdown", "metadata": { "code_folding": [], "customInput": null, "hidden_ranges": [], "originalKey": "92fcddf9-9d86-45cd-b9fb-a0a7acdb267d", "showInput": false }, "source": [ "## 9. Save to JSON or SQL\n", "At any point, we can also save our experiment to a JSON file. To ensure that our custom metrics and runner are saved properly, we first need to register them." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "code_folding": [], "collapsed": false, "customInput": null, "hidden_ranges": [], "originalKey": "a2202ebe-b6fa-4bb9-91a5-6b505ce2d3fc", "requestMsgId": "d9f25a3b-ca8b-4283-86de-f7ec537723ab", "showInput": true }, "outputs": [], "source": [ "from ax.storage.metric_registry import register_metric\n", "from ax.storage.runner_registry import register_runner\n", "\n", "register_metric(BoothMetric)\n", "register_metric(L2NormMetric)\n", "register_metric(Hartmann6Metric)\n", "register_runner(MyRunner)\n", "\n", "from ax.storage.json_store.load import load_experiment\n", "from ax.storage.json_store.save import save_experiment\n", "\n", "save_experiment(exp, \"experiment.json\")" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "code_folding": [], "collapsed": false, "customInput": null, "hidden_ranges": [], "originalKey": "be389a75-f37f-446b-aeeb-3bd647a457fc", "requestMsgId": "82385488-d4ec-4052-869f-792c4606bcec", "showInput": true }, "outputs": [], "source": [ "loaded_experiment = load_experiment(\"experiment.json\")" ] }, { "cell_type": "markdown", "metadata": { "customInput": null, "originalKey": "dc1f6800-437e-45de-85d3-276ae5f8ca99", "showInput": false }, "source": [ "To save our experiment to SQL, we must first specify a connection to a database and create all necessary tables.\n", "\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "code_folding": [], "collapsed": false, "customInput": null, "hidden_ranges": [], "originalKey": "b9cb7a42-972b-43b6-832b-8c08137a403e", "requestMsgId": "502a1d87-43a8-4c91-bf04-3b72d10e8fd6", "showInput": true }, "outputs": [], "source": [ "from ax.storage.sqa_store.db import init_engine_and_session_factory,get_engine, create_all_tables\n", "from ax.storage.sqa_store.load import load_experiment\n", "from ax.storage.sqa_store.save import save_experiment\n", "\n", "init_engine_and_session_factory(url='sqlite:///foo3.db')\n", "\n", "engine = get_engine()\n", "create_all_tables(engine)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "code_folding": [], "collapsed": false, "customInput": null, "hidden_ranges": [], "originalKey": "61282918-fc05-4a33-9877-bfe49e3fc934", "requestMsgId": "c01febf4-1f6d-49a9-b647-0405949e232b", "showInput": true }, "outputs": [], "source": [ "exp.name = \"new\"\n", "save_experiment(exp)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "code_folding": [], "collapsed": false, "customInput": null, "hidden_ranges": [], "originalKey": "e8ea6d47-a91e-47f6-b48a-63f60c11b7bc", "requestMsgId": "2e8c6e5d-b206-497e-8675-0768105e678d", "showInput": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/runner/work/Ax/Ax/ax/storage/sqa_store/load.py:235: SAWarning:\n", "\n", "TypeDecorator JSONEncodedText() will not produce a cache key because the ``cache_ok`` flag is not set to True. Set this flag to True if this type object's state is safe to use in a cache key, or False to disable this warning.\n", "\n" ] }, { "data": { "text/plain": [ "Experiment(new)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "load_experiment(exp.name)" ] } ], "metadata": { "kernelspec": { "display_name": "python3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.11" } }, "nbformat": 4, "nbformat_minor": 2 }