This tutorial walks through using Ax to tune two hyperparameters (learning rate and momentum) for a PyTorch CNN on the MNIST dataset trained using SGD with momentum.
import torch
import torch.nn as nn
import torch.nn.functional as F
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.service.utils.report_utils import exp_to_df
from ax.utils.notebook.plotting import init_notebook_plotting, render
from ax.utils.tutorials.cnn_utils import evaluate, load_mnist, train
from torch._tensor import Tensor
from torch.utils.data import DataLoader
init_notebook_plotting()
[INFO 12-09 18:40:50] ax.utils.notebook.plotting: Injecting Plotly library into cell. Do not overwrite or delete cell. [INFO 12-09 18:40:50] ax.utils.notebook.plotting: Please see (https://ax.dev/tutorials/visualizations.html#Fix-for-plots-that-are-not-rendering) if visualizations are not rendering.
torch.manual_seed(42)
dtype = torch.float
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
First, we need to load the MNIST data and partition it into training, validation, and test sets.
Note: this will download the dataset if necessary.
BATCH_SIZE = 512
train_loader, valid_loader, test_loader = load_mnist(batch_size=BATCH_SIZE)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 9912422/9912422 [00:00<00:00, 138305865.28it/s]
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 28881/28881 [00:00<00:00, 158140592.46it/s]
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████| 1648877/1648877 [00:00<00:00, 38484048.55it/s]
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
100%|██████████| 4542/4542 [00:00<00:00, 13084154.37it/s]
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Create a client object to interface with Ax APIs. By default this runs locally without storage.
ax_client = AxClient()
[INFO 12-09 18:40:52] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
An experiment consists of a search space (parameters and parameter constraints) and optimization configuration (objective name, minimization setting, and outcome constraints).
# Create an experiment with required arguments: name, parameters, and objective_name.
ax_client.create_experiment(
name="tune_cnn_on_mnist", # The name of the experiment.
parameters=[
{
"name": "lr", # The name of the parameter.
"type": "range", # The type of the parameter ("range", "choice" or "fixed").
"bounds": [1e-6, 0.4], # The bounds for range parameters.
# "values" The possible values for choice parameters .
# "value" The fixed value for fixed parameters.
"value_type": "float", # Optional, the value type ("int", "float", "bool" or "str"). Defaults to inference from type of "bounds".
"log_scale": True, # Optional, whether to use a log scale for range parameters. Defaults to False.
# "is_ordered" Optional, a flag for choice parameters.
},
{
"name": "momentum",
"type": "range",
"bounds": [0.0, 1.0],
},
],
objectives={"accuracy": ObjectiveProperties(minimize=False)}, # The objective name and minimization setting.
# parameter_constraints: Optional, a list of strings of form "p1 >= p2" or "p1 + p2 <= some_bound".
# outcome_constraints: Optional, a list of strings of form "constrained_metric <= some_bound".
)
[INFO 12-09 18:40:52] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter momentum. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict. [INFO 12-09 18:40:52] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='lr', parameter_type=FLOAT, range=[1e-06, 0.4], log_scale=True), RangeParameter(name='momentum', parameter_type=FLOAT, range=[0.0, 1.0])], parameter_constraints=[]). [INFO 12-09 18:40:52] ax.modelbridge.dispatch_utils: Using Models.BOTORCH_MODULAR since there are more ordered parameters than there are categories for the unordered categorical parameters. [INFO 12-09 18:40:52] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=2 num_trials=None use_batch_trials=False [INFO 12-09 18:40:52] ax.modelbridge.dispatch_utils: calculated num_initialization_trials=5 [INFO 12-09 18:40:52] ax.modelbridge.dispatch_utils: num_completed_initialization_trials=0 num_remaining_initialization_trials=5 [INFO 12-09 18:40:52] ax.modelbridge.dispatch_utils: `verbose`, `disable_progbar`, and `jit_compile` are not yet supported when using `choose_generation_strategy` with ModularBoTorchModel, dropping these arguments. [INFO 12-09 18:40:52] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting.
First we define a simple CNN class to classify the MNIST images
class CNN(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, kernel_size=5, stride=1)
self.fc1 = nn.Linear(8 * 8 * 20, 64)
self.fc2 = nn.Linear(64, 10)
def forward(self, x: Tensor) -> Tensor:
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 3, 3)
x = x.view(-1, 8 * 8 * 20)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=-1)
In this tutorial, we want to optimize classification accuracy on the validation set as a function of the learning rate and momentum. The train_evaluate
function takes in a parameterization (set of parameter values), computes the classification accuracy, and returns that metric.
def train_evaluate(parameterization):
"""
Train the model and then compute an evaluation metric.
In this tutorial, the CNN utils package is doing a lot of work
under the hood:
- `train` initializes the network, defines the loss function
and optimizer, performs the training loop, and returns the
trained model.
- `evaluate` computes the accuracy of the model on the
evaluation dataset and returns the metric.
For your use case, you can define training and evaluation functions
of your choosing.
"""
net = CNN()
net = train(
net=net,
train_loader=train_loader,
parameters=parameterization,
dtype=dtype,
device=device,
)
return evaluate(
net=net,
data_loader=valid_loader,
dtype=dtype,
device=device,
)
First we use attach_trial
to attach a custom trial with manually-chosen parameters. This step is optional, but we include it here to demonstrate adding manual trials and to serve as a baseline model with decent performance.
# Attach the trial
ax_client.attach_trial(
parameters={"lr": 0.000026, "momentum": 0.58}
)
# Get the parameters and run the trial
baseline_parameters = ax_client.get_trial_parameters(trial_index=0)
ax_client.complete_trial(trial_index=0, raw_data=train_evaluate(baseline_parameters))
[INFO 12-09 18:40:53] ax.core.experiment: Attached custom parameterizations [{'lr': 2.6e-05, 'momentum': 0.58}] as trial 0. [INFO 12-09 18:40:59] ax.service.ax_client: Completed trial 0 with data: {'accuracy': (0.841833, None)}.
Now we start the optimization loop.
At each step, the user queries the client for a new trial then submits the evaluation of that trial back to the client.
Note that Ax auto-selects an appropriate optimization algorithm based on the search space. For more advanced use cases that require a specific optimization algorithm, pass a generation_strategy
argument into the AxClient
constructor. Note that when Bayesian Optimization is used, generating new trials may take a few minutes.
for i in range(25):
parameters, trial_index = ax_client.get_next_trial()
# Local evaluation here can be replaced with deployment to external system.
ax_client.complete_trial(trial_index=trial_index, raw_data=train_evaluate(parameters))
[INFO 12-09 18:40:59] ax.service.ax_client: Generated new trial 1 with parameters {'lr': 0.009955, 'momentum': 0.633423}. [INFO 12-09 18:41:06] ax.service.ax_client: Completed trial 1 with data: {'accuracy': (0.100333, None)}. [INFO 12-09 18:41:06] ax.service.ax_client: Generated new trial 2 with parameters {'lr': 5e-06, 'momentum': 0.022851}. [INFO 12-09 18:41:12] ax.service.ax_client: Completed trial 2 with data: {'accuracy': (0.318667, None)}. [INFO 12-09 18:41:12] ax.service.ax_client: Generated new trial 3 with parameters {'lr': 7e-06, 'momentum': 0.176948}. [INFO 12-09 18:41:19] ax.service.ax_client: Completed trial 3 with data: {'accuracy': (0.4585, None)}. [INFO 12-09 18:41:19] ax.service.ax_client: Generated new trial 4 with parameters {'lr': 8.2e-05, 'momentum': 0.90883}. [INFO 12-09 18:41:26] ax.service.ax_client: Completed trial 4 with data: {'accuracy': (0.926, None)}. [INFO 12-09 18:41:26] ax.service.ax_client: Generated new trial 5 with parameters {'lr': 0.000302, 'momentum': 0.341904}. [INFO 12-09 18:41:32] ax.service.ax_client: Completed trial 5 with data: {'accuracy': (0.929, None)}. [INFO 12-09 18:41:32] ax.service.ax_client: Generated new trial 6 with parameters {'lr': 0.000137, 'momentum': 0.590917}. [INFO 12-09 18:41:39] ax.service.ax_client: Completed trial 6 with data: {'accuracy': (0.92, None)}. [INFO 12-09 18:41:39] ax.service.ax_client: Generated new trial 7 with parameters {'lr': 1e-05, 'momentum': 1.0}. [INFO 12-09 18:41:46] ax.service.ax_client: Completed trial 7 with data: {'accuracy': (0.860167, None)}. [INFO 12-09 18:41:46] ax.service.ax_client: Generated new trial 8 with parameters {'lr': 0.000246, 'momentum': 0.0}. [INFO 12-09 18:41:53] ax.service.ax_client: Completed trial 8 with data: {'accuracy': (0.888833, None)}. [INFO 12-09 18:41:53] ax.service.ax_client: Generated new trial 9 with parameters {'lr': 0.000149, 'momentum': 0.286357}. [INFO 12-09 18:42:00] ax.service.ax_client: Completed trial 9 with data: {'accuracy': (0.901667, None)}. [INFO 12-09 18:42:00] ax.service.ax_client: Generated new trial 10 with parameters {'lr': 1e-06, 'momentum': 1.0}. [INFO 12-09 18:42:07] ax.service.ax_client: Completed trial 10 with data: {'accuracy': (0.560333, None)}. [INFO 12-09 18:42:08] ax.service.ax_client: Generated new trial 11 with parameters {'lr': 3.7e-05, 'momentum': 1.0}. [INFO 12-09 18:42:14] ax.service.ax_client: Completed trial 11 with data: {'accuracy': (0.759333, None)}. [INFO 12-09 18:42:15] ax.service.ax_client: Generated new trial 12 with parameters {'lr': 0.000261, 'momentum': 0.913716}. [INFO 12-09 18:42:22] ax.service.ax_client: Completed trial 12 with data: {'accuracy': (0.947667, None)}. [INFO 12-09 18:42:22] ax.service.ax_client: Generated new trial 13 with parameters {'lr': 0.000171, 'momentum': 0.802859}. [INFO 12-09 18:42:29] ax.service.ax_client: Completed trial 13 with data: {'accuracy': (0.942667, None)}. [INFO 12-09 18:42:30] ax.service.ax_client: Generated new trial 14 with parameters {'lr': 0.000161, 'momentum': 1.0}. [INFO 12-09 18:42:37] ax.service.ax_client: Completed trial 14 with data: {'accuracy': (0.843, None)}. [INFO 12-09 18:42:37] ax.service.ax_client: Generated new trial 15 with parameters {'lr': 0.000312, 'momentum': 0.667594}. [INFO 12-09 18:42:44] ax.service.ax_client: Completed trial 15 with data: {'accuracy': (0.9475, None)}. [INFO 12-09 18:42:45] ax.service.ax_client: Generated new trial 16 with parameters {'lr': 0.000834, 'momentum': 0.0}. [INFO 12-09 18:42:51] ax.service.ax_client: Completed trial 16 with data: {'accuracy': (0.741333, None)}. [INFO 12-09 18:42:52] ax.service.ax_client: Generated new trial 17 with parameters {'lr': 0.000743, 'momentum': 1.0}. [INFO 12-09 18:42:59] ax.service.ax_client: Completed trial 17 with data: {'accuracy': (0.102667, None)}. [INFO 12-09 18:43:00] ax.service.ax_client: Generated new trial 18 with parameters {'lr': 0.000235, 'momentum': 0.66382}. [INFO 12-09 18:43:06] ax.service.ax_client: Completed trial 18 with data: {'accuracy': (0.936667, None)}. [INFO 12-09 18:43:07] ax.modelbridge.base: Untransformed parameter 0.40000000000000013 greater than upper bound 0.4, clamping [INFO 12-09 18:43:07] ax.service.ax_client: Generated new trial 19 with parameters {'lr': 0.4, 'momentum': 0.0}. [INFO 12-09 18:43:13] ax.service.ax_client: Completed trial 19 with data: {'accuracy': (0.101, None)}. [INFO 12-09 18:43:14] ax.service.ax_client: Generated new trial 20 with parameters {'lr': 6.4e-05, 'momentum': 0.453937}. [INFO 12-09 18:43:21] ax.service.ax_client: Completed trial 20 with data: {'accuracy': (0.897, None)}. [INFO 12-09 18:43:22] ax.service.ax_client: Generated new trial 21 with parameters {'lr': 4.7e-05, 'momentum': 0.0}. [INFO 12-09 18:43:29] ax.service.ax_client: Completed trial 21 with data: {'accuracy': (0.790333, None)}. [INFO 12-09 18:43:29] ax.service.ax_client: Generated new trial 22 with parameters {'lr': 4e-06, 'momentum': 1.0}. [INFO 12-09 18:43:36] ax.service.ax_client: Completed trial 22 with data: {'accuracy': (0.830667, None)}. [INFO 12-09 18:43:37] ax.service.ax_client: Generated new trial 23 with parameters {'lr': 8.4e-05, 'momentum': 0.692917}. [INFO 12-09 18:43:43] ax.service.ax_client: Completed trial 23 with data: {'accuracy': (0.920333, None)}. [INFO 12-09 18:43:45] ax.service.ax_client: Generated new trial 24 with parameters {'lr': 0.000412, 'momentum': 0.0}. [INFO 12-09 18:43:51] ax.service.ax_client: Completed trial 24 with data: {'accuracy': (0.849, None)}. [INFO 12-09 18:43:52] ax.service.ax_client: Generated new trial 25 with parameters {'lr': 0.000237, 'momentum': 0.507928}. [INFO 12-09 18:43:59] ax.service.ax_client: Completed trial 25 with data: {'accuracy': (0.928333, None)}.
By default, Ax restricts number of trials that can run in parallel for some optimization stages, in order to improve the optimization performance and reduce the number of trials that the optimization will require. To check the maximum parallelism for each optimization stage:
ax_client.get_max_parallelism()
[(5, 5), (-1, 3)]
The output of this function is a list of tuples of form (number of trials, max parallelism), so the example above means "the max parallelism is 5 for the first 5 trials and 3 for all subsequent trials." This is because the first 5 trials are produced quasi-randomly and can all be evaluated at once, and subsequent trials are produced via Bayesian optimization, which converges on optimal point in fewer trials when parallelism is limited. MaxParallelismReachedException indicates that the parallelism limit has been reached –– refer to the 'Service API Exceptions Meaning and Handling' section at the end of the tutorial for handling.
ax_client.get_trials_data_frame()
[WARNING 12-09 18:43:59] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.
trial_index | arm_name | trial_status | generation_method | accuracy | lr | momentum | |
---|---|---|---|---|---|---|---|
0 | 0 | 0_0 | COMPLETED | Manual | 0.841833 | 0.000026 | 0.580000 |
1 | 1 | 1_0 | COMPLETED | Sobol | 0.100333 | 0.009955 | 0.633423 |
2 | 2 | 2_0 | COMPLETED | Sobol | 0.318667 | 0.000005 | 0.022851 |
3 | 3 | 3_0 | COMPLETED | Sobol | 0.458500 | 0.000007 | 0.176948 |
4 | 4 | 4_0 | COMPLETED | Sobol | 0.926000 | 0.000082 | 0.908830 |
5 | 5 | 5_0 | COMPLETED | Sobol | 0.929000 | 0.000302 | 0.341904 |
6 | 6 | 6_0 | COMPLETED | BoTorch | 0.920000 | 0.000137 | 0.590917 |
7 | 7 | 7_0 | COMPLETED | BoTorch | 0.860167 | 0.000010 | 1.000000 |
8 | 8 | 8_0 | COMPLETED | BoTorch | 0.888833 | 0.000246 | 0.000000 |
9 | 9 | 9_0 | COMPLETED | BoTorch | 0.901667 | 0.000149 | 0.286357 |
10 | 10 | 10_0 | COMPLETED | BoTorch | 0.560333 | 0.000001 | 1.000000 |
11 | 11 | 11_0 | COMPLETED | BoTorch | 0.759333 | 0.000037 | 1.000000 |
12 | 12 | 12_0 | COMPLETED | BoTorch | 0.947667 | 0.000261 | 0.913716 |
13 | 13 | 13_0 | COMPLETED | BoTorch | 0.942667 | 0.000171 | 0.802859 |
14 | 14 | 14_0 | COMPLETED | BoTorch | 0.843000 | 0.000161 | 1.000000 |
15 | 15 | 15_0 | COMPLETED | BoTorch | 0.947500 | 0.000312 | 0.667594 |
16 | 16 | 16_0 | COMPLETED | BoTorch | 0.741333 | 0.000834 | 0.000000 |
17 | 17 | 17_0 | COMPLETED | BoTorch | 0.102667 | 0.000743 | 1.000000 |
18 | 18 | 18_0 | COMPLETED | BoTorch | 0.936667 | 0.000235 | 0.663820 |
19 | 19 | 19_0 | COMPLETED | BoTorch | 0.101000 | 0.400000 | 0.000000 |
20 | 20 | 20_0 | COMPLETED | BoTorch | 0.897000 | 0.000064 | 0.453937 |
21 | 21 | 21_0 | COMPLETED | BoTorch | 0.790333 | 0.000047 | 0.000000 |
22 | 22 | 22_0 | COMPLETED | BoTorch | 0.830667 | 0.000004 | 1.000000 |
23 | 23 | 23_0 | COMPLETED | BoTorch | 0.920333 | 0.000084 | 0.692917 |
24 | 24 | 24_0 | COMPLETED | BoTorch | 0.849000 | 0.000412 | 0.000000 |
25 | 25 | 25_0 | COMPLETED | BoTorch | 0.928333 | 0.000237 | 0.507928 |
Once it's complete, we can access the best parameters found, as well as the corresponding metric values. Note that these parameters may not necessarily be the set that yielded the highest observed accuracy because Ax uses the highest model predicted accuracy to choose the best parameters (see here for more details). Due to randomness in the data or the algorithm itself, using observed accuracy may result in choosing an outlier for the best set of parameters. Using the model predicted best will use the model to regularize the observations and reduce the likelihood of picking some outlier in the data.
best_parameters, values = ax_client.get_best_parameters()
best_parameters
{'lr': 0.00023468478584700203, 'momentum': 0.6638197948979379}
mean, covariance = values
mean
{'accuracy': 0.9503630011366103}
Contour plot showing classification accuracy as a function of the two hyperparameters.
The black squares show points that we have actually run; notice how they are clustered in the optimal region.
render(ax_client.get_contour_plot(param_x="lr", param_y="momentum", metric_name="accuracy"))
[INFO 12-09 18:44:00] ax.service.ax_client: Retrieving contour plot with parameter 'lr' on X-axis and 'momentum' on Y-axis, for metric 'accuracy'. Remaining parameters are affixed to the middle of their range.
Here we plot the optimization trace, showing the progression of finding the point with the optimal objective:
render(
ax_client.get_optimization_trace()
)
Note that the resulting accuracy on the test set generally won't be the same as the maximum accuracy achieved on the evaluation set throughout optimization.
df = ax_client.get_trials_data_frame()
best_arm_idx = df.trial_index[df["accuracy"] == df["accuracy"].max()].values[0]
best_arm = ax_client.get_trial_parameters(best_arm_idx)
best_arm
[WARNING 12-09 18:44:01] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.
{'lr': 0.00026103181236379136, 'momentum': 0.9137159656656234}
combined_train_valid_set = torch.utils.data.ConcatDataset(
[
train_loader.dataset.dataset,
valid_loader.dataset.dataset,
]
)
combined_train_valid_loader = torch.utils.data.DataLoader(
combined_train_valid_set,
batch_size=BATCH_SIZE,
shuffle=True,
)
net = train(
net=CNN(),
train_loader=combined_train_valid_loader,
parameters=best_arm,
dtype=dtype,
device=device,
)
test_accuracy = evaluate(
net=net,
data_loader=test_loader,
dtype=dtype,
device=device,
)
print(f"Classification Accuracy (test set): {round(test_accuracy*100, 2)}%")
Classification Accuracy (test set): 98.28%
We can serialize the state of optimization to JSON and save it to a .json
file or save it to the SQL backend. For the former:
ax_client.save_to_json_file() # For custom filepath, pass `filepath` argument.
[INFO 12-09 18:44:31] ax.service.ax_client: Saved JSON-serialized state of optimization to `ax_client_snapshot.json`.
restored_ax_client = (
AxClient.load_from_json_file()
) # For custom filepath, pass `filepath` argument.
/tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. /tmp/tmp.bwZKabbZig/Ax-main/ax/core/data.py:203: FutureWarning: Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object. [INFO 12-09 18:44:31] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
To store state of optimization to an SQL backend, first follow setup instructions on Ax website.
Having set up the SQL backend, pass DBSettings
to AxClient
on instantiation (note that SQLAlchemy
dependency will have to be installed – for installation, refer to optional dependencies on Ax website):
from ax.storage.sqa_store.structs import DBSettings
# URL is of the form "dialect+driver://username:password@host:port/database".
db_settings = DBSettings(url="sqlite:///foo.db")
# Instead of URL, can provide a `creator function`; can specify custom encoders/decoders if necessary.
new_ax = AxClient(db_settings=db_settings)
[INFO 12-09 18:44:32] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
When valid DBSettings
are passed into AxClient
, a unique experiment name is a required argument (name
) to ax_client.create_experiment
. The state of the optimization is auto-saved any time it changes (i.e. a new trial is added or completed, etc).
To reload an optimization state later, instantiate AxClient
with the same DBSettings
and use ax_client.load_experiment_from_database(experiment_name="my_experiment")
.
Evaluation failure: should any optimization iterations fail during evaluation, log_trial_failure
will ensure that the same trial is not proposed again.
_, trial_index = ax_client.get_next_trial()
ax_client.log_trial_failure(trial_index=trial_index)
[INFO 12-09 18:44:33] ax.service.ax_client: Generated new trial 26 with parameters {'lr': 0.000246, 'momentum': 0.77497}. [INFO 12-09 18:44:33] ax.service.ax_client: Registered failure of trial 26.
Need to run many trials in parallel: for optimal results and optimization efficiency, we strongly recommend sequential optimization (generating a few trials, then waiting for them to be completed with evaluation data). However, if your use case needs to dispatch many trials in parallel before they are updated with data and you are running into the "All trials for current model have been generated, but not enough data has been observed to fit next model" error, instantiate AxClient
as AxClient(enforce_sequential_optimization=False)
.
DataRequiredError
: Ax generation strategy needs to be updated with more data to proceed to the next optimization model. When the optimization moves from initialization stage to the Bayesian optimization stage, the underlying BayesOpt model needs sufficient data to train. For optimal results and optimization efficiency (finding the optimal point in the least number of trials), we recommend sequential optimization (generating a few trials, then waiting for them to be completed with evaluation data). Therefore, the correct way to handle this exception is to wait until more trial evaluations complete and log their data via ax_client.complete_trial(...)
.
However, if there is strong need to generate more trials before more data is available, instantiate AxClient
as AxClient(enforce_sequential_optimization=False)
. With this setting, as many trials will be generated from the initialization stage as requested, and the optimization will move to the BayesOpt stage whenever enough trials are completed.
MaxParallelismReachedException
: generation strategy restricts the number of trials that can be run simultaneously (to encourage sequential optimization), and the parallelism limit has been reached. The correct way to handle this exception is the same as DataRequiredError
– to wait until more trial evluations complete and log their data via ax_client.complete_trial(...)
.
In some cases higher parallelism is important, so enforce_sequential_optimization=False
kwarg to AxClient allows the user to suppress limiting of parallelism. It's also possible to override the default parallelism setting for all stages of the optimization by passing choose_generation_strategy_kwargs
to ax_client.create_experiment
:
ax_client = AxClient()
ax_client.create_experiment(
parameters=[
{"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
{"name": "y", "type": "range", "bounds": [0.0, 15.0]},
],
# Sets max parallelism to 10 for all steps of the generation strategy.
choose_generation_strategy_kwargs={"max_parallelism_override": 10},
)
[INFO 12-09 18:44:34] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points. [INFO 12-09 18:44:34] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict. [INFO 12-09 18:44:34] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter y. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict. [INFO 12-09 18:44:34] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='y', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[]). [INFO 12-09 18:44:34] ax.modelbridge.dispatch_utils: Using Models.BOTORCH_MODULAR since there are more ordered parameters than there are categories for the unordered categorical parameters. [INFO 12-09 18:44:34] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=2 num_trials=None use_batch_trials=False [INFO 12-09 18:44:34] ax.modelbridge.dispatch_utils: calculated num_initialization_trials=5 [INFO 12-09 18:44:34] ax.modelbridge.dispatch_utils: num_completed_initialization_trials=0 num_remaining_initialization_trials=5 [INFO 12-09 18:44:34] ax.modelbridge.dispatch_utils: `verbose`, `disable_progbar`, and `jit_compile` are not yet supported when using `choose_generation_strategy` with ModularBoTorchModel, dropping these arguments. [INFO 12-09 18:44:34] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting.
ax_client.get_max_parallelism() # Max parallelism is now 10 for all stages of the optimization.
[(5, 10), (-1, 10)]
Total runtime of script: 3 minutes, 49.29 seconds.