This tutorial illustrates use of a Global Stopping Strategy (GSS) in combination with the Service API. For background on the Service API, see the Service API Tutorial: https://ax.dev/tutorials/gpei_hartmann_service.html GSS is also supported in the Scheduler API, where it can be provided as part of SchedulerOptions
. For more on Scheduler
, see the Scheduler tutorial: https://ax.dev/tutorials/scheduler.html
Global Stopping stops an optimization loop when some data-based criteria are met which suggest that future trials will not be very helpful. For example, we might stop when there has been very little improvement in the last five trials. This is as opposed to trial-level early stopping, which monitors the results of expensive evaluations and terminates those that are unlikely to produce promising results, freeing resources to explore more promising configurations. For more on trial-level early stopping, see the tutorial: https://ax.dev/tutorials/early_stopping/early_stopping.html
import numpy as np
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.measurement.synthetic_functions import Branin, branin
from ax.utils.notebook.plotting import init_notebook_plotting, render
init_notebook_plotting()
[ERROR 11-12 07:14:25] ax.storage.sqa_store.encoder: ATTENTION: The Ax team is considering deprecating SQLAlchemy storage. If you are currently using SQLAlchemy storage, please reach out to us via GitHub Issues here: https://github.com/facebook/Ax/issues/2975
[INFO 11-12 07:14:25] ax.utils.notebook.plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.
[INFO 11-12 07:14:25] ax.utils.notebook.plotting: Please see (https://ax.dev/tutorials/visualizations.html#Fix-for-plots-that-are-not-rendering) if visualizations are not rendering.
This example uses the Branin test problem. We run 25 trials, which turns out to be far more than needed, because we get close to the optimum quite quickly.
def evaluate(parameters):
x = np.array([parameters.get(f"x{i+1}") for i in range(2)])
return {"branin": (branin(x), 0.0)}
params = [
{
"name": f"x{i + 1}",
"type": "range",
"bounds": [*Branin._domain[i]],
"value_type": "float",
"log_scale": False,
}
for i in range(2)
]
ax_client = AxClient(random_seed=0, verbose_logging=False)
ax_client.create_experiment(
name="branin_test_experiment",
parameters=params,
objectives={"branin": ObjectiveProperties(minimize=True)},
is_test=True,
)
[WARNING 11-12 07:14:26] ax.service.ax_client: Random seed set to 0. Note that this setting only affects the Sobol quasi-random generator and BoTorch-powered Bayesian optimization models. For the latter models, setting random seed to the same number for two optimizations will make the generated trials similar, but not exactly the same, and over time the trials will diverge more.
[INFO 11-12 07:14:26] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[]).
[INFO 11-12 07:14:26] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
[INFO 11-12 07:14:26] ax.modelbridge.dispatch_utils: Using Models.BOTORCH_MODULAR since there is at least one ordered parameter and there are no unordered categorical parameters.
[INFO 11-12 07:14:26] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=2 num_trials=None use_batch_trials=False
[INFO 11-12 07:14:26] ax.modelbridge.dispatch_utils: calculated num_initialization_trials=5
[INFO 11-12 07:14:26] ax.modelbridge.dispatch_utils: num_completed_initialization_trials=0 num_remaining_initialization_trials=5
[INFO 11-12 07:14:26] ax.modelbridge.dispatch_utils: `verbose`, `disable_progbar`, and `jit_compile` are not yet supported when using `choose_generation_strategy` with ModularBoTorchModel, dropping these arguments.
[INFO 11-12 07:14:26] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting.
%%time
for i in range(25):
parameters, trial_index = ax_client.get_next_trial()
# Local evaluation here can be replaced with deployment to external system.
ax_client.complete_trial(
trial_index=trial_index, raw_data=evaluate(parameters)
)
/tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. /tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. /tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. /tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. /tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.
CPU times: user 32.1 s, sys: 71.8 ms, total: 32.2 s Wall time: 16.5 s
render(ax_client.get_optimization_trace())
Rather than running a fixed number of trials, we can use a GlobalStoppingStrategy (GSS), which checks whether some stopping criteria have been met when get_next_trial
is called. Here, we use an ImprovementGlobalStoppingStrategy
, which checks whether the the last window_size
trials have improved by more than some threshold amount.
For single-objective optimization, which we are doing here, ImprovementGlobalStoppingStrategy
checks if an improvement is "significant" by comparing it to the inter-quartile range (IQR) of the objective values attained so far.
ImprovementGlobalStoppingStrategy
also supports multi-objective optimization (MOO), in which case it checks whether the percentage improvement in hypervolume over the last window_size
trials exceeds improvement_bar
.
from ax.global_stopping.strategies.improvement import ImprovementGlobalStoppingStrategy
from ax.exceptions.core import OptimizationShouldStop
# Start considering stopping only after the 5 initialization trials + 5 GPEI trials.
# Stop if the improvement in the best point in the past 5 trials is less than
# 1% of the IQR thus far.
stopping_strategy = ImprovementGlobalStoppingStrategy(
min_trials=5 + 5, window_size=5, improvement_bar=0.01
)
ax_client_gss = AxClient(
global_stopping_strategy=stopping_strategy, random_seed=0, verbose_logging=False
)
ax_client_gss.create_experiment(
name="branin_test_experiment",
parameters=params,
objectives={"branin": ObjectiveProperties(minimize=True)},
is_test=True,
)
[WARNING 11-12 07:14:43] ax.service.ax_client: Random seed set to 0. Note that this setting only affects the Sobol quasi-random generator and BoTorch-powered Bayesian optimization models. For the latter models, setting random seed to the same number for two optimizations will make the generated trials similar, but not exactly the same, and over time the trials will diverge more.
[INFO 11-12 07:14:43] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[]).
[INFO 11-12 07:14:43] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
[INFO 11-12 07:14:43] ax.modelbridge.dispatch_utils: Using Models.BOTORCH_MODULAR since there is at least one ordered parameter and there are no unordered categorical parameters.
[INFO 11-12 07:14:43] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=2 num_trials=None use_batch_trials=False
[INFO 11-12 07:14:43] ax.modelbridge.dispatch_utils: calculated num_initialization_trials=5
[INFO 11-12 07:14:43] ax.modelbridge.dispatch_utils: num_completed_initialization_trials=0 num_remaining_initialization_trials=5
[INFO 11-12 07:14:43] ax.modelbridge.dispatch_utils: `verbose`, `disable_progbar`, and `jit_compile` are not yet supported when using `choose_generation_strategy` with ModularBoTorchModel, dropping these arguments.
[INFO 11-12 07:14:43] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting.
If there has not been much improvement, ImprovementGlobalStoppingStrategy
will raise an exception. If the exception is raised, we catch it and terminate optimization.
for i in range(25):
try:
parameters, trial_index = ax_client_gss.get_next_trial()
except OptimizationShouldStop as exc:
print(exc.message)
break
ax_client_gss.complete_trial(trial_index=trial_index, raw_data=evaluate(parameters))
/tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.
/tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. /tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. /tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. /tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.
The improvement in best objective in the past 5 trials (=0.000) is less than 0.01 times the interquartile range (IQR) of objectives attained so far (IQR=34.894).
render(ax_client_gss.get_optimization_trace())
You can write a custom Global Stopping Strategy by subclassing BaseGlobalStoppingStrategy
and use it where ImprovementGlobalStoppingStrategy
was used above.
from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy
from typing import Tuple
from ax.core.experiment import Experiment
from ax.core.base_trial import TrialStatus
from ax.global_stopping.strategies.improvement import constraint_satisfaction
Here, we define SimpleThresholdGlobalStoppingStrategy
, which stops when we observe a point better than a provided threshold. This can be useful when there is a known optimum. For example, the Branin function has an optimum of zero. When the optimum is not known, this can still be useful from a satisficing perspective: For example, maybe we need a model to take up less than a certain amount of RAM so it doesn't crash our usual hardware, but there is no benefit to further improvements.
class SimpleThresholdGlobalStoppingStrategy(BaseGlobalStoppingStrategy):
"""
A GSS that stops when we observe a point better than `threshold`.
"""
def __init__(
self,
min_trials: int,
inactive_when_pending_trials: bool = True,
threshold: float = 0.1
):
self.threshold = threshold
super().__init__(
min_trials=min_trials,
inactive_when_pending_trials=inactive_when_pending_trials
)
def _should_stop_optimization(
self, experiment: Experiment
) -> Tuple[bool, str]:
"""
Check if the best seen is better than `self.threshold`.
"""
feasible_objectives = [
trial.objective_mean
for trial in experiment.trials_by_status[TrialStatus.COMPLETED]
if constraint_satisfaction(trial)
]
# Computing the interquartile for scaling the difference
if len(feasible_objectives) <= 1:
message = "There are not enough feasible arms tried yet."
return False, message
minimize = experiment.optimization_config.objective.minimize
if minimize:
best = np.min(feasible_objectives)
stop = best < self.threshold
else:
best = np.max(feasible_objectives)
stop = best > self.threshold
comparison = "less" if minimize else "greater"
if stop:
message = (
f"The best objective seen is {best:.3f}, which is {comparison} "
f"than the threshold of {self.threshold:.3f}."
)
else:
message = ""
return stop, message
stopping_strategy = SimpleThresholdGlobalStoppingStrategy(min_trials=5, threshold=1.)
ax_client_custom_gss = AxClient(
global_stopping_strategy=stopping_strategy,
random_seed=0,
verbose_logging=False,
)
ax_client_custom_gss.create_experiment(
name="branin_test_experiment",
parameters=params,
objectives={"branin": ObjectiveProperties(minimize=True)},
is_test=True,
)
[WARNING 11-12 07:14:51] ax.service.ax_client: Random seed set to 0. Note that this setting only affects the Sobol quasi-random generator and BoTorch-powered Bayesian optimization models. For the latter models, setting random seed to the same number for two optimizations will make the generated trials similar, but not exactly the same, and over time the trials will diverge more.
[INFO 11-12 07:14:51] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[]).
[INFO 11-12 07:14:51] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
[INFO 11-12 07:14:51] ax.modelbridge.dispatch_utils: Using Models.BOTORCH_MODULAR since there is at least one ordered parameter and there are no unordered categorical parameters.
[INFO 11-12 07:14:51] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=2 num_trials=None use_batch_trials=False
[INFO 11-12 07:14:51] ax.modelbridge.dispatch_utils: calculated num_initialization_trials=5
[INFO 11-12 07:14:51] ax.modelbridge.dispatch_utils: num_completed_initialization_trials=0 num_remaining_initialization_trials=5
[INFO 11-12 07:14:51] ax.modelbridge.dispatch_utils: `verbose`, `disable_progbar`, and `jit_compile` are not yet supported when using `choose_generation_strategy` with ModularBoTorchModel, dropping these arguments.
[INFO 11-12 07:14:51] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting.
for i in range(25):
try:
parameters, trial_index = ax_client_custom_gss.get_next_trial()
except OptimizationShouldStop as exc:
print(exc.message)
break
ax_client_custom_gss.complete_trial(
trial_index=trial_index, raw_data=evaluate(parameters)
)
/tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. /tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. /tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. /tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. /tmp/tmp.Lx6ya87xsF/Ax-main/ax/modelbridge/cross_validation.py:464: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.
The best objective seen is 0.401, which is less than the threshold of 1.000.
render(ax_client_custom_gss.get_optimization_trace())
Total runtime of script: 36.16 seconds.