{ "cells": [ { "cell_type": "markdown", "id": "bd2597f3", "metadata": { "code_folding": [], "hidden_ranges": [], "originalKey": "bbfd01ea-97cb-4830-ab6d-60236151a3cd", "papermill": { "duration": 0.003275, "end_time": "2024-11-13T05:25:01.405113", "exception": false, "start_time": "2024-11-13T05:25:01.401838", "status": "completed" }, "showInput": false, "tags": [] }, "source": [ "# Multi-task Bayesian Optimization\n", "\n", "This tutorial uses synthetic functions to illustrate Bayesian optimization using a multi-task Gaussian Process in Ax. A typical use case is optimizing an expensive-to-evaluate (online) system with supporting (offline) simulations of that system.\n", "\n", "Bayesian optimization with a multi-task kernel (Multi-task Bayesian optimization) is described by Swersky et al. (2013). Letham and Bakshy (2019) describe using multi-task Bayesian optimization to tune a ranking system with a mix of online and offline (simulator) experiments.\n", "\n", "This tutorial produces the results of Online Appendix 2 from [that paper](https://arxiv.org/pdf/1904.01049.pdf).\n", "\n", "The synthetic problem used here is to maximize the Hartmann 6 function, a classic optimization test problem in 6 dimensions. The objective is treated as unknown and are modeled with separate GPs. The objective is noisy.\n", "\n", "Throughout the optimization we can make nosiy observations directly of the objective (an online observation), and we can make noisy observations of a biased version of the objective (offline observations). Bias is simulated by passing the function values through a piecewise linear function. Offline observations are much less time-consuming than online observations, so we wish to use them to improve our ability to optimize the online objective." ] }, { "cell_type": "code", "execution_count": 1, "id": "19d85cf8", "metadata": { "code_folding": [], "execution": { "iopub.execute_input": "2024-11-13T05:25:01.412328Z", "iopub.status.busy": "2024-11-13T05:25:01.412021Z", "iopub.status.idle": "2024-11-13T05:25:04.008416Z", "shell.execute_reply": "2024-11-13T05:25:04.007460Z" }, "hidden_ranges": [], "originalKey": "3ce827be-d20b-48d3-a6ff-291bd442c748", "papermill": { "duration": 2.620688, "end_time": "2024-11-13T05:25:04.028716", "exception": false, "start_time": "2024-11-13T05:25:01.408028", "status": "completed" }, "tags": [], "vscode": { "languageId": "python" } }, "outputs": [ { "data": { "text/html": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 11-13 05:25:03] ax.utils.notebook.plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 11-13 05:25:03] ax.utils.notebook.plotting: Please see\n", " (https://ax.dev/tutorials/visualizations.html#Fix-for-plots-that-are-not-rendering)\n", " if visualizations are not rendering.\n" ] }, { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import os\n", "import time\n", "\n", "from copy import deepcopy\n", "from typing import Optional\n", "\n", "import numpy as np\n", "\n", "import torch\n", "\n", "from ax.core.data import Data\n", "from ax.core.experiment import Experiment\n", "from ax.core.generator_run import GeneratorRun\n", "from ax.core.multi_type_experiment import MultiTypeExperiment\n", "from ax.core.objective import Objective\n", "from ax.core.observation import ObservationFeatures, observations_from_data\n", "from ax.core.optimization_config import OptimizationConfig\n", "from ax.core.parameter import ParameterType, RangeParameter\n", "from ax.core.search_space import SearchSpace\n", "from ax.metrics.hartmann6 import Hartmann6Metric\n", "from ax.modelbridge.factory import get_sobol\n", "from ax.modelbridge.registry import Models, MT_MTGP_trans, ST_MTGP_trans\n", "from ax.modelbridge.torch import TorchModelBridge\n", "from ax.modelbridge.transforms.convert_metric_names import tconfig_from_mt_experiment\n", "from ax.plot.diagnostic import interact_batch_comparison\n", "from ax.runners.synthetic import SyntheticRunner\n", "from ax.utils.common.typeutils import checked_cast\n", "from ax.utils.notebook.plotting import init_notebook_plotting, render\n", "\n", "init_notebook_plotting()" ] }, { "cell_type": "code", "execution_count": 2, "id": "993058bf", "metadata": { "execution": { "iopub.execute_input": "2024-11-13T05:25:04.110505Z", "iopub.status.busy": "2024-11-13T05:25:04.109789Z", "iopub.status.idle": "2024-11-13T05:25:04.113499Z", "shell.execute_reply": "2024-11-13T05:25:04.112959Z" }, "papermill": { "duration": 0.045082, "end_time": "2024-11-13T05:25:04.114707", "exception": false, "start_time": "2024-11-13T05:25:04.069625", "status": "completed" }, "tags": [], "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "SMOKE_TEST = os.environ.get(\"SMOKE_TEST\")" ] }, { "cell_type": "markdown", "id": "4dadd91c", "metadata": { "code_folding": [], "hidden_ranges": [], "originalKey": "76100312-e604-46ed-a123-9b0296ced6ff", "papermill": { "duration": 0.039448, "end_time": "2024-11-13T05:25:04.193117", "exception": false, "start_time": "2024-11-13T05:25:04.153669", "status": "completed" }, "showInput": false, "tags": [] }, "source": [ "## 1. Define Metric classes\n", "For this example, the online system is optimizing a Hartmann6 function. The Metric objects for these are directly imported above. We create analagous offline versions of this metrics which are identical but have a transform applied (a piecewise linear function). We construct Metric objects for each of them." ] }, { "cell_type": "code", "execution_count": 3, "id": "f0ab11b6", "metadata": { "code_folding": [], "execution": { "iopub.execute_input": "2024-11-13T05:25:04.274069Z", "iopub.status.busy": "2024-11-13T05:25:04.273535Z", "iopub.status.idle": "2024-11-13T05:25:04.278298Z", "shell.execute_reply": "2024-11-13T05:25:04.277636Z" }, "hidden_ranges": [], "originalKey": "2315ca64-74e5-4084-829e-e8a482c653e5", "papermill": { "duration": 0.046852, "end_time": "2024-11-13T05:25:04.279611", "exception": false, "start_time": "2024-11-13T05:25:04.232759", "status": "completed" }, "tags": [], "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Create metric with artificial offline bias, for the objective\n", "# by passing the true values through a piecewise linear function.\n", "\n", "\n", "class OfflineHartmann6Metric(Hartmann6Metric):\n", " def f(self, x: np.ndarray) -> float:\n", " raw_res = super().f(x)\n", " m = -0.35\n", " if raw_res < m:\n", " return (1.5 * (raw_res - m)) + m\n", " else:\n", " return (6.0 * (raw_res - m)) + m" ] }, { "cell_type": "markdown", "id": "d27c5a54", "metadata": { "originalKey": "b0e2089f-a7a3-4a8b-b8b3-ab6d75ca7f09", "papermill": { "duration": 0.039044, "end_time": "2024-11-13T05:25:04.357694", "exception": false, "start_time": "2024-11-13T05:25:04.318650", "status": "completed" }, "showInput": false, "tags": [] }, "source": [ "## 2. Create experiment\n", "\n", "A MultiTypeExperiment is used for managing online and offline trials together. It is constructed in several steps:\n", "\n", "1. Create the search space - This is done in the usual way.\n", "2. Specify optimization config - Also done in the usual way.\n", "3. Initialize Experiment - In addition to the search_space and optimization_config, specify that \"online\" is the default trial_type. This is the main trial type for which we're optimizing. Optimization metrics are defined to be for this type and new trials assume this trial type by default.\n", "4. Establish offline trial_type - Register the \"offline\" trial type and specify how to deploy trials of this type.\n", "5. Add offline metrics - Create the offline metrics and add them to the experiment. When adding the metrics, we need to specify the trial type (\"offline\") and online metric name it is associated with so the model can link them.\n", "\n", "Finally, because this is a synthetic benchmark problem where the true function values are known, we will also register metrics with the true (noiseless) function values for plotting below." ] }, { "cell_type": "code", "execution_count": 4, "id": "dfd314a5", "metadata": { "code_folding": [], "execution": { "iopub.execute_input": "2024-11-13T05:25:04.438566Z", "iopub.status.busy": "2024-11-13T05:25:04.437998Z", "iopub.status.idle": "2024-11-13T05:25:04.444192Z", "shell.execute_reply": "2024-11-13T05:25:04.443657Z" }, "hidden_ranges": [], "originalKey": "39504f84-793e-4dae-ae55-068f1b762706", "papermill": { "duration": 0.048451, "end_time": "2024-11-13T05:25:04.445494", "exception": false, "start_time": "2024-11-13T05:25:04.397043", "status": "completed" }, "tags": [], "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def get_experiment(include_true_metric=True):\n", " noise_sd = 0.1 # Observations will have this much Normal noise added to them\n", "\n", " # 1. Create simple search space for [0,1]^d, d=6\n", " param_names = [f\"x{i}\" for i in range(6)]\n", " parameters = [\n", " RangeParameter(\n", " name=param_names[i],\n", " parameter_type=ParameterType.FLOAT,\n", " lower=0.0,\n", " upper=1.0,\n", " )\n", " for i in range(6)\n", " ]\n", " search_space = SearchSpace(parameters=parameters)\n", "\n", " # 2. Specify optimization config\n", " online_objective = Hartmann6Metric(\n", " \"objective\", param_names=param_names, noise_sd=noise_sd\n", " )\n", " opt_config = OptimizationConfig(\n", " objective=Objective(online_objective, minimize=True)\n", " )\n", "\n", " # 3. Init experiment\n", " exp = MultiTypeExperiment(\n", " name=\"mt_exp\",\n", " search_space=search_space,\n", " default_trial_type=\"online\",\n", " default_runner=SyntheticRunner(),\n", " optimization_config=opt_config,\n", " )\n", "\n", " # 4. Establish offline trial_type, and how those trials are deployed\n", " exp.add_trial_type(\"offline\", SyntheticRunner())\n", "\n", " # 5. Add offline metrics that provide biased estimates of the online metrics\n", " offline_objective = OfflineHartmann6Metric(\n", " \"offline_objective\", param_names=param_names, noise_sd=noise_sd\n", " )\n", " # Associate each offline metric with corresponding online metric\n", " exp.add_tracking_metric(\n", " metric=offline_objective, trial_type=\"offline\", canonical_name=\"objective\"\n", " )\n", "\n", " return exp" ] }, { "cell_type": "markdown", "id": "f661278e", "metadata": { "code_folding": [], "hidden_ranges": [], "originalKey": "5a00218e-c27d-4d6f-bef0-3e562217533a", "papermill": { "duration": 0.039472, "end_time": "2024-11-13T05:25:04.524204", "exception": false, "start_time": "2024-11-13T05:25:04.484732", "status": "completed" }, "showInput": false, "tags": [] }, "source": [ "## 3. Vizualize the simulator bias\n", "\n", "These figures compare the online measurements to the offline measurements on a random set of points, for the objective metric. You can see the offline measurements are biased but highly correlated. This produces Fig. S3 from the paper." ] }, { "cell_type": "code", "execution_count": 5, "id": "0ef32c8a", "metadata": { "execution": { "iopub.execute_input": "2024-11-13T05:25:04.605115Z", "iopub.status.busy": "2024-11-13T05:25:04.604633Z", "iopub.status.idle": "2024-11-13T05:25:04.889756Z", "shell.execute_reply": "2024-11-13T05:25:04.889050Z" }, "originalKey": "8260b668-91ef-404e-aa8c-4bf43f6a5660", "papermill": { "duration": 0.327416, "end_time": "2024-11-13T05:25:04.891519", "exception": false, "start_time": "2024-11-13T05:25:04.564103", "status": "completed" }, "tags": [], "vscode": { "languageId": "python" } }, "outputs": [ { "data": { "application/vnd.plotly.v1+json": { "config": { "linkText": "Export to plot.ly", "plotlyServerURL": "https://plot.ly", "showLink": false }, "data": [ { "hoverinfo": "none", "line": { "color": "black", "dash": "dot", "width": 2 }, "mode": "lines", "showlegend": false, "type": "scatter", "visible": true, "x": [ -2.874659384633504, 2.3732864377721707 ], "y": [ -2.874659384633504, 2.3732864377721707 ] }, { "error_x": { "array": [ 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196 ], "color": "rgba(128,177,211,0.4)", "type": "data" }, "error_y": { "array": [ 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196, 0.196 ], "color": "rgba(128,177,211,0.4)", "type": "data" }, "hoverinfo": "text", "marker": { "color": "rgba(128,177,211,1)" }, "mode": "markers", "name": "In-sample", "showlegend": true, "text": [ "Arm 0_0

Batch 1: 1.652 [1.456, 1.848]
Batch 0: -0.046 [-0.242, 0.150]

Parameterization:
x0: 0.0
x1: 0.0
x2: 0.0
x3: 0.0
x4: 0.0
x5: 0.0", "Arm 0_1

Batch 1: -0.688 [-0.884, -0.492]
Batch 0: -0.467 [-0.663, -0.271]

Parameterization:
x0: 0.5
x1: 0.5
x2: 0.5
x3: 0.5
x4: 0.5
x5: 0.5", "Arm 0_2

Batch 1: 1.477 [1.281, 1.673]
Batch 0: 0.02 [-0.176, 0.216]

Parameterization:
x0: 0.75
x1: 0.25
x2: 0.25
x3: 0.25
x4: 0.75
x5: 0.75", "Arm 0_3

Batch 1: -0.993 [-1.189, -0.797]
Batch 0: -0.732 [-0.928, -0.536]

Parameterization:
x0: 0.25
x1: 0.75
x2: 0.75
x3: 0.75
x4: 0.25
x5: 0.25", "Arm 0_4

Batch 1: 0.789 [0.593, 0.985]
Batch 0: -0.18 [-0.376, 0.016]

Parameterization:
x0: 0.375
x1: 0.375
x2: 0.625
x3: 0.875
x4: 0.375
x5: 0.125", "Arm 0_5

Batch 1: 1.78 [1.584, 1.976]
Batch 0: 0.086 [-0.110, 0.282]

Parameterization:
x0: 0.875
x1: 0.875
x2: 0.125
x3: 0.375
x4: 0.875
x5: 0.625", "Arm 0_6

Batch 1: 1.077 [0.881, 1.273]
Batch 0: 0.004 [-0.192, 0.200]

Parameterization:
x0: 0.625
x1: 0.125
x2: 0.875
x3: 0.625
x4: 0.625
x5: 0.875", "Arm 0_7

Batch 1: -0.35 [-0.546, -0.154]
Batch 0: -0.421 [-0.617, -0.225]

Parameterization:
x0: 0.125
x1: 0.625
x2: 0.375
x3: 0.125
x4: 0.125
x5: 0.375", "Arm 0_8

Batch 1: 0.794 [0.598, 0.990]
Batch 0: -0.027 [-0.223, 0.169]

Parameterization:
x0: 0.1875
x1: 0.3125
x2: 0.9375
x3: 0.4375
x4: 0.5625
x5: 0.3125", "Arm 0_9

Batch 1: 1.69 [1.494, 1.886]
Batch 0: 0.068 [-0.128, 0.264]

Parameterization:
x0: 0.6875
x1: 0.8125
x2: 0.4375
x3: 0.9375
x4: 0.0625
x5: 0.8125", "Arm 0_10

Batch 1: -0.37 [-0.566, -0.174]
Batch 0: -0.45 [-0.646, -0.254]

Parameterization:
x0: 0.9375
x1: 0.0625
x2: 0.6875
x3: 0.1875
x4: 0.3125
x5: 0.5625", "Arm 0_11

Batch 1: -1.58 [-1.776, -1.384]
Batch 0: -1.134 [-1.330, -0.938]

Parameterization:
x0: 0.4375
x1: 0.5625
x2: 0.1875
x3: 0.6875
x4: 0.8125
x5: 0.0625", "Arm 0_12

Batch 1: 1.3 [1.104, 1.496]
Batch 0: -0.242 [-0.438, -0.046]

Parameterization:
x0: 0.3125
x1: 0.1875
x2: 0.3125
x3: 0.5625
x4: 0.9375
x5: 0.4375", "Arm 0_13

Batch 1: 0.074 [-0.122, 0.270]
Batch 0: -0.169 [-0.365, 0.027]

Parameterization:
x0: 0.8125
x1: 0.6875
x2: 0.8125
x3: 0.0625
x4: 0.4375
x5: 0.9375", "Arm 0_14

Batch 1: 1.444 [1.248, 1.640]
Batch 0: -0.229 [-0.425, -0.033]

Parameterization:
x0: 0.5625
x1: 0.4375
x2: 0.0625
x3: 0.8125
x4: 0.1875
x5: 0.6875", "Arm 0_15

Batch 1: 0.364 [0.168, 0.560]
Batch 0: -0.254 [-0.450, -0.058]

Parameterization:
x0: 0.0625
x1: 0.9375
x2: 0.5625
x3: 0.3125
x4: 0.6875
x5: 0.1875", "Arm 0_16

Batch 1: -0.358 [-0.554, -0.162]
Batch 0: -0.177 [-0.373, 0.019]

Parameterization:
x0: 0.09375
x1: 0.46875
x2: 0.46875
x3: 0.65625
x4: 0.28125
x5: 0.96875", "Arm 0_17

Batch 1: 1.663 [1.467, 1.859]
Batch 0: -0.056 [-0.252, 0.140]

Parameterization:
x0: 0.59375
x1: 0.96875
x2: 0.96875
x3: 0.15625
x4: 0.78125
x5: 0.46875", "Arm 0_18

Batch 1: 1.661 [1.465, 1.857]
Batch 0: -0.064 [-0.260, 0.132]

Parameterization:
x0: 0.84375
x1: 0.21875
x2: 0.21875
x3: 0.90625
x4: 0.53125
x5: 0.21875", "Arm 0_19

Batch 1: -0.294 [-0.490, -0.098]
Batch 0: -0.248 [-0.444, -0.052]

Parameterization:
x0: 0.34375
x1: 0.71875
x2: 0.71875
x3: 0.40625
x4: 0.03125
x5: 0.71875", "Arm 0_20

Batch 1: -1.557 [-1.753, -1.361]
Batch 0: -1.175 [-1.371, -0.979]

Parameterization:
x0: 0.46875
x1: 0.09375
x2: 0.84375
x3: 0.28125
x4: 0.15625
x5: 0.84375", "Arm 0_21

Batch 1: 1.828 [1.632, 2.024]
Batch 0: -0.01 [-0.206, 0.186]

Parameterization:
x0: 0.96875
x1: 0.59375
x2: 0.34375
x3: 0.78125
x4: 0.65625
x5: 0.34375", "Arm 0_22

Batch 1: 1.585 [1.389, 1.781]
Batch 0: 0.001 [-0.195, 0.197]

Parameterization:
x0: 0.71875
x1: 0.34375
x2: 0.59375
x3: 0.03125
x4: 0.90625
x5: 0.09375", "Arm 0_23

Batch 1: 0.14 [-0.056, 0.336]
Batch 0: -0.206 [-0.402, -0.010]

Parameterization:
x0: 0.21875
x1: 0.84375
x2: 0.09375
x3: 0.53125
x4: 0.40625
x5: 0.59375", "Arm 0_24

Batch 1: 1.279 [1.083, 1.475]
Batch 0: -0.078 [-0.274, 0.118]

Parameterization:
x0: 0.15625
x1: 0.15625
x2: 0.53125
x3: 0.84375
x4: 0.84375
x5: 0.65625", "Arm 0_25

Batch 1: -0.275 [-0.471, -0.079]
Batch 0: -0.489 [-0.685, -0.293]

Parameterization:
x0: 0.65625
x1: 0.65625
x2: 0.03125
x3: 0.34375
x4: 0.34375
x5: 0.15625", "Arm 0_26

Batch 1: 1.221 [1.025, 1.417]
Batch 0: -0.149 [-0.345, 0.047]

Parameterization:
x0: 0.90625
x1: 0.40625
x2: 0.78125
x3: 0.59375
x4: 0.09375
x5: 0.40625", "Arm 0_27

Batch 1: 1.459 [1.263, 1.655]
Batch 0: 0.044 [-0.152, 0.240]

Parameterization:
x0: 0.40625
x1: 0.90625
x2: 0.28125
x3: 0.09375
x4: 0.59375
x5: 0.90625", "Arm 0_28

Batch 1: 0.759 [0.563, 0.955]
Batch 0: -0.055 [-0.251, 0.141]

Parameterization:
x0: 0.28125
x1: 0.28125
x2: 0.15625
x3: 0.21875
x4: 0.71875
x5: 0.53125", "Arm 0_29

Batch 1: 0.466 [0.270, 0.662]
Batch 0: -0.244 [-0.440, -0.048]

Parameterization:
x0: 0.78125
x1: 0.78125
x2: 0.65625
x3: 0.71875
x4: 0.21875
x5: 0.03125", "Arm 0_30

Batch 1: -0.416 [-0.612, -0.220]
Batch 0: -0.298 [-0.494, -0.102]

Parameterization:
x0: 0.53125
x1: 0.03125
x2: 0.40625
x3: 0.46875
x4: 0.46875
x5: 0.28125", "Arm 0_31

Batch 1: 1.715 [1.519, 1.911]
Batch 0: 0.166 [-0.030, 0.362]

Parameterization:
x0: 0.03125
x1: 0.53125
x2: 0.90625
x3: 0.96875
x4: 0.96875
x5: 0.78125", "Arm 0_32

Batch 1: -1.454 [-1.650, -1.258]
Batch 0: -1.242 [-1.438, -1.046]

Parameterization:
x0: 0.046875
x1: 0.265625
x2: 0.703125
x3: 0.546875
x4: 0.140625
x5: 0.921875", "Arm 0_33

Batch 1: 1.395 [1.199, 1.591]
Batch 0: -0.042 [-0.238, 0.154]

Parameterization:
x0: 0.546875
x1: 0.765625
x2: 0.203125
x3: 0.046875
x4: 0.640625
x5: 0.421875", "Arm 0_34

Batch 1: 1.939 [1.743, 2.135]
Batch 0: -0.009 [-0.205, 0.187]

Parameterization:
x0: 0.796875
x1: 0.015625
x2: 0.953125
x3: 0.796875
x4: 0.890625
x5: 0.171875", "Arm 0_35

Batch 1: -2.44 [-2.636, -2.244]
Batch 0: -1.907 [-2.103, -1.711]

Parameterization:
x0: 0.296875
x1: 0.515625
x2: 0.453125
x3: 0.296875
x4: 0.390625
x5: 0.671875", "Arm 0_36

Batch 1: -2.329 [-2.525, -2.133]
Batch 0: -1.692 [-1.888, -1.496]

Parameterization:
x0: 0.421875
x1: 0.140625
x2: 0.078125
x3: 0.421875
x4: 0.265625
x5: 0.796875", "Arm 0_37

Batch 1: 1.712 [1.516, 1.908]
Batch 0: -0.024 [-0.220, 0.172]

Parameterization:
x0: 0.921875
x1: 0.640625
x2: 0.578125
x3: 0.921875
x4: 0.765625
x5: 0.296875", "Arm 0_38

Batch 1: 1.462 [1.266, 1.658]
Batch 0: 0.031 [-0.165, 0.227]

Parameterization:
x0: 0.671875
x1: 0.390625
x2: 0.328125
x3: 0.171875
x4: 0.515625
x5: 0.046875", "Arm 0_39

Batch 1: 1.357 [1.161, 1.553]
Batch 0: -0.108 [-0.304, 0.088]

Parameterization:
x0: 0.171875
x1: 0.890625
x2: 0.828125
x3: 0.671875
x4: 0.015625
x5: 0.546875", "Arm 0_40

Batch 1: 1.81 [1.614, 2.006]
Batch 0: 0.035 [-0.161, 0.231]

Parameterization:
x0: 0.234375
x1: 0.078125
x2: 0.265625
x3: 0.984375
x4: 0.703125
x5: 0.734375", "Arm 0_41

Batch 1: 0.649 [0.453, 0.845]
Batch 0: -0.272 [-0.468, -0.076]

Parameterization:
x0: 0.734375
x1: 0.578125
x2: 0.765625
x3: 0.484375
x4: 0.203125
x5: 0.234375", "Arm 0_42

Batch 1: 1.743 [1.547, 1.939]
Batch 0: -0.027 [-0.223, 0.169]

Parameterization:
x0: 0.984375
x1: 0.328125
x2: 0.015625
x3: 0.734375
x4: 0.453125
x5: 0.484375", "Arm 0_43

Batch 1: 1.614 [1.418, 1.810]
Batch 0: 0.013 [-0.183, 0.209]

Parameterization:
x0: 0.484375
x1: 0.828125
x2: 0.515625
x3: 0.234375
x4: 0.953125
x5: 0.984375", "Arm 0_44

Batch 1: 1.552 [1.356, 1.748]
Batch 0: -0.005 [-0.201, 0.191]

Parameterization:
x0: 0.359375
x1: 0.453125
x2: 0.890625
x3: 0.109375
x4: 0.828125
x5: 0.609375", "Arm 0_45

Batch 1: 1.183 [0.987, 1.379]
Batch 0: -0.146 [-0.342, 0.050]

Parameterization:
x0: 0.859375
x1: 0.953125
x2: 0.390625
x3: 0.609375
x4: 0.328125
x5: 0.109375", "Arm 0_46

Batch 1: -0.199 [-0.395, -0.003]
Batch 0: -0.265 [-0.461, -0.069]

Parameterization:
x0: 0.609375
x1: 0.203125
x2: 0.640625
x3: 0.359375
x4: 0.078125
x5: 0.359375", "Arm 0_47

Batch 1: 1.743 [1.547, 1.939]
Batch 0: 0.059 [-0.137, 0.255]

Parameterization:
x0: 0.109375
x1: 0.703125
x2: 0.140625
x3: 0.859375
x4: 0.578125
x5: 0.859375", "Arm 0_48

Batch 1: 1.001 [0.805, 1.197]
Batch 0: -0.164 [-0.360, 0.032]

Parameterization:
x0: 0.078125
x1: 0.234375
x2: 0.796875
x3: 0.140625
x4: 0.421875
x5: 0.078125", "Arm 0_49

Batch 1: 1.575 [1.379, 1.771]
Batch 0: 0.24 [0.044, 0.436]

Parameterization:
x0: 0.578125
x1: 0.734375
x2: 0.296875
x3: 0.640625
x4: 0.921875
x5: 0.578125" ], "type": "scatter", "visible": true, "x": [ 1.6524320984723868, -0.6878858971601272, 1.4770808077332649, -0.993409054118204, 0.7886336709163931, 1.7804359311655795, 1.0773321202240793, -0.34998560539622653, 0.7942791638626059, 1.6903721919460017, -0.3704934212568616, -1.5795198695207817, 1.29984711448716, 0.0735116650871739, 1.444341713704562, 0.36414515925277435, -0.3583787303727239, 1.663419694531572, 1.661264236472163, -0.29439317829075407, -1.5573174367851048, 1.8282537871268785, 1.5853024059605545, 0.14003400760883242, 1.2788274550142804, -0.2749265704266538, 1.2214512190416733, 1.4587147697483676, 0.7587494893581115, 0.4661684648419135, -0.4160641206184189, 1.714844698001545, -1.4535183563773453, 1.3948105628483884, 1.93874344584464, -2.4401163927059732, -2.3292341962140224, 1.7115634646254114, 1.4623878004667297, 1.356523563842693, 1.8104220084429676, 0.6494142389673133, 1.742634131552257, 1.6140005373938835, 1.5520003052653588, 1.183248548129262, -0.19935609632274945, 1.7425834750126372, 1.0012776682076836, 1.5748424288242397 ], "y": [ -0.046031046481474394, -0.4668133411025592, 0.020434014700034205, -0.7319230199446534, -0.17996314445172007, 0.08600638326298972, 0.0035363151816114047, -0.42147738251028527, -0.02739316458075458, 0.06774884913714001, -0.4495925751066135, -1.1344299619478115, -0.24230445084299795, -0.16884701303563104, -0.22861347231845586, -0.25382818744999885, -0.17704452940620327, -0.056299619649289354, -0.06445093124142948, -0.24815096414558346, -1.1754911472200849, -0.009755970774769036, 0.0006225577022990827, -0.20647271029959413, -0.07822348631118277, -0.48879536384292094, -0.14945140490556008, 0.04371295784007785, -0.05455174017396108, -0.24413018874927012, -0.2982111411191101, 0.1656627244668291, -1.2420075853238597, -0.04178773500131504, -0.009158522971981187, -1.9074384445262103, -1.6921533914852553, -0.024110779003412314, 0.03116786627378903, -0.1077249130523752, 0.03510930351304123, -0.27161522996725057, -0.026515953332991823, 0.013081000918839344, -0.00530675560267449, -0.14599695485702402, -0.26526071199387513, 0.059209972256328375, -0.16440108275113494, 0.2403312457307302 ] } ], "layout": { "annotations": [ { "showarrow": false, "text": "Show CI", "x": 1.125, "xanchor": "left", "xref": "paper", "y": 0.9, "yanchor": "middle", "yref": "paper" } ], "height": 500, "hovermode": "closest", "showlegend": false, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "fillpattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ], "sequentialminus": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "text": "Repeated arms across trials" }, "updatemenus": [ { "buttons": [ { "args": [ { "visible": [ true, true ] }, { "xaxis.range": [ -2.874659384633504, 2.3732864377721707 ], "yaxis.range": [ -2.874659384633504, 2.3732864377721707 ] } ], "label": "objective", "method": "update" } ], "x": 0, "xanchor": "left", "y": 1.125, "yanchor": "top" }, { "buttons": [ { "args": [ { "error_x.thickness": 2, "error_x.width": 4, "error_y.thickness": 2, "error_y.width": 4 } ], "label": "Yes", "method": "restyle" }, { "args": [ { "error_x.thickness": 0, "error_x.width": 0, "error_y.thickness": 0, "error_y.width": 0 } ], "label": "No", "method": "restyle" } ], "x": 1.125, "xanchor": "left", "y": 0.8, "yanchor": "middle" } ], "width": 530, "xaxis": { "linecolor": "black", "linewidth": 0.5, "mirror": true, "range": [ -2.874659384633504, 2.3732864377721707 ], "title": { "text": "Batch 1" }, "zeroline": false }, "yaxis": { "linecolor": "black", "linewidth": 0.5, "mirror": true, "range": [ -2.874659384633504, 2.3732864377721707 ], "title": { "text": "Batch 0" }, "zeroline": false } } }, "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Generate 50 points from a Sobol sequence\n", "exp = get_experiment(include_true_metric=False)\n", "s = get_sobol(exp.search_space, scramble=False)\n", "gr = s.gen(50)\n", "# Deploy them both online and offline\n", "exp.new_batch_trial(trial_type=\"online\", generator_run=gr).run()\n", "exp.new_batch_trial(trial_type=\"offline\", generator_run=gr).run()\n", "# Fetch data\n", "data = exp.fetch_data()\n", "observations = observations_from_data(exp, data)\n", "# Plot the arms in batch 0 (online) vs. batch 1 (offline)\n", "render(interact_batch_comparison(observations, exp, 1, 0))" ] }, { "cell_type": "markdown", "id": "d4b6ebef", "metadata": { "originalKey": "69cf9e8e-361e-4546-871f-6bb8641d1b97", "papermill": { "duration": 0.044021, "end_time": "2024-11-13T05:25:04.979379", "exception": false, "start_time": "2024-11-13T05:25:04.935358", "status": "completed" }, "tags": [] }, "source": [ "## 4. The Bayesian optimization loop\n", "\n", "Here we construct a Bayesian optimization loop that interleaves online and offline batches. The loop defined here is described in Algorithm 1 of the paper. We compare multi-task Bayesian optimization to regular Bayesian optimization using only online observations.\n", "\n", "Here we measure performance over 3 repetitions of the loop. Each one takes 1-2 hours so the whole benchmark run will take several hours to complete." ] }, { "cell_type": "code", "execution_count": 6, "id": "b480b51a", "metadata": { "execution": { "iopub.execute_input": "2024-11-13T05:25:05.068483Z", "iopub.status.busy": "2024-11-13T05:25:05.067856Z", "iopub.status.idle": "2024-11-13T05:25:05.071734Z", "shell.execute_reply": "2024-11-13T05:25:05.071206Z" }, "originalKey": "3d124563-8a1f-411e-9822-972568ce1970", "papermill": { "duration": 0.049938, "end_time": "2024-11-13T05:25:05.073060", "exception": false, "start_time": "2024-11-13T05:25:05.023122", "status": "completed" }, "tags": [], "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Settings for the optimization benchmark.\n", "\n", "# Number of repeated experiments, each with independent observation noise.\n", "# This should be changed to 50 to reproduce the results from the paper.\n", "if SMOKE_TEST:\n", " n_batches = 1\n", " n_init_online = 2\n", " n_init_offline = 2\n", " n_opt_online = 2\n", " n_opt_offline = 2\n", "else:\n", " n_batches = 3 # Number of optimized BO batches\n", " n_init_online = 5 # Size of the quasirandom initialization run online\n", " n_init_offline = 20 # Size of the quasirandom initialization run offline\n", " n_opt_online = 5 # Batch size for BO selected points to be run online\n", " n_opt_offline = 20 # Batch size for BO selected to be run offline" ] }, { "cell_type": "markdown", "id": "3da61f22", "metadata": { "originalKey": "5447b3e7-b245-4fab-ad4a-165d7c63e09c", "papermill": { "duration": 0.043264, "end_time": "2024-11-13T05:25:05.159618", "exception": false, "start_time": "2024-11-13T05:25:05.116354", "status": "completed" }, "tags": [] }, "source": [ "#### 4a. Optimization with online observations only\n", "For the online-only case, we run `n_init_online` sobol points followed by `n_batches` batches of `n_opt_online` points selected by the GP. This is a normal Bayesian optimization loop." ] }, { "cell_type": "code", "execution_count": 7, "id": "f306b842", "metadata": { "code_folding": [], "execution": { "iopub.execute_input": "2024-11-13T05:25:05.247722Z", "iopub.status.busy": "2024-11-13T05:25:05.247275Z", "iopub.status.idle": "2024-11-13T05:25:05.252569Z", "shell.execute_reply": "2024-11-13T05:25:05.252005Z" }, "hidden_ranges": [], "originalKey": "040354c2-4313-46db-b40d-8adc8da6fafb", "papermill": { "duration": 0.050935, "end_time": "2024-11-13T05:25:05.253839", "exception": false, "start_time": "2024-11-13T05:25:05.202904", "status": "completed" }, "tags": [], "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# This function runs a Bayesian optimization loop, making online observations only.\n", "def run_online_only_bo():\n", " t1 = time.time()\n", " ### Do BO with online only\n", " ## Quasi-random initialization\n", " exp_online = get_experiment()\n", " m = get_sobol(exp_online.search_space, scramble=False)\n", " gr = m.gen(n=n_init_online)\n", " exp_online.new_batch_trial(trial_type=\"online\", generator_run=gr).run()\n", " ## Do BO\n", " for b in range(n_batches):\n", " print(\"Online-only batch\", b, time.time() - t1)\n", " # Fit the GP\n", " m = Models.BOTORCH_MODULAR(\n", " experiment=exp_online,\n", " data=exp_online.fetch_data(),\n", " search_space=exp_online.search_space,\n", " )\n", " # Generate the new batch\n", " gr = m.gen(\n", " n=n_opt_online,\n", " search_space=exp_online.search_space,\n", " optimization_config=exp_online.optimization_config,\n", " )\n", " exp_online.new_batch_trial(trial_type=\"online\", generator_run=gr).run()" ] }, { "cell_type": "markdown", "id": "2da8d847", "metadata": { "originalKey": "c1837efe-9f41-4eb8-a415-309392724141", "papermill": { "duration": 0.043628, "end_time": "2024-11-13T05:25:05.340925", "exception": false, "start_time": "2024-11-13T05:25:05.297297", "status": "completed" }, "tags": [] }, "source": [ "#### 4b. Multi-task Bayesian optimization\n", "Here we incorporate offline observations to accelerate the optimization, while using the same total number of online observations as in the loop above. The strategy here is that outlined in Algorithm 1 of the paper.\n", "\n", "1. Initialization - Run `n_init_online` Sobol points online, and `n_init_offline` Sobol points offline.\n", "2. Fit model - Fit an MTGP to both online and offline observations.\n", "3. Generate candidates - Generate `n_opt_offline` candidates using NEI.\n", "4. Launch offline batch - Run the `n_opt_offline` candidates offline and observe their offline metrics.\n", "5. Update model - Update the MTGP with the new offline observations.\n", "6. Select points for online batch - Select the best (maximum utility) `n_opt_online` of the NEI candidates, after incorporating their offline observations, and run them online.\n", "7. Update model and repeat - Update the model with the online observations, and repeat from step 3 for the next batch." ] }, { "cell_type": "code", "execution_count": 8, "id": "907c4dde", "metadata": { "execution": { "iopub.execute_input": "2024-11-13T05:25:05.429853Z", "iopub.status.busy": "2024-11-13T05:25:05.429228Z", "iopub.status.idle": "2024-11-13T05:25:05.436489Z", "shell.execute_reply": "2024-11-13T05:25:05.435826Z" }, "papermill": { "duration": 0.053274, "end_time": "2024-11-13T05:25:05.437771", "exception": false, "start_time": "2024-11-13T05:25:05.384497", "status": "completed" }, "tags": [], "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def get_MTGP(\n", " experiment: Experiment,\n", " data: Data,\n", " search_space: Optional[SearchSpace] = None,\n", " trial_index: Optional[int] = None,\n", " device: torch.device = torch.device(\"cpu\"),\n", " dtype: torch.dtype = torch.double,\n", ") -> TorchModelBridge:\n", " \"\"\"Instantiates a Multi-task Gaussian Process (MTGP) model that generates\n", " points with EI.\n", "\n", " If the input experiment is a MultiTypeExperiment then a\n", " Multi-type Multi-task GP model will be instantiated.\n", " Otherwise, the model will be a Single-type Multi-task GP.\n", " \"\"\"\n", "\n", " if isinstance(experiment, MultiTypeExperiment):\n", " trial_index_to_type = {\n", " t.index: t.trial_type for t in experiment.trials.values()\n", " }\n", " transforms = MT_MTGP_trans\n", " transform_configs = {\n", " \"TrialAsTask\": {\"trial_level_map\": {\"trial_type\": trial_index_to_type}},\n", " \"ConvertMetricNames\": tconfig_from_mt_experiment(experiment),\n", " }\n", " else:\n", " # Set transforms for a Single-type MTGP model.\n", " transforms = ST_MTGP_trans\n", " transform_configs = None\n", "\n", " # Choose the status quo features for the experiment from the selected trial.\n", " # If trial_index is None, we will look for a status quo from the last\n", " # experiment trial to use as a status quo for the experiment.\n", " if trial_index is None:\n", " trial_index = len(experiment.trials) - 1\n", " elif trial_index >= len(experiment.trials):\n", " raise ValueError(\"trial_index is bigger than the number of experiment trials\")\n", "\n", " status_quo = experiment.trials[trial_index].status_quo\n", " if status_quo is None:\n", " status_quo_features = None\n", " else:\n", " status_quo_features = ObservationFeatures(\n", " parameters=status_quo.parameters,\n", " trial_index=trial_index, # pyre-ignore[6]\n", " )\n", "\n", " \n", " return checked_cast(\n", " TorchModelBridge,\n", " Models.ST_MTGP(\n", " experiment=experiment,\n", " search_space=search_space or experiment.search_space,\n", " data=data,\n", " transforms=transforms,\n", " transform_configs=transform_configs,\n", " torch_dtype=dtype,\n", " torch_device=device,\n", " status_quo_features=status_quo_features,\n", " ),\n", " )" ] }, { "cell_type": "code", "execution_count": 9, "id": "74ca4511", "metadata": { "code_folding": [], "execution": { "iopub.execute_input": "2024-11-13T05:25:05.526824Z", "iopub.status.busy": "2024-11-13T05:25:05.526350Z", "iopub.status.idle": "2024-11-13T05:25:05.534975Z", "shell.execute_reply": "2024-11-13T05:25:05.534386Z" }, "hidden_ranges": [], "originalKey": "37735b0e-e488-4927-a3da-a7d32d9f1ae0", "papermill": { "duration": 0.054826, "end_time": "2024-11-13T05:25:05.536293", "exception": false, "start_time": "2024-11-13T05:25:05.481467", "status": "completed" }, "tags": [], "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Online batches are constructed by selecting the maximum utility points from the offline\n", "# batch, after updating the model with the offline results. This function selects the max utility points according\n", "# to the MTGP predictions.\n", "def max_utility_from_GP(n, m, experiment, search_space, gr):\n", " obsf = []\n", " for arm in gr.arms:\n", " params = deepcopy(arm.parameters)\n", " params[\"trial_type\"] = \"online\"\n", " obsf.append(ObservationFeatures(parameters=params))\n", " # Make predictions\n", " f, cov = m.predict(obsf)\n", " # Compute expected utility\n", " u = -np.array(f[\"objective\"])\n", " best_arm_indx = np.flip(np.argsort(u))[:n]\n", " gr_new = GeneratorRun(\n", " arms=[gr.arms[i] for i in best_arm_indx],\n", " weights=[1.0] * n,\n", " )\n", " return gr_new\n", "\n", "\n", "# This function runs a multi-task Bayesian optimization loop, as outlined in Algorithm 1 and above.\n", "def run_mtbo():\n", " t1 = time.time()\n", " online_trials = []\n", " ## 1. Quasi-random initialization, online and offline\n", " exp_multitask = get_experiment()\n", " # Online points\n", " m = get_sobol(exp_multitask.search_space, scramble=False)\n", " gr = m.gen(\n", " n=n_init_online,\n", " )\n", " tr = exp_multitask.new_batch_trial(trial_type=\"online\", generator_run=gr)\n", " tr.run()\n", " online_trials.append(tr.index)\n", " # Offline points\n", " m = get_sobol(exp_multitask.search_space, scramble=False)\n", " gr = m.gen(\n", " n=n_init_offline,\n", " )\n", " exp_multitask.new_batch_trial(trial_type=\"offline\", generator_run=gr).run()\n", " ## Do BO\n", " for b in range(n_batches):\n", " print(\"Multi-task batch\", b, time.time() - t1)\n", " # (2 / 7). Fit the MTGP\n", " m = get_MTGP(\n", " experiment=exp_multitask,\n", " data=exp_multitask.fetch_data(),\n", " search_space=exp_multitask.search_space,\n", " )\n", "\n", " # 3. Finding the best points for the online task\n", " gr = m.gen(\n", " n=n_opt_offline,\n", " optimization_config=exp_multitask.optimization_config,\n", " fixed_features=ObservationFeatures(\n", " parameters={}, trial_index=online_trials[-1]\n", " ),\n", " )\n", "\n", " # 4. But launch them offline\n", " exp_multitask.new_batch_trial(trial_type=\"offline\", generator_run=gr).run()\n", "\n", " # 5. Update the model\n", " m = get_MTGP(\n", " experiment=exp_multitask,\n", " data=exp_multitask.fetch_data(),\n", " search_space=exp_multitask.search_space,\n", " )\n", "\n", " # 6. Select max-utility points from the offline batch to generate an online batch\n", " gr = max_utility_from_GP(\n", " n=n_opt_online,\n", " m=m,\n", " experiment=exp_multitask,\n", " search_space=exp_multitask.search_space,\n", " gr=gr,\n", " )\n", " tr = exp_multitask.new_batch_trial(trial_type=\"online\", generator_run=gr)\n", " tr.run()\n", " online_trials.append(tr.index)" ] }, { "cell_type": "markdown", "id": "4c1dbcfc", "metadata": { "originalKey": "6708d9ee-34be-4d85-91cc-ed2af5dd8026", "papermill": { "duration": 0.04386, "end_time": "2024-11-13T05:25:05.623639", "exception": false, "start_time": "2024-11-13T05:25:05.579779", "status": "completed" }, "tags": [] }, "source": [ "#### 4c. Run both loops\n", "Run both Bayesian optimization loops and aggregate results." ] }, { "cell_type": "code", "execution_count": 10, "id": "152006eb", "metadata": { "code_folding": [], "execution": { "iopub.execute_input": "2024-11-13T05:25:05.713414Z", "iopub.status.busy": "2024-11-13T05:25:05.712930Z", "iopub.status.idle": "2024-11-13T05:32:46.923577Z", "shell.execute_reply": "2024-11-13T05:32:46.922881Z" }, "hidden_ranges": [], "originalKey": "f94a7537-61a6-4200-8e56-01de41aff6c9", "papermill": { "duration": 461.257644, "end_time": "2024-11-13T05:32:46.925285", "exception": false, "start_time": "2024-11-13T05:25:05.667641", "status": "completed" }, "tags": [], "vscode": { "languageId": "python" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Online-only batch 0 0.003331899642944336\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Online-only batch 1 7.83771824836731\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Online-only batch 2 23.880744457244873\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Multi-task batch 0 0.007803916931152344\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/linear_operator/utils/interpolation.py:71: UserWarning:\n", "\n", "torch.sparse.SparseTensor(indices, values, shape, *, device=) is deprecated. Please use torch.sparse_coo_tensor(indices, values, shape, dtype=, device=). (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:651.)\n", "\n", "/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/botorch/acquisition/logei.py:338: RuntimeWarning:\n", "\n", "`cache_root` is only supported for GPyTorchModels that are not MultiTask models and don't produce a TransformedPosterior. Got a model of type . Setting `cache_root = False`.\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Multi-task batch 1 110.71123433113098\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/botorch/acquisition/logei.py:338: RuntimeWarning:\n", "\n", "`cache_root` is only supported for GPyTorchModels that are not MultiTask models and don't produce a TransformedPosterior. Got a model of type . Setting `cache_root = False`.\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Multi-task batch 2 274.48416352272034\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/botorch/acquisition/logei.py:338: RuntimeWarning:\n", "\n", "`cache_root` is only supported for GPyTorchModels that are not MultiTask models and don't produce a TransformedPosterior. Got a model of type . Setting `cache_root = False`.\n", "\n" ] } ], "source": [ "runners = {\n", " \"GP, online only\": run_online_only_bo,\n", " \"MTGP\": run_mtbo,\n", "}\n", "for k, r in runners.items():\n", " r()" ] }, { "cell_type": "markdown", "id": "5d575fb1", "metadata": { "originalKey": "1de5ae27-c925-4599-9425-332765a03416", "papermill": { "duration": 0.04409, "end_time": "2024-11-13T05:32:47.013921", "exception": false, "start_time": "2024-11-13T05:32:46.969831", "status": "completed" }, "tags": [] }, "source": [ "#### References\n", "Benjamin Letham and Eytan Bakshy. Bayesian optimization for policy search via online-offline experimentation. _arXiv preprint arXiv:1603.09326_, 2019.\n", "\n", "Kevin Swersky, Jasper Snoek, and Ryan P Adams. Multi-task Bayesian optimization. In _Advances in Neural Information Processing Systems_ 26, NIPS, pages 2004–2012, 2013." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" }, "papermill": { "default_parameters": {}, "duration": 469.117544, "end_time": "2024-11-13T05:32:49.716838", "environment_variables": {}, "exception": null, "input_path": "/tmp/tmp.jh7tLjWjTJ/Ax-main/tutorials/multi_task.ipynb", "output_path": "/tmp/tmp.jh7tLjWjTJ/Ax-main/tutorials/multi_task.ipynb", "parameters": {}, "start_time": "2024-11-13T05:25:00.599294", "version": "2.6.0" } }, "nbformat": 4, "nbformat_minor": 5 }