from ax.modelbridge.dispatch_utils import choose_generation_strategy
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.modelbridge_utils import get_pending_observation_features
from ax.modelbridge.registry import ModelRegistryBase, Models
from ax.utils.testing.core_stubs import get_branin_experiment, get_branin_search_space
GenerationStrategy
(API reference) is a key abstraction in Ax:
Scheduler
etc. (tutorials for all those higher-level APIs are here: https://ax.dev/tutorials/).This tutorial walks through a few examples of generation strategies and discusses its important settings. Before reading it, we recommend familiarizing yourself with how Model
and ModelBridge
work in Ax: https://ax.dev/docs/models.html#deeper-dive-organization-of-the-modeling-stack.
Contents:
GenerationStep
a building block of the generation strategyGenerationStep
settingsGenerationStep
-s togethermax_parallelism
enforcement and handling the MaxParallelismReachedException
GenerationStrategy
storageGeneratorRun
-s, not Trial
-smodel_kwargs
elements that don't have associated serialization logic in AxModels
registry enum entries over a factory function?Models
?gs = GenerationStrategy(
steps=[
# 1. Initialization step (does not require pre-existing data and is well-suited for
# initial sampling of the search space)
GenerationStep(
model=Models.SOBOL,
num_trials=5, # How many trials should be produced from this generation step
min_trials_observed=3, # How many trials need to be completed to move to next model
max_parallelism=5, # Max parallelism for this step
model_kwargs={"seed": 999}, # Any kwargs you want passed into the model
model_gen_kwargs={}, # Any kwargs you want passed to `modelbridge.gen`
),
# 2. Bayesian optimization step (requires data obtained from previous phase and learns
# from all data available at the time of each new candidate generation call)
GenerationStep(
model=Models.BOTORCH_MODULAR,
num_trials=-1, # No limitation on how many trials should be produced from this step
max_parallelism=3, # Parallelism limit for this step, often lower than for Sobol
# More on parallelism vs. required samples in BayesOpt:
# https://ax.dev/docs/bayesopt.html#tradeoff-between-parallelism-and-total-number-of-trials
),
]
)
Ax provides a choose_generation_strategy
utility, which can auto-select a suitable generation strategy given a search space and an array of other optional settings. The utility is fairly simple at the moment, but additional development (support for multi-objective optimization, multi-fidelity optimization, Bayesian optimization with categorical kernels etc.) is coming soon.
gs = choose_generation_strategy(
# Required arguments:
search_space=get_branin_search_space(), # Ax `SearchSpace`
# Some optional arguments (shown with their defaults), see API docs for more settings:
# https://ax.dev/api/modelbridge.html#module-ax.modelbridge.dispatch_utils
use_batch_trials=False, # Whether this GS will be used to generate 1-arm `Trial`-s or `BatchTrials`
no_bayesian_optimization=False, # Use quasi-random candidate generation without BayesOpt
max_parallelism_override=None, # Integer, to which to set the `max_parallelism` setting of all steps in this GS
)
gs
[INFO 09-23 20:32:15] ax.modelbridge.dispatch_utils: Using Models.BOTORCH_MODULAR since there is at least one ordered parameter and there are no unordered categorical parameters.
[INFO 09-23 20:32:15] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=2 num_trials=None use_batch_trials=False
[INFO 09-23 20:32:15] ax.modelbridge.dispatch_utils: calculated num_initialization_trials=5
[INFO 09-23 20:32:15] ax.modelbridge.dispatch_utils: num_completed_initialization_trials=0 num_remaining_initialization_trials=5
[INFO 09-23 20:32:15] ax.modelbridge.dispatch_utils: `verbose`, `disable_progbar`, and `jit_compile` are not yet supported when using `choose_generation_strategy` with ModularBoTorchModel, dropping these arguments.
[INFO 09-23 20:32:15] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting.
GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials])
While often used through Service or Loop API or other higher-order abstractions like the Ax Scheduler
(where the generation strategy is used to fit models and produce candidates from them under-the-hood), it's also possible to use the GS directly, in place of a ModelBridge
instance. The interface of GenerationStrategy.gen
is the same as ModelBridge.gen
.
experiment = get_branin_experiment()
[INFO 09-23 20:32:15] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
Note that it's important to specify pending observations to the call to gen
to avoid getting the same points re-suggested. Without pending_observations
argument, Ax models are not aware of points that should be excluded from generation. Points are considered "pending" when they belong to STAGED
, RUNNING
, or ABANDONED
trials (with the latter included so model does not re-suggest points that are considered "bad" and should not be re-suggested).
If the call to get_pending_obervation_features
becomes slow in your setup (since it performs data-fetching etc.), you can opt for get_pending_observation_features_based_on_trial_status
(also from ax.modelbridge.modelbridge_utils
), but note the limitations of that utility (detailed in its docstring).
generator_run = gs.gen(
experiment=experiment, # Ax `Experiment`, for which to generate new candidates
data=None, # Ax `Data` to use for model training, optional.
n=1, # Number of candidate arms to produce
pending_observations=get_pending_observation_features(
experiment
), # Points that should not be re-generated
# Any other kwargs specified will be passed through to `ModelBridge.gen` along with `GenerationStep.model_gen_kwargs`
)
generator_run
/tmp/tmp.QqcA7fo0ui/Ax-main/ax/modelbridge/cross_validation.py:463: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. warn("Encountered exception in computing model fit quality: " + str(e))
GeneratorRun(1 arms, total weight 1.0)
Then we can add the newly produced GeneratorRun
to the experiment as a Trial
(or BatchTrial
if n
> 1):
trial = experiment.new_trial(generator_run)
trial
Trial(experiment_name='branin_test_experiment', index=0, status=TrialStatus.CANDIDATE, arm=Arm(name='0_0', parameters={'x1': 5.9256744384765625, 'x2': 13.495488166809082}))
Important notes on GenerationStrategy.gen
:
data
argument above is not specified, GS will pull experiment data from cache via experiment.lookup_data
,pending_observations
, the GS (and any model in Ax) could produce the same candidate over and over, as without that argument the model is not 'aware' that the candidate is part of a RUNNING
or ABANDONED
trial and should not be re-suggested again.In cases where get_pending_observation_features
is too slow and the experiment consists of 1-arm Trial
-s only, it's possible to use get_pending_observation_features_based_on_trial_status
instead (found in the same file).
Note that when using the Ax Service API, one of the arguments to AxClient
is choose_generation_strategy_kwargs
; specifying that argument is a convenient way to influence the choice of generation strategy in AxClient
without manually specifying a full GenerationStrategy
.
GenerationStep
as a building block of generation strategy¶GenerationStep
¶There are two ways of specifying a model for a generation step: via an entry in a Models
enum or via a 'factory function' –– a callable model constructor (e.g. get_GPEI
and other factory functions in the same file). Note that using the latter path, a factory function, will prohibit GenerationStrategy
storage and is generally discouraged.
GenerationStep
settings¶All of the available settings are described in the documentation:
print(GenerationStep.__doc__)
One step in the generation strategy, corresponds to a single model. Describes the model, how many trials will be generated with this model, what minimum number of observations is required to proceed to the next model, etc. NOTE: Model can be specified either from the model registry (`ax.modelbridge.registry.Models` or using a callable model constructor. Only models from the registry can be saved, and thus optimization can only be resumed if interrupted when using models from the registry. Args: model: A member of `Models` enum or a callable returning an instance of `ModelBridge` with an instantiated underlying `Model`. Refer to `ax/modelbridge/factory.py` for examples of such callables. num_trials: How many trials to generate with the model from this step. If set to -1, trials will continue to be generated from this model as long as `generation_strategy.gen` is called (available only for the last of the generation steps). min_trials_observed: How many trials must be completed before the generation strategy can proceed to the next step. Defaults to 0. If `num_trials` of a given step have been generated but `min_trials_ observed` have not been completed, a call to `generation_strategy.gen` will fail with a `DataRequiredError`. max_parallelism: How many trials generated in the course of this step are allowed to be run (i.e. have `trial.status` of `RUNNING`) simultaneously. If `max_parallelism` trials from this step are already running, a call to `generation_strategy.gen` will fail with a `MaxParallelismReached Exception`, indicating that more trials need to be completed before generating and running next trials. use_update: DEPRECATED. enforce_num_trials: Whether to enforce that only `num_trials` are generated from the given step. If False and `num_trials` have been generated, but `min_trials_observed` have not been completed, `generation_strategy.gen` will continue generating trials from the current step, exceeding `num_ trials` for it. Allows to avoid `DataRequiredError`, but delays proceeding to next generation step. model_kwargs: Dictionary of kwargs to pass into the model constructor on instantiation. E.g. if `model` is `Models.SOBOL`, kwargs will be applied as `Models.SOBOL(**model_kwargs)`; if `model` is `get_sobol`, `get_sobol( **model_kwargs)`. NOTE: if generation strategy is interrupted and resumed from a stored snapshot and its last used model has state saved on its generator runs, `model_kwargs` is updated with the state dict of the model, retrieved from the last generator run of this generation strategy. model_gen_kwargs: Each call to `generation_strategy.gen` performs a call to the step's model's `gen` under the hood; `model_gen_kwargs` will be passed to the model's `gen` like so: `model.gen(**model_gen_kwargs)`. completion_criteria: List of TransitionCriterion. All `is_met` must evaluate True for the GenerationStrategy to move on to the next Step index: Index of this generation step, for use internally in `Generation Strategy`. Do not assign as it will be reassigned when instantiating `GenerationStrategy` with a list of its steps. should_deduplicate: Whether to deduplicate the parameters of proposed arms against those of previous arms via rejection sampling. If this is True, the generation strategy will discard generator runs produced from the generation step that has `should_deduplicate=True` if they contain arms already present on the experiment and replace them with new generator runs. If no generator run with entirely unique arms could be produced in 5 attempts, a `GenerationStrategyRepeatedPoints` error will be raised, as we assume that the optimization converged when the model can no longer suggest unique arms. Note for developers: by "model" here we really mean an Ax ModelBridge object, which contains an Ax Model under the hood. We call it "model" here to simplify and focus on explaining the logic of GenerationStep and GenerationStrategy.
GenerationStep
-s together¶A GenerationStrategy
moves from one step to another when:
N=num_trials
generator runs were produced and attached as trials to the experiment ANDM=min_trials_observed
have been completed and have data.Caveat: enforce_num_trials
setting:
enforce_num_trials=True
for a given generation step, if 1) is reached but 2) is not yet reached, the generation strategy will raise a DataRequiredError
, indicating that more trials need to be completed before the next step.enforce_num_trials=False
, the GS will continue producing generator runs from the current step until 2) is reached.max_parallelism
enforcement¶Generation strategy can restrict the number of trials that can be ran simultaneously (to encourage sequential optimization, which benefits Bayesian optimization performance). When the parallelism limit is reached, a call to GenerationStrategy.gen
will result in a MaxParallelismReachedException
.
The correct way to handle this exception:
GenerationStep.max_parallelism
is configured correctly for all steps in your generation strategy (to disable it completely, configure GenerationStep.max_parallelism=None
),trial.mark_completed
.When used through Service API or Scheduler
, generation strategy will be automatically stored to SQL or JSON via specifying DBSettings
to either AxClient
or Scheduler
(details in respective tutorials in the "Tutorials" page). Generation strategy can also be stored to SQL or JSON individually, as shown below.
More detail on SQL and JSON storage in Ax generally can be found in "Building Blocks of Ax" tutorial.
For SQL storage setup in Ax, read through the "Storage" documentation page.
Note that unlike an Ax experiment, a generation strategy does not have a name or another unique identifier. Therefore, a generation strategy is stored in association with experiment and can be retrieved by the associated experiment's name.
from ax.storage.sqa_store.db import (
create_all_tables,
get_engine,
init_engine_and_session_factory,
)
from ax.storage.sqa_store.load import (
load_experiment,
load_generation_strategy_by_experiment_name,
)
from ax.storage.sqa_store.save import save_experiment, save_generation_strategy
init_engine_and_session_factory(url="sqlite:///foo2.db")
engine = get_engine()
create_all_tables(engine)
save_experiment(experiment)
save_generation_strategy(gs)
experiment = load_experiment(experiment_name=experiment.name)
gs = load_generation_strategy_by_experiment_name(
experiment_name=experiment.name,
experiment=experiment, # Can optionally specify experiment object to avoid loading it from database twice
)
gs
[INFO 09-23 20:32:16] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
[INFO 09-23 20:32:16] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials])
from ax.storage.json_store.decoder import object_from_json
from ax.storage.json_store.encoder import object_to_json
gs_json = object_to_json(gs) # Can be written to a file or string via `json.dump` etc.
gs = object_from_json(
gs_json
) # Decoded back from JSON (can be loaded from file, string via `json.load` etc.)
gs
[INFO 09-23 20:32:16] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials])
Below is a list of important "gotchas" of using generation strategy (especially outside of the higher-level APIs like the Service API or the Scheduler
):
GenerationStrategy.gen
produces GeneratorRun
-s, not trials¶Since GenerationStrategy.gen
mimics ModelBridge.gen
and allows for human-in-the-loop usage mode, a call to gen
produces a GeneratorRun
, which can then be added (or altered before addition or not added at all) to a Trial
or BatchTrial
on a given experiment. So it's important to add the generator run to a trial, since otherwise it will not be attached to the experiment on its own.
generator_run = gs.gen(
experiment=experiment,
n=1,
pending_observations=get_pending_observation_features(experiment),
)
experiment.new_trial(generator_run)
/tmp/tmp.QqcA7fo0ui/Ax-main/ax/modelbridge/cross_validation.py:463: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction. warn("Encountered exception in computing model fit quality: " + str(e))
Trial(experiment_name='branin_test_experiment', index=1, status=TrialStatus.CANDIDATE, arm=Arm(name='1_0', parameters={'x1': -3.437549672089517, 'x2': 3.9023718936368823}))
model_kwargs
elements that do not define serialization logic in Ax¶Note that passing objects that are not yet serializable in Ax (e.g. a BoTorch Prior
object) as part of GenerationStep.model_kwargs
or GenerationStep.model_gen_kwargs
will prevent correct generation strategy storage. If this becomes a problem, feel free to open an issue on our Github: https://github.com/facebook/Ax/issues to get help with adding storage support for a given object.
Models
enum entries over a factory function?¶Models.GPEI
captures all arguments to the model and model bridge and stores them on a generator runs, subsequently produced by the model. Since the capturing logic is part of Models.__call__
function, it is not present in a factory function. Furthermore, there is no safe and flexible way to serialize callables in Python.ModelBridge
with an underlying Model
instance based on them), it is not standard in terms of its inputs. Models
introduces a standardized interface, making it easy to adapt any example to one's specific case.Models
and natively supported in Ax?¶Please open a Github issue to request a new modeling setup in Ax (or for any other questions or requests).
Total runtime of script: 7.03 seconds.