This notebook serves as a quickstart guide for using the Ax library with the SubmitIt library in an ask-tell loop. SubmitIt is a Python toolbox for submitting jobs to Slurm.
The notebook demonstrates how to use the Ax client in an ask-tell loop where each trial is scheduled to run on a Slurm cluster asynchronously.
To use this script, run it on a slurm node either as an interactive notebook or export it as a Python script and run it as a Slurm job.
Let's start by importing the necessary libraries.
import time
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.notebook.plotting import render
from ax.service.utils.report_utils import exp_to_df
from submitit import AutoExecutor, LocalJob, DebugJob
We'll define a simple function to optimize. This function takes two parameters, and returns a single metric.
def evaluate(parameters):
x = parameters["x"]
y = parameters["y"]
return {"result": (x - 3)**2 + (y - 4)**2}
Note: SubmitIt's CommandFunction allows you to define commands to run on the node and then redirects the standard output.
We'll use Ax's Service API for this example. We start by initializing an AxClient and creating an experiment.
ax_client = AxClient()
ax_client.create_experiment(
name="my_experiment",
parameters=[
{"name": "x", "type": "range", "bounds": [-10.0, 10.0]},
{"name": "y", "type": "range", "bounds": [-10.0, 10.0]},
],
objectives={"result": ObjectiveProperties(minimize=True)},
parameter_constraints=["x + y <= 2.0"], # Optional.
)
[INFO 07-23 19:43:49] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 07-23 19:43:49] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 07-23 19:43:49] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter y. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 07-23 19:43:49] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x', parameter_type=FLOAT, range=[-10.0, 10.0]), RangeParameter(name='y', parameter_type=FLOAT, range=[-10.0, 10.0])], parameter_constraints=[ParameterConstraint(1.0*x + 1.0*y <= 2.0)]).
[INFO 07-23 19:43:49] ax.modelbridge.dispatch_utils: Using Models.BOTORCH_MODULAR since there is at least one ordered parameter and there are no unordered categorical parameters.
[INFO 07-23 19:43:49] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=2 num_trials=None use_batch_trials=False
[INFO 07-23 19:43:49] ax.modelbridge.dispatch_utils: calculated num_initialization_trials=5
[INFO 07-23 19:43:49] ax.modelbridge.dispatch_utils: num_completed_initialization_trials=0 num_remaining_initialization_trials=5
[INFO 07-23 19:43:49] ax.modelbridge.dispatch_utils: `verbose`, `disable_progbar`, and `jit_compile` are not yet supported when using `choose_generation_strategy` with ModularBoTorchModel, dropping these arguments.
[INFO 07-23 19:43:49] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting.
Other commonly used parameters types include choice
parameters and fixed
parameters.
Tip 1: you can specify additional information for parameters such as log_scale
, if a parameter operates at a log-scale and is_ordered
for choice parameters that have a meaningful ordering.
Tip 2: Ax is an excellent choice for multi-objective optimization problems when there are multiple competing objectives and the goal is to find all Pareto-optimal solutions.
Tip 3: One can define constraints on both the parameters and the outcome.
We'll use SubmitIt's AutoExecutor
for this example. We start by initializing an AutoExecutor
, and setting a few commonly used parameters.
# Log folder and cluster. Specify cluster='local' or cluster='debug' to run the jobs locally during development.
# When we're are ready for deployment, switch to cluster='slurm'
executor = AutoExecutor(folder="/tmp/submitit_runs", cluster='debug')
executor.update_parameters(timeout_min=60) # Timeout of the slurm job. Not including slurm scheduling delay.
executor.update_parameters(cpus_per_task=2)
Other commonly used Slurm parameters include partition
, ntasks_per_node
, cpus_per_task
, cpus_per_gpu
, gpus_per_node
, gpus_per_task
, qos
, mem
, mem_per_gpu
, mem_per_cpu
, account
.
Now, we're ready to run the optimization loop. We'll use an ask-tell loop, where we ask Ax for a suggestion, evaluate it using our function, and then tell Ax the result.
The example loop schedules new jobs whenever there is availability. For tasks that take a similar amount of time regardless of the parameters, it may make more sense to wait for the whole batch to finish before scheduling the next (so ax can make better informed parameter choices).
Note that get_next_trials
may not use all available num_parallel_jobs
if it doesn't have good parameter candidates to run.
total_budget = 10
num_parallel_jobs = 3
jobs = []
submitted_jobs = 0
# Run until all the jobs have finished and our budget is used up.
while submitted_jobs < total_budget or jobs:
for job, trial_index in jobs[:]:
# Poll if any jobs completed
# Local and debug jobs don't run until .result() is called.
if job.done() or type(job) in [LocalJob, DebugJob]:
result = job.result()
ax_client.complete_trial(trial_index=trial_index, raw_data=result)
jobs.remove((job, trial_index))
# Schedule new jobs if there is availablity
trial_index_to_param, _ = ax_client.get_next_trials(
max_trials=min(num_parallel_jobs - len(jobs), total_budget - submitted_jobs))
for trial_index, parameters in trial_index_to_param.items():
job = executor.submit(evaluate, parameters)
submitted_jobs += 1
jobs.append((job, trial_index))
time.sleep(1)
# Display the current trials.
display(exp_to_df(ax_client.experiment))
# Sleep for a bit before checking the jobs again to avoid overloading the cluster.
# If you have a large number of jobs, consider adding a sleep statement in the job polling loop aswell.
time.sleep(30)
/tmp/tmp.DL1QmpHQMI/Ax-main/ax/modelbridge/cross_validation.py:462: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. [INFO 07-23 19:43:50] ax.service.ax_client: Generated new trial 0 with parameters {'x': -6.786203, 'y': 4.281123} using model Sobol.
/tmp/tmp.DL1QmpHQMI/Ax-main/ax/modelbridge/cross_validation.py:462: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. [INFO 07-23 19:43:50] ax.service.ax_client: Generated new trial 1 with parameters {'x': -6.420702, 'y': -1.444372} using model Sobol.
/tmp/tmp.DL1QmpHQMI/Ax-main/ax/modelbridge/cross_validation.py:462: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. [INFO 07-23 19:43:50] ax.service.ax_client: Generated new trial 2 with parameters {'x': -0.634777, 'y': -0.716038} using model Sobol.
[WARNING 07-23 19:43:53] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.
[INFO 07-23 19:43:53] ax.service.utils.report_utils: No results present for the specified metrics `[Metric('result')]`. Returning arm parameters and metadata only.
trial_index | arm_name | trial_status | generation_method | x | y | |
---|---|---|---|---|---|---|
0 | 0 | 0_0 | RUNNING | Sobol | -6.786203 | 4.281123 |
1 | 1 | 1_0 | RUNNING | Sobol | -6.420702 | -1.444372 |
2 | 2 | 2_0 | RUNNING | Sobol | -0.634777 | -0.716038 |
[INFO 07-23 19:44:23] ax.service.ax_client: Completed trial 0 with data: {'result': (95.848795, None)}.
[INFO 07-23 19:44:23] ax.service.ax_client: Completed trial 1 with data: {'result': (118.390815, None)}.
[INFO 07-23 19:44:23] ax.service.ax_client: Completed trial 2 with data: {'result': (35.452622, None)}.
/tmp/tmp.DL1QmpHQMI/Ax-main/ax/modelbridge/cross_validation.py:462: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. [INFO 07-23 19:44:23] ax.service.ax_client: Generated new trial 3 with parameters {'x': -0.819522, 'y': 0.937298} using model Sobol.
/tmp/tmp.DL1QmpHQMI/Ax-main/ax/core/data.py:286: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation. /tmp/tmp.DL1QmpHQMI/Ax-main/ax/modelbridge/cross_validation.py:462: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. [INFO 07-23 19:44:23] ax.service.ax_client: Generated new trial 4 with parameters {'x': 2.055257, 'y': -0.98436} using model Sobol.
/tmp/tmp.DL1QmpHQMI/Ax-main/ax/core/data.py:286: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation. [WARNING 07-23 19:44:25] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.
trial_index | arm_name | trial_status | generation_method | result | x | y | |
---|---|---|---|---|---|---|---|
0 | 0 | 0_0 | COMPLETED | Sobol | 95.848795 | -6.786203 | 4.281123 |
1 | 1 | 1_0 | COMPLETED | Sobol | 118.390815 | -6.420702 | -1.444372 |
2 | 2 | 2_0 | COMPLETED | Sobol | 35.452622 | -0.634777 | -0.716038 |
3 | 3 | 3_0 | RUNNING | Sobol | NaN | -0.819522 | 0.937298 |
4 | 4 | 4_0 | RUNNING | Sobol | NaN | 2.055257 | -0.984360 |
[INFO 07-23 19:44:55] ax.service.ax_client: Completed trial 3 with data: {'result': (23.968892, None)}.
[INFO 07-23 19:44:55] ax.service.ax_client: Completed trial 4 with data: {'result': (25.736387, None)}.
[INFO 07-23 19:45:00] ax.service.ax_client: Generated new trial 5 with parameters {'x': 10.0, 'y': -8.0} using model BoTorch.
/tmp/tmp.DL1QmpHQMI/Ax-main/ax/core/data.py:286: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
[INFO 07-23 19:45:04] ax.service.ax_client: Generated new trial 6 with parameters {'x': 10.0, 'y': -8.0} using model BoTorch.
/tmp/tmp.DL1QmpHQMI/Ax-main/ax/core/data.py:286: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
[INFO 07-23 19:45:08] ax.service.ax_client: Generated new trial 7 with parameters {'x': 10.0, 'y': -8.0} using model BoTorch.
/tmp/tmp.DL1QmpHQMI/Ax-main/ax/core/data.py:286: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation. [WARNING 07-23 19:45:12] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.
trial_index | arm_name | trial_status | generation_method | result | x | y | |
---|---|---|---|---|---|---|---|
0 | 0 | 0_0 | COMPLETED | Sobol | 95.848795 | -6.786203 | 4.281123 |
1 | 1 | 1_0 | COMPLETED | Sobol | 118.390815 | -6.420702 | -1.444372 |
2 | 2 | 2_0 | COMPLETED | Sobol | 35.452622 | -0.634777 | -0.716038 |
3 | 3 | 3_0 | COMPLETED | Sobol | 23.968892 | -0.819522 | 0.937298 |
4 | 4 | 4_0 | COMPLETED | Sobol | 25.736387 | 2.055257 | -0.984360 |
5 | 5 | 5_0 | RUNNING | BoTorch | NaN | 10.000000 | -8.000000 |
6 | 6 | 6_0 | RUNNING | BoTorch | NaN | 10.000000 | -8.000000 |
7 | 7 | 7_0 | RUNNING | BoTorch | NaN | 10.000000 | -8.000000 |
[INFO 07-23 19:45:42] ax.service.ax_client: Completed trial 5 with data: {'result': (193.0, None)}.
[INFO 07-23 19:45:42] ax.service.ax_client: Completed trial 6 with data: {'result': (193.0, None)}.
[INFO 07-23 19:45:42] ax.service.ax_client: Completed trial 7 with data: {'result': (193.0, None)}.
[INFO 07-23 19:45:46] ax.service.ax_client: Generated new trial 8 with parameters {'x': 0.807427, 'y': 1.192573} using model BoTorch.
/tmp/tmp.DL1QmpHQMI/Ax-main/ax/core/data.py:286: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
[INFO 07-23 19:45:51] ax.service.ax_client: Generated new trial 9 with parameters {'x': 0.807236, 'y': 1.192764} using model BoTorch.
/tmp/tmp.DL1QmpHQMI/Ax-main/ax/core/data.py:286: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation. [WARNING 07-23 19:45:53] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.
trial_index | arm_name | trial_status | generation_method | result | x | y | |
---|---|---|---|---|---|---|---|
0 | 0 | 0_0 | COMPLETED | Sobol | 95.848795 | -6.786203 | 4.281123 |
1 | 1 | 1_0 | COMPLETED | Sobol | 118.390815 | -6.420702 | -1.444372 |
2 | 2 | 2_0 | COMPLETED | Sobol | 35.452622 | -0.634777 | -0.716038 |
3 | 3 | 3_0 | COMPLETED | Sobol | 23.968892 | -0.819522 | 0.937298 |
4 | 4 | 4_0 | COMPLETED | Sobol | 25.736387 | 2.055257 | -0.984360 |
5 | 5 | 5_0 | COMPLETED | BoTorch | 193.000000 | 10.000000 | -8.000000 |
6 | 6 | 6_0 | COMPLETED | BoTorch | 193.000000 | 10.000000 | -8.000000 |
7 | 7 | 7_0 | COMPLETED | BoTorch | 193.000000 | 10.000000 | -8.000000 |
8 | 8 | 8_0 | RUNNING | BoTorch | NaN | 0.807427 | 1.192573 |
9 | 9 | 9_0 | RUNNING | BoTorch | NaN | 0.807236 | 1.192764 |
[INFO 07-23 19:46:23] ax.service.ax_client: Completed trial 8 with data: {'result': (12.689022, None)}.
[INFO 07-23 19:46:23] ax.service.ax_client: Completed trial 9 with data: {'result': (12.688788, None)}.
[WARNING 07-23 19:46:23] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.
trial_index | arm_name | trial_status | generation_method | result | x | y | |
---|---|---|---|---|---|---|---|
0 | 0 | 0_0 | COMPLETED | Sobol | 95.848795 | -6.786203 | 4.281123 |
1 | 1 | 1_0 | COMPLETED | Sobol | 118.390815 | -6.420702 | -1.444372 |
2 | 2 | 2_0 | COMPLETED | Sobol | 35.452622 | -0.634777 | -0.716038 |
3 | 3 | 3_0 | COMPLETED | Sobol | 23.968892 | -0.819522 | 0.937298 |
4 | 4 | 4_0 | COMPLETED | Sobol | 25.736387 | 2.055257 | -0.984360 |
5 | 5 | 5_0 | COMPLETED | BoTorch | 193.000000 | 10.000000 | -8.000000 |
6 | 6 | 6_0 | COMPLETED | BoTorch | 193.000000 | 10.000000 | -8.000000 |
7 | 7 | 7_0 | COMPLETED | BoTorch | 193.000000 | 10.000000 | -8.000000 |
8 | 8 | 8_0 | COMPLETED | BoTorch | 12.689022 | 0.807427 | 1.192573 |
9 | 9 | 9_0 | COMPLETED | BoTorch | 12.688788 | 0.807236 | 1.192764 |
We can retrieve the best parameters and render the response surface.
best_parameters, (means, covariances) = ax_client.get_best_parameters()
print(f'Best set of parameters: {best_parameters}')
print(f'Mean objective value: {means}')
# The covariance is only meaningful when multiple objectives are present.
render(ax_client.get_contour_plot())
[INFO 07-23 19:46:53] ax.service.ax_client: Retrieving contour plot with parameter 'x' on X-axis and 'y' on Y-axis, for metric 'result'. Remaining parameters are affixed to the middle of their range.
Best set of parameters: {'x': -0.8195218630135059, 'y': 0.937297884374857} Mean objective value: {'result': 24.02469800438149}
Total runtime of script: 3 minutes, 11.17 seconds.