from ax.modelbridge.generation_strategy import GenerationStrategy, GenerationStep
from ax.modelbridge.registry import Models, ModelRegistryBase
from ax.modelbridge.dispatch_utils import choose_generation_strategy
from ax.modelbridge.modelbridge_utils import get_pending_observation_features
from ax.utils.testing.core_stubs import get_branin_search_space, get_branin_experiment
GenerationStrategy
(API reference) is a key abstraction in Ax:
Scheduler
etc. (tutorials for all those higher-level APIs are here: https://ax.dev/tutorials/).This tutorial walks through a few examples of generation strategies and discusses its important settings. Before reading it, we recommend familiarizing yourself with how Model
and ModelBridge
work in Ax: https://ax.dev/docs/models.html#deeper-dive-organization-of-the-modeling-stack.
Contents:
GenerationStep
a building block of the generation strategyGenerationStep
settingsGenerationStep
-s togethermax_parallelism
enforcement and handling the MaxParallelismReachedException
GenerationStrategy
storageGeneratorRun
-s, not Trial
-smodel_kwargs
elements that don't have associated serialization logic in AxModels
registry enum entries over a factory function?Models
?gs = GenerationStrategy(
steps=[
# 1. Initialization step (does not require pre-existing data and is well-suited for
# initial sampling of the search space)
GenerationStep(
model=Models.SOBOL,
num_trials=5, # How many trials should be produced from this generation step
min_trials_observed=3, # How many trials need to be completed to move to next model
max_parallelism=5, # Max parallelism for this step
model_kwargs={"seed": 999}, # Any kwargs you want passed into the model
model_gen_kwargs={}, # Any kwargs you want passed to `modelbridge.gen`
),
# 2. Bayesian optimization step (requires data obtained from previous phase and learns
# from all data available at the time of each new candidate generation call)
GenerationStep(
model=Models.GPEI,
num_trials=-1, # No limitation on how many trials should be produced from this step
max_parallelism=3, # Parallelism limit for this step, often lower than for Sobol
# More on parallelism vs. required samples in BayesOpt:
# https://ax.dev/docs/bayesopt.html#tradeoff-between-parallelism-and-total-number-of-trials
),
]
)
Ax provides a choose_generation_strategy
utility, which can auto-select a suitable generation strategy given a search space and an array of other optional settings. The utility is fairly simple at the moment, but additional development (support for multi-objective optimization, multi-fidelity optimization, Bayesian optimization with categorical kernels etc.) is coming soon.
gs = choose_generation_strategy(
# Required arguments:
search_space=get_branin_search_space(), # Ax `SearchSpace`
# Some optional arguments (shown with their defaults), see API docs for more settings:
# https://ax.dev/api/modelbridge.html#module-ax.modelbridge.dispatch_utils
use_batch_trials=False, # Whether this GS will be used to generate 1-arm `Trial`-s or `BatchTrials`
no_bayesian_optimization=False, # Use quasi-random candidate generation without BayesOpt
max_parallelism_override=None, # Integer, to which to set the `max_parallelism` setting of all steps in this GS
)
gs
[INFO 09-15 17:23:59] ax.modelbridge.dispatch_utils: Using Bayesian optimization since there are more ordered parameters than there are categories for the unordered categorical parameters. [INFO 09-15 17:23:59] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting.
GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials])
While often used through Service or Loop API or other higher-order abstractions like the Ax Scheduler
(where the generation strategy is used to fit models and produce candidates from them under-the-hood), it's also possible to use the GS directly, in place of a ModelBridge
instance. The interface of GenerationStrategy.gen
is the same as ModelBridge.gen
.
experiment = get_branin_experiment()
[INFO 09-15 17:23:59] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
Note that it's important to specify pending observations to the call to gen
to avoid getting the same points re-suggested. Without pending_observations
argument, Ax models are not aware of points that should be excluded from generation. Points are considered "pending" when they belong to STAGED
, RUNNING
, or ABANDONED
trials (with the latter included so model does not re-suggest points that are considered "bad" and should not be re-suggested).
If the call to get_pending_obervation_features
becomes slow in your setup (since it performs data-fetching etc.), you can opt for get_pending_observation_features_based_on_trial_status
(also from ax.modelbridge.modelbridge_utils
), but note the limitations of that utility (detailed in its docstring).
generator_run = gs.gen(
experiment=experiment, # Ax `Experiment`, for which to generate new candidates
data=None, # Ax `Data` to use for model training, optional.
n=1, # Number of candidate arms to produce
pending_observations=get_pending_observation_features(experiment), # Points that should not be re-generated
# Any other kwargs specified will be passed through to `ModelBridge.gen` along with `GenerationStep.model_gen_kwargs`
)
generator_run
GeneratorRun(1 arms, total weight 1.0)
Then we can add the newly produced GeneratorRun
to the experiment as a Trial
(or BatchTrial
if n
> 1):
trial = experiment.new_trial(generator_run)
trial
Trial(experiment_name='branin_test_experiment', index=0, status=TrialStatus.CANDIDATE, arm=Arm(name='0_0', parameters={'x1': -1.9262760877609253, 'x2': 13.50526750087738}))
Important notes on GenerationStrategy.gen
:
data
argument above is not specified, GS will pull experiment data from cache via experiment.lookup_data
,pending_observations
, the GS (and any model in Ax) could produce the same candidate over and over, as without that argument the model is not 'aware' that the candidate is part of a RUNNING
or ABANDONED
trial and should not be re-suggested again.In cases where get_pending_observation_features
is too slow and the experiment consists of 1-arm Trial
-s only, it's possible to use get_pending_observation_features_based_on_trial_status
instead (found in the same file).
Note that when using the Ax Service API, one of the arguments to AxClient
is choose_generation_strategy_kwargs
; specifying that argument is a convenient way to influence the choice of generation strategy in AxClient
without manually specifying a full GenerationStrategy
.
GenerationStep
as a building block of generation strategy¶GenerationStep
¶There are two ways of specifying a model for a generation step: via an entry in a Models
enum or via a 'factory function' –– a callable model constructor (e.g. get_GPEI
and other factory functions in the same file). Note that using the latter path, a factory function, will prohibit GenerationStrategy
storage and is generally discouraged.
GenerationStep
settings¶All of the available settings are described in the documentation:
print(GenerationStep.__doc__)
One step in the generation strategy, corresponds to a single model. Describes the model, how many trials will be generated with this model, what minimum number of observations is required to proceed to the next model, etc. NOTE: Model can be specified either from the model registry (`ax.modelbridge.registry.Models` or using a callable model constructor. Only models from the registry can be saved, and thus optimization can only be resumed if interrupted when using models from the registry. Args: model: A member of `Models` enum or a callable returning an instance of `ModelBridge` with an instantiated underlying `Model`. Refer to `ax/modelbridge/factory.py` for examples of such callables. num_trials: How many trials to generate with the model from this step. If set to -1, trials will continue to be generated from this model as long as `generation_strategy.gen` is called (available only for the last of the generation steps). min_trials_observed: How many trials must be completed before the generation strategy can proceed to the next step. Defaults to 0. If `num_trials` of a given step have been generated but `min_trials_ observed` have not been completed, a call to `generation_strategy.gen` will fail with a `DataRequiredError`. max_parallelism: How many trials generated in the course of this step are allowed to be run (i.e. have `trial.status` of `RUNNING`) simultaneously. If `max_parallelism` trials from this step are already running, a call to `generation_strategy.gen` will fail with a `MaxParallelismReached Exception`, indicating that more trials need to be completed before generating and running next trials. use_update: Whether to use `model_bridge.update` instead or reinstantiating model + bridge on every call to `gen` within a single generation step. NOTE: use of `update` on stateful models that do not implement `_get_state` may result in inability to correctly resume a generation strategy from a serialized state. enforce_num_trials: Whether to enforce that only `num_trials` are generated from the given step. If False and `num_trials` have been generated, but `min_trials_observed` have not been completed, `generation_strategy.gen` will continue generating trials from the current step, exceeding `num_ trials` for it. Allows to avoid `DataRequiredError`, but delays proceeding to next generation step. model_kwargs: Dictionary of kwargs to pass into the model constructor on instantiation. E.g. if `model` is `Models.SOBOL`, kwargs will be applied as `Models.SOBOL(**model_kwargs)`; if `model` is `get_sobol`, `get_sobol( **model_kwargs)`. NOTE: if generation strategy is interrupted and resumed from a stored snapshot and its last used model has state saved on its generator runs, `model_kwargs` is updated with the state dict of the model, retrieved from the last generator run of this generation strategy. model_gen_kwargs: Each call to `generation_strategy.gen` performs a call to the step's model's `gen` under the hood; `model_gen_kwargs` will be passed to the model's `gen` like so: `model.gen(**model_gen_kwargs)`. index: Index of this generation step, for use internally in `Generation Strategy`. Do not assign as it will be reassigned when instantiating `GenerationStrategy` with a list of its steps. should_deduplicate: Whether to deduplicate the parameters of proposed arms against those of previous arms via rejection sampling. If this is True, the generation strategy will discard generator runs produced from the generation step that has `should_deduplicate=True` if they contain arms already present on the experiment and replace them with new generator runs. If no generator run with entirely unique arms could be produced in 5 attempts, a `GenerationStrategyRepeatedPoints` error will be raised, as we assume that the optimization converged when the model can no longer suggest unique arms.
GenerationStep
-s together¶A GenerationStrategy
moves from one step to another when:
N=num_trials
generator runs were produced and attached as trials to the experiment ANDM=min_trials_observed
have been completed and have data.Caveat: enforce_num_trials
setting:
enforce_num_trials=True
for a given generation step, if 1) is reached but 2) is not yet reached, the generation strategy will raise a DataRequiredError
, indicating that more trials need to be completed before the next step.enforce_num_trials=False
, the GS will continue producing generator runs from the current step until 2) is reached.max_parallelism
enforcement¶Generation strategy can restrict the number of trials that can be ran simultaneously (to encourage sequential optimization, which benefits Bayesian optimization performance). When the parallelism limit is reached, a call to GenerationStrategy.gen
will result in a MaxParallelismReachedException
.
The correct way to handle this exception:
GenerationStep.max_parallelism
is configured correctly for all steps in your generation strategy (to disable it completely, configure GenerationStep.max_parallelism=None
),trial.mark_completed
.When used through Service API or Scheduler
, generation strategy will be automatically stored to SQL or JSON via specifying DBSettings
to either AxClient
or Scheduler
(details in respective tutorials in the "Tutorials" page). Generation strategy can also be stored to SQL or JSON individually, as shown below.
More detail on SQL and JSON storage in Ax generally can be found in "Building Blocks of Ax" tutorial.
For SQL storage setup in Ax, read through the "Storage" documentation page.
Note that unlike an Ax experiment, a generation strategy does not have a name or another unique identifier. Therefore, a generation strategy is stored in association with experiment and can be retrieved by the associated experiment's name.
from ax.storage.sqa_store.save import save_generation_strategy, save_experiment
from ax.storage.sqa_store.load import load_experiment, load_generation_strategy_by_experiment_name
from ax.storage.sqa_store.db import init_engine_and_session_factory,get_engine, create_all_tables
from ax.storage.sqa_store.load import load_experiment
from ax.storage.sqa_store.save import save_experiment
init_engine_and_session_factory(url='sqlite:///foo2.db')
engine = get_engine()
create_all_tables(engine)
save_experiment(experiment)
save_generation_strategy(gs)
experiment = load_experiment(experiment_name=experiment.name)
gs = load_generation_strategy_by_experiment_name(
experiment_name=experiment.name,
experiment=experiment, # Can optionally specify experiment object to avoid loading it from database twice
)
gs
[INFO 09-15 17:24:00] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False /home/runner/work/Ax/Ax/ax/storage/sqa_store/load.py:246: SAWarning: TypeDecorator JSONEncodedText() will not produce a cache key because the ``cache_ok`` attribute is not set to True. This can have significant performance implications including some performance degradations in comparison to prior SQLAlchemy versions. Set this attribute to True if this type object's state is safe to use in a cache key, or False to disable this warning. (Background on this error at: https://sqlalche.me/e/14/cprf) session.query(exp_sqa_class.properties) [INFO 09-15 17:24:00] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials])
from ax.storage.json_store.encoder import object_to_json
from ax.storage.json_store.decoder import object_from_json
gs_json = object_to_json(gs) # Can be written to a file or string via `json.dump` etc.
gs = object_from_json(gs_json) # Decoded back from JSON (can be loaded from file, string via `json.load` etc.)
gs
[INFO 09-15 17:24:00] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials])
Below is a list of important "gotchas" of using generation strategy (especially outside of the higher-level APIs like the Service API or the Scheduler
):
GenerationStrategy.gen
produces GeneratorRun
-s, not trials¶Since GenerationStrategy.gen
mimics ModelBridge.gen
and allows for human-in-the-loop usage mode, a call to gen
produces a GeneratorRun
, which can then be added (or altered before addition or not added at all) to a Trial
or BatchTrial
on a given experiment. So it's important to add the generator run to a trial, since otherwise it will not be attached to the experiment on its own.
generator_run = gs.gen(
experiment=experiment, n=1, pending_observations=get_pending_observation_features(experiment)
)
experiment.new_trial(generator_run)
Trial(experiment_name='branin_test_experiment', index=1, status=TrialStatus.CANDIDATE, arm=Arm(name='1_0', parameters={'x1': -2.737500830553472, 'x2': 1.095017110928893}))
model_kwargs
elements that do not define serialization logic in Ax¶Note that passing objects that are not yet serializable in Ax (e.g. a BoTorch Prior
object) as part of GenerationStep.model_kwargs
or GenerationStep.model_gen_kwargs
will prevent correct generation strategy storage. If this becomes a problem, feel free to open an issue on our Github: https://github.com/facebook/Ax/issues to get help with adding storage support for a given object.
Models
enum entries over a factory function?¶Models.GPEI
captures all arguments to the model and model bridge and stores them on a generator runs, subsequently produced by the model. Since the capturing logic is part of Models.__call__
function, it is not present in a factory function. Furthermore, there is no safe and flexible way to serialize callables in Python.ModelBridge
with an underlying Model
instance based on them), it is not standard in terms of its inputs. Models
introduces a standardized interface, making it easy to adapt any example to one's specific case.Models
and natively supported in Ax?¶Please open a Github issue to request a new modeling setup in Ax (or for any other questions or requests).
Total runtime of script: 4.84 seconds.