Skip to main content
Version: 0.5.0

Hyperparameter Optimization on Slurm via SubmitIt

This notebook serves as a quickstart guide for using the Ax library with the SubmitIt library in an ask-tell loop. SubmitIt is a Python toolbox for submitting jobs to Slurm.

The notebook demonstrates how to use the Ax client in an ask-tell loop where each trial is scheduled to run on a Slurm cluster asynchronously.

To use this script, run it on a slurm node either as an interactive notebook or export it as a Python script and run it as a Slurm job.

Importing Necessary Libraries

Let's start by importing the necessary libraries.

import sys
import plotly.io as pio
if 'google.colab' in sys.modules:
pio.renderers.default = "colab"
%pip install ax-platform
import time
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.notebook.plotting import render
from ax.service.utils.report_utils import exp_to_df
from submitit import AutoExecutor, LocalJob, DebugJob

Defining the Function to Optimize

We'll define a simple function to optimize. This function takes two parameters, and returns a single metric.

def evaluate(parameters):
x = parameters["x"]
y = parameters["y"]
return {"result": (x - 3)**2 + (y - 4)**2}

Note: SubmitIt's CommandFunction allows you to define commands to run on the node and then redirects the standard output.

Setting up Ax

We'll use Ax's Service API for this example. We start by initializing an AxClient and creating an experiment.

ax_client = AxClient()
ax_client.create_experiment(
name="my_experiment",
parameters=[
{"name": "x", "type": "range", "bounds": [-10.0, 10.0]},
{"name": "y", "type": "range", "bounds": [-10.0, 10.0]},
],
objectives={"result": ObjectiveProperties(minimize=True)},
parameter_constraints=["x + y <= 2.0"], # Optional.
)
Out:

[INFO 02-03 18:44:14] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the verbose_logging argument to False. Note that float values in the logs are rounded to 6 decimal points.

Out:

[INFO 02-03 18:44:14] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.

Out:

[INFO 02-03 18:44:14] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter y. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.

Out:

[INFO 02-03 18:44:14] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x', parameter_type=FLOAT, range=[-10.0, 10.0]), RangeParameter(name='y', parameter_type=FLOAT, range=[-10.0, 10.0])], parameter_constraints=[ParameterConstraint(1.0*x + 1.0*y <= 2.0)]).

Out:

[INFO 02-03 18:44:14] ax.modelbridge.dispatch_utils: Using Models.BOTORCH_MODULAR since there is at least one ordered parameter and there are no unordered categorical parameters.

Out:

[INFO 02-03 18:44:14] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=2 num_trials=None use_batch_trials=False

Out:

[INFO 02-03 18:44:14] ax.modelbridge.dispatch_utils: calculated num_initialization_trials=5

Out:

[INFO 02-03 18:44:14] ax.modelbridge.dispatch_utils: num_completed_initialization_trials=0 num_remaining_initialization_trials=5

Out:

[INFO 02-03 18:44:14] ax.modelbridge.dispatch_utils: verbose, disable_progbar, and jit_compile are not yet supported when using choose_generation_strategy with ModularBoTorchModel, dropping these arguments.

Out:

[INFO 02-03 18:44:14] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting.

Other commonly used parameters types include choice parameters and fixed parameters.

Tip 1: you can specify additional information for parameters such as log_scale, if a parameter operates at a log-scale and is_ordered for choice parameters that have a meaningful ordering.

Tip 2: Ax is an excellent choice for multi-objective optimization problems when there are multiple competing objectives and the goal is to find all Pareto-optimal solutions.

Tip 3: One can define constraints on both the parameters and the outcome.

Setting up SubmitIt

We'll use SubmitIt's AutoExecutor for this example. We start by initializing an AutoExecutor, and setting a few commonly used parameters.

# Log folder and cluster. Specify cluster='local' or cluster='debug' to run the jobs locally during development.
# When we're are ready for deployment, switch to cluster='slurm'
executor = AutoExecutor(folder="/tmp/submitit_runs", cluster='debug')
executor.update_parameters(timeout_min=60) # Timeout of the slurm job. Not including slurm scheduling delay.
executor.update_parameters(cpus_per_task=2)

Other commonly used Slurm parameters include partition, ntasks_per_node, cpus_per_task, cpus_per_gpu, gpus_per_node, gpus_per_task, qos, mem, mem_per_gpu, mem_per_cpu, account.

Running the Optimization Loop

Now, we're ready to run the optimization loop. We'll use an ask-tell loop, where we ask Ax for a suggestion, evaluate it using our function, and then tell Ax the result.

The example loop schedules new jobs whenever there is availability. For tasks that take a similar amount of time regardless of the parameters, it may make more sense to wait for the whole batch to finish before scheduling the next (so ax can make better informed parameter choices).

Note that get_next_trials may not use all available num_parallel_jobs if it doesn't have good parameter candidates to run.

total_budget = 10
num_parallel_jobs = 3

jobs = []
submitted_jobs = 0
# Run until all the jobs have finished and our budget is used up.
while submitted_jobs < total_budget or jobs:
for job, trial_index in jobs[:]:
# Poll if any jobs completed
# Local and debug jobs don't run until .result() is called.
if job.done() or type(job) in [LocalJob, DebugJob]:
result = job.result()
ax_client.complete_trial(trial_index=trial_index, raw_data=result)
jobs.remove((job, trial_index))

# Schedule new jobs if there is availablity
trial_index_to_param, _ = ax_client.get_next_trials(
max_trials=min(num_parallel_jobs - len(jobs), total_budget - submitted_jobs))
for trial_index, parameters in trial_index_to_param.items():
job = executor.submit(evaluate, parameters)
submitted_jobs += 1
jobs.append((job, trial_index))
time.sleep(1)

# Display the current trials.
display(exp_to_df(ax_client.experiment))

# Sleep for a bit before checking the jobs again to avoid overloading the cluster.
# If you have a large number of jobs, consider adding a sleep statement in the job polling loop aswell.
time.sleep(30)
Out:

/home/runner/work/Ax/Ax/ax/modelbridge/cross_validation.py:439: UserWarning:

Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.

[INFO 02-03 18:44:14] ax.service.ax_client: Generated new trial 0 with parameters {'x': -1.455271, 'y': -1.835502} using model Sobol.

Out:

/home/runner/work/Ax/Ax/ax/modelbridge/cross_validation.py:439: UserWarning:

Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.

[INFO 02-03 18:44:14] ax.service.ax_client: Generated new trial 1 with parameters {'x': 4.54069, 'y': -8.546926} using model Sobol.

Out:

/home/runner/work/Ax/Ax/ax/modelbridge/cross_validation.py:439: UserWarning:

Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.

[INFO 02-03 18:44:14] ax.service.ax_client: Generated new trial 2 with parameters {'x': -7.889682, 'y': 0.413027} using model Sobol.

Out:

[INFO 02-03 18:44:17] ax.service.utils.report_utils: No results present for the specified metrics [Metric('result')]. Returning arm parameters and metadata only.

trial_indexarm_nametrial_statusgeneration_methodxy
000_0RUNNINGSobol-1.45527-1.8355
111_0RUNNINGSobol4.54069-8.54693
222_0RUNNINGSobol-7.889680.413027
Out:

[INFO 02-03 18:44:47] ax.service.ax_client: Completed trial 0 with data: {'result': (53.902525, None)}.

Out:

[INFO 02-03 18:44:47] ax.service.ax_client: Completed trial 1 with data: {'result': (159.799074, None)}.

Out:

[INFO 02-03 18:44:47] ax.service.ax_client: Completed trial 2 with data: {'result': (131.451556, None)}.

Out:

/home/runner/work/Ax/Ax/ax/modelbridge/cross_validation.py:439: UserWarning:

Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.

[INFO 02-03 18:44:47] ax.service.ax_client: Generated new trial 3 with parameters {'x': -6.582977, 'y': -7.303808} using model Sobol.

Out:

/home/runner/work/Ax/Ax/ax/core/data.py:295: FutureWarning:

The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.

/home/runner/work/Ax/Ax/ax/modelbridge/cross_validation.py:439: UserWarning:

Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.

[INFO 02-03 18:44:47] ax.service.ax_client: Generated new trial 4 with parameters {'x': -3.861047, 'y': -8.787177} using model Sobol.

Out:

/home/runner/work/Ax/Ax/ax/core/data.py:295: FutureWarning:

The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.

trial_indexarm_nametrial_statusgeneration_methodresultxy
000_0COMPLETEDSobol53.9025-1.45527-1.8355
111_0COMPLETEDSobol159.7994.54069-8.54693
222_0COMPLETEDSobol131.452-7.889680.413027
333_0RUNNINGSobolnan-6.58298-7.30381
444_0RUNNINGSobolnan-3.86105-8.78718
Out:

[INFO 02-03 18:45:19] ax.service.ax_client: Completed trial 3 with data: {'result': (219.609525, None)}.

Out:

[INFO 02-03 18:45:19] ax.service.ax_client: Completed trial 4 with data: {'result': (210.585854, None)}.

Out:

[INFO 02-03 18:45:24] ax.service.ax_client: Generated new trial 5 with parameters {'x': 3.925657, 'y': -1.925657} using model BoTorch.

Out:

/home/runner/work/Ax/Ax/ax/core/data.py:295: FutureWarning:

The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.

Out:

[INFO 02-03 18:45:28] ax.service.ax_client: Generated new trial 6 with parameters {'x': 3.924387, 'y': -1.924387} using model BoTorch.

Out:

/home/runner/work/Ax/Ax/ax/core/data.py:295: FutureWarning:

The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.

Out:

[INFO 02-03 18:45:34] ax.service.ax_client: Generated new trial 7 with parameters {'x': -5.063467, 'y': -2.61011} using model BoTorch.

Out:

/home/runner/work/Ax/Ax/ax/core/data.py:295: FutureWarning:

The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.

trial_indexarm_nametrial_statusgeneration_methodresultxy
000_0COMPLETEDSobol53.9025-1.45527-1.8355
111_0COMPLETEDSobol159.7994.54069-8.54693
222_0COMPLETEDSobol131.452-7.889680.413027
333_0COMPLETEDSobol219.61-6.58298-7.30381
444_0COMPLETEDSobol210.586-3.86105-8.78718
555_0RUNNINGBoTorchnan3.92566-1.92566
666_0RUNNINGBoTorchnan3.92439-1.92439
777_0RUNNINGBoTorchnan-5.06347-2.61011
Out:

[INFO 02-03 18:46:07] ax.service.ax_client: Completed trial 5 with data: {'result': (35.970254, None)}.

Out:

[INFO 02-03 18:46:07] ax.service.ax_client: Completed trial 6 with data: {'result': (35.952855, None)}.

Out:

[INFO 02-03 18:46:07] ax.service.ax_client: Completed trial 7 with data: {'result': (108.713056, None)}.

Out:

[INFO 02-03 18:46:11] ax.service.ax_client: Generated new trial 8 with parameters {'x': 2.23012, 'y': -0.23012} using model BoTorch.

Out:

/home/runner/work/Ax/Ax/ax/core/data.py:295: FutureWarning:

The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.

Out:

[INFO 02-03 18:46:16] ax.service.ax_client: Generated new trial 9 with parameters {'x': 2.426578, 'y': -1.551917} using model BoTorch.

Out:

/home/runner/work/Ax/Ax/ax/core/data.py:295: FutureWarning:

The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.

trial_indexarm_nametrial_statusgeneration_methodresultxy
000_0COMPLETEDSobol53.9025-1.45527-1.8355
111_0COMPLETEDSobol159.7994.54069-8.54693
222_0COMPLETEDSobol131.452-7.889680.413027
333_0COMPLETEDSobol219.61-6.58298-7.30381
444_0COMPLETEDSobol210.586-3.86105-8.78718
555_0COMPLETEDBoTorch35.97033.92566-1.92566
666_0COMPLETEDBoTorch35.95293.92439-1.92439
777_0COMPLETEDBoTorch108.713-5.06347-2.61011
888_0RUNNINGBoTorchnan2.23012-0.23012
999_0RUNNINGBoTorchnan2.42658-1.55192
Out:

[INFO 02-03 18:46:48] ax.service.ax_client: Completed trial 8 with data: {'result': (18.486629, None)}.

Out:

[INFO 02-03 18:46:48] ax.service.ax_client: Completed trial 9 with data: {'result': (31.152595, None)}.

trial_indexarm_nametrial_statusgeneration_methodresultxy
000_0COMPLETEDSobol53.9025-1.45527-1.8355
111_0COMPLETEDSobol159.7994.54069-8.54693
222_0COMPLETEDSobol131.452-7.889680.413027
333_0COMPLETEDSobol219.61-6.58298-7.30381
444_0COMPLETEDSobol210.586-3.86105-8.78718
555_0COMPLETEDBoTorch35.97033.92566-1.92566
666_0COMPLETEDBoTorch35.95293.92439-1.92439
777_0COMPLETEDBoTorch108.713-5.06347-2.61011
888_0COMPLETEDBoTorch18.48662.23012-0.23012
999_0COMPLETEDBoTorch31.15262.42658-1.55192

Finally

We can retrieve the best parameters and render the response surface.

best_parameters, (means, covariances) = ax_client.get_best_parameters()
print(f'Best set of parameters: {best_parameters}')
print(f'Mean objective value: {means}')
# The covariance is only meaningful when multiple objectives are present.

render(ax_client.get_contour_plot())

Out:

[INFO 02-03 18:47:18] ax.service.ax_client: Retrieving contour plot with parameter 'x' on X-axis and 'y' on Y-axis, for metric 'result'. Remaining parameters are affixed to the middle of their range.

Out:

Best set of parameters: {'x': 2.2301197684191636, 'y': -0.2301197684191637}

Mean objective value: {'result': 19.50516987256279}

loading...