Scheduler
¶We recommend reading through the "Developer API" tutorial before getting started with the Scheduler
, as using it in this tutorial will require an Ax Experiment
and an understanding of the experiment's subcomponents like the search space and the runner.
Scheduler.run_n_trials
.Scheduler.run_trials_and_yield_results
to run the optimization via a generator method.Scheduler
and external systems for trial evaluation¶Scheduler
is a closed-loop manager class in Ax that continuously deploys trial runs to an arbitrary external system in an asynchronous fashion, polls their status from that system, and leverages known trial results to generate more trials.
Key features of the Scheduler
:
Experiment
for optimization setup (an optimization config with metrics, a search space, a runner for trial evaluations),GenerationStrategy
for flexible specification of an optimization algorithm used to generate new trials to run,This scheme summarizes how the scheduler interacts with any external system used to run trial evaluations:
An example of an 'external system' running trial evaluations could be a remote server executing scheduled jobs, a subprocess conducting ML training runs, an engine running physics simulations, etc. For the sake of example here, let us assume a dummy external system with the following client:
from random import randint
from time import time
from typing import Any, Dict, NamedTuple, Union
from ax.core.base_trial import TrialStatus
from ax.utils.measurement.synthetic_functions import branin
class MockJob(NamedTuple):
"""Dummy class to represent a job scheduled on `MockJobQueue`."""
id: int
parameters: Dict[str, Union[str, float, int, bool]]
class MockJobQueueClient:
"""Dummy class to represent a job queue where the Ax `Scheduler` will
deploy trial evaluation runs during optimization.
"""
jobs: Dict[str, MockJob] = {}
def schedule_job_with_parameters(
self, parameters: Dict[str, Union[str, float, int, bool]]
) -> int:
"""Schedules an evaluation job with given parameters and returns job ID."""
# Code to actually schedule the job and produce an ID would go here;
# using timestamp in microseconds as dummy ID for this example.
job_id = int(time() * 1e6)
self.jobs[job_id] = MockJob(job_id, parameters)
return job_id
def get_job_status(self, job_id: int) -> TrialStatus:
""" "Get status of the job by a given ID. For simplicity of the example,
return an Ax `TrialStatus`.
"""
job = self.jobs[job_id]
# Instead of randomizing trial status, code to check actual job status
# would go here.
if randint(0, 3) > 0:
return TrialStatus.COMPLETED
return TrialStatus.RUNNING
def get_outcome_value_for_completed_job(self, job_id: int) -> Dict[str, float]:
"""Get evaluation results for a given completed job."""
job = self.jobs[job_id]
# In a real external system, this would retrieve real relevant outcomes and
# not a synthetic function value.
return {"branin": branin(job.parameters.get("x1"), job.parameters.get("x2"))}
MOCK_JOB_QUEUE_CLIENT = MockJobQueueClient()
def get_mock_job_queue_client() -> MockJobQueueClient:
"""Obtain the singleton job queue instance."""
return MOCK_JOB_QUEUE_CLIENT
As mentioned above, using a Scheduler
requires a fully set up experiment with metrics and a runner. Refer to the "Building Blocks of Ax" tutorial to learn more about those components, as here we assume familiarity with them.
The following runner and metric set up intractions between the Scheduler
and the mock external system we assume:
from collections import defaultdict
from typing import Iterable, Set
from ax.core.base_trial import BaseTrial
from ax.core.runner import Runner
from ax.core.trial import Trial
class MockJobRunner(Runner): # Deploys trials to external system.
def run(self, trial: BaseTrial) -> Dict[str, Any]:
"""Deploys a trial based on custom runner subclass implementation.
Args:
trial: The trial to deploy.
Returns:
Dict of run metadata from the deployment process.
"""
if not isinstance(trial, Trial):
raise ValueError("This runner only handles `Trial`.")
mock_job_queue = get_mock_job_queue_client()
job_id = mock_job_queue.schedule_job_with_parameters(
parameters=trial.arm.parameters
)
# This run metadata will be attached to trial as `trial.run_metadata`
# by the base `Scheduler`.
return {"job_id": job_id}
def poll_trial_status(
self, trials: Iterable[BaseTrial]
) -> Dict[TrialStatus, Set[int]]:
"""Checks the status of any non-terminal trials and returns their
indices as a mapping from TrialStatus to a list of indices. Required
for runners used with Ax ``Scheduler``.
NOTE: Does not need to handle waiting between polling calls while trials
are running; this function should just perform a single poll.
Args:
trials: Trials to poll.
Returns:
A dictionary mapping TrialStatus to a list of trial indices that have
the respective status at the time of the polling. This does not need to
include trials that at the time of polling already have a terminal
(ABANDONED, FAILED, COMPLETED) status (but it may).
"""
status_dict = defaultdict(set)
for trial in trials:
mock_job_queue = get_mock_job_queue_client()
status = mock_job_queue.get_job_status(
job_id=trial.run_metadata.get("job_id")
)
status_dict[status].add(trial.index)
return status_dict
import pandas as pd
from ax.core.metric import Metric, MetricFetchResult, MetricFetchE
from ax.core.base_trial import BaseTrial
from ax.core.data import Data
from ax.utils.common.result import Ok, Err
class BraninForMockJobMetric(Metric): # Pulls data for trial from external system.
def fetch_trial_data(self, trial: BaseTrial) -> MetricFetchResult:
"""Obtains data via fetching it from ` for a given trial."""
if not isinstance(trial, Trial):
raise ValueError("This metric only handles `Trial`.")
try:
mock_job_queue = get_mock_job_queue_client()
# Here we leverage the "job_id" metadata created by `MockJobRunner.run`.
branin_data = mock_job_queue.get_outcome_value_for_completed_job(
job_id=trial.run_metadata.get("job_id")
)
df_dict = {
"trial_index": trial.index,
"metric_name": "branin",
"arm_name": trial.arm.name,
"mean": branin_data.get("branin"),
# Can be set to 0.0 if function is known to be noiseless
# or to an actual value when SEM is known. Setting SEM to
# `None` results in Ax assuming unknown noise and inferring
# noise level from data.
"sem": None,
}
return Ok(value=Data(df=pd.DataFrame.from_records([df_dict])))
except Exception as e:
return Err(
MetricFetchE(message=f"Failed to fetch {self.name}", exception=e)
)
Now we can set up the experiment using the runner and metric we defined. This experiment will have a single-objective optimization config, minimizing the Branin function, and the search space that corresponds to that function.
from ax import *
def make_branin_experiment_with_runner_and_metric() -> Experiment:
parameters = [
RangeParameter(
name="x1",
parameter_type=ParameterType.FLOAT,
lower=-5,
upper=10,
),
RangeParameter(
name="x2",
parameter_type=ParameterType.FLOAT,
lower=0,
upper=15,
),
]
objective = Objective(metric=BraninForMockJobMetric(name="branin"), minimize=True)
return Experiment(
name="branin_test_experiment",
search_space=SearchSpace(parameters=parameters),
optimization_config=OptimizationConfig(objective=objective),
runner=MockJobRunner(),
is_test=True, # Marking this experiment as a test experiment.
)
experiment = make_branin_experiment_with_runner_and_metric()
[INFO 09-23 20:32:23] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
Scheduler
¶A Scheduler
requires an Ax GenerationStrategy
specifying the algorithm to use for the optimization. Here we use the choose_generation_strategy
utility that auto-picks a generation strategy based on the search space properties. To construct a custom generation strategy instead, refer to the "Generation Strategy" tutorial.
Importantly, a generation strategy in Ax limits allowed parallelism levels for each generation step it contains. If you would like the Scheduler
to ensure parallelism limitations, set max_examples
on each generation step in your generation strategy.
from ax.modelbridge.dispatch_utils import choose_generation_strategy
generation_strategy = choose_generation_strategy(
search_space=experiment.search_space,
max_parallelism_cap=3,
)
[INFO 09-23 20:32:23] ax.modelbridge.dispatch_utils: Using Models.BOTORCH_MODULAR since there is at least one ordered parameter and there are no unordered categorical parameters.
[INFO 09-23 20:32:23] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=2 num_trials=None use_batch_trials=False
[INFO 09-23 20:32:23] ax.modelbridge.dispatch_utils: calculated num_initialization_trials=5
[INFO 09-23 20:32:23] ax.modelbridge.dispatch_utils: num_completed_initialization_trials=0 num_remaining_initialization_trials=5
[INFO 09-23 20:32:23] ax.modelbridge.dispatch_utils: `verbose`, `disable_progbar`, and `jit_compile` are not yet supported when using `choose_generation_strategy` with ModularBoTorchModel, dropping these arguments.
[INFO 09-23 20:32:23] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting.
Now we have all the components needed to start the scheduler:
from ax.service.scheduler import Scheduler, SchedulerOptions
scheduler = Scheduler(
experiment=experiment,
generation_strategy=generation_strategy,
options=SchedulerOptions(),
)
[INFO 09-23 20:32:23] Scheduler: `Scheduler` requires experiment to have immutable search space and optimization config. Setting property immutable_search_space_and_opt_config to `True` on experiment.
import numpy as np
from ax.plot.trace import optimization_trace_single_method
from ax.utils.notebook.plotting import render, init_notebook_plotting
init_notebook_plotting()
def get_plot():
best_objectives = np.array(
[[trial.objective_mean for trial in scheduler.experiment.trials.values()]]
)
best_objective_plot = optimization_trace_single_method(
y=np.minimum.accumulate(best_objectives, axis=1),
title="Model performance vs. # of iterations",
ylabel="Y",
)
return best_objective_plot
[INFO 09-23 20:32:24] ax.utils.notebook.plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.
[INFO 09-23 20:32:24] ax.utils.notebook.plotting: Please see (https://ax.dev/tutorials/visualizations.html#Fix-for-plots-that-are-not-rendering) if visualizations are not rendering.