import sys
import plotly.io as pio
if 'google.colab' in sys.modules:
pio.renderers.default = "colab"
%pip install ax-platform
from ax.modelbridge.dispatch_utils import choose_generation_strategy
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.modelbridge_utils import get_pending_observation_features
from ax.modelbridge.registry import ModelRegistryBase, Models
from ax.utils.testing.core_stubs import get_branin_experiment, get_branin_search_space
Generation Strategy (GS) Tutorial
GenerationStrategy
(API reference)
is a key abstraction in Ax:
- It allows for specifying multiple optimization algorithms to chain one after another in the course of the optimization.
- Many higher-level APIs in Ax use generation strategies: Service and Loop APIs,
Scheduler
etc. (tutorials for all those higher-level APIs are here: https://ax.dev/tutorials/). - Generation strategy allows for storage and resumption of modeling setups, making optimization resumable from SQL or JSON snapshots.
This tutorial walks through a few examples of generation strategies and discusses its
important settings. Before reading it, we recommend familiarizing yourself with how
Model
and ModelBridge
work in Ax:
https://ax.dev/docs/models.html#deeper-dive-organization-of-the-modeling-stack.
Contents:
- Quick-start examples
- Manually configured GS
- Auto-selected GS
- Candidate generation from a GS
- Deep dive:
GenerationStep
a building block of the generation strategy- Describing a model
- Other
GenerationStep
settings - Chaining
GenerationStep
-s together max_parallelism
enforcement and handling theMaxParallelismReachedException
GenerationStrategy
storage- JSON storage
- SQL storage
- Advanced considerations / "gotchas"
- Generation strategy produces
GeneratorRun
-s, notTrial
-s model_kwargs
elements that don't have associated serialization logic in Ax- Why prefer
Models
registry enum entries over a factory function? - How to request more modeling setups in
Models
?
- Generation strategy produces
1. Quick-start examples
1A. Manually configured generation strategy
Below is a typical generation strategy used for most single-objective optimization cases in Ax:
gs = GenerationStrategy(
steps=[
# 1. Initialization step (does not require pre-existing data and is well-suited for
# initial sampling of the search space)
GenerationStep(
model=Models.SOBOL,
num_trials=5, # How many trials should be produced from this generation step
min_trials_observed=3, # How many trials need to be completed to move to next model
max_parallelism=5, # Max parallelism for this step
model_kwargs={"seed": 999}, # Any kwargs you want passed into the model
model_gen_kwargs={}, # Any kwargs you want passed to `modelbridge.gen`
),
# 2. Bayesian optimization step (requires data obtained from previous phase and learns
# from all data available at the time of each new candidate generation call)
GenerationStep(
model=Models.BOTORCH_MODULAR,
num_trials=-1, # No limitation on how many trials should be produced from this step
max_parallelism=3, # Parallelism limit for this step, often lower than for Sobol
# More on parallelism vs. required samples in BayesOpt:
# https://ax.dev/docs/bayesopt.html#tradeoff-between-parallelism-and-total-number-of-trials
),
]
)
1B. Auto-selected generation strategy
Ax provides a
choose_generation_strategy
utility, which can auto-select a suitable generation strategy given a search space and
an array of other optional settings. The utility is fairly simple at the moment, but
additional development (support for multi-objective optimization, multi-fidelity
optimization, Bayesian optimization with categorical kernels etc.) is coming soon.
gs = choose_generation_strategy(
# Required arguments:
search_space=get_branin_search_space(), # Ax `SearchSpace`
# Some optional arguments (shown with their defaults), see API docs for more settings:
# https://ax.dev/api/modelbridge.html#module-ax.modelbridge.dispatch_utils
use_batch_trials=False, # Whether this GS will be used to generate 1-arm `Trial`-s or `BatchTrials`
no_bayesian_optimization=False, # Use quasi-random candidate generation without BayesOpt
max_parallelism_override=None, # Integer, to which to set the `max_parallelism` setting of all steps in this GS
)
gs
[INFO 02-03 18:39:18] ax.modelbridge.dispatch_utils: Using Models.BOTORCH_MODULAR since there is at least one ordered parameter and there are no unordered categorical parameters.
[INFO 02-03 18:39:18] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=2 num_trials=None use_batch_trials=False
[INFO 02-03 18:39:18] ax.modelbridge.dispatch_utils: calculated num_initialization_trials=5
[INFO 02-03 18:39:18] ax.modelbridge.dispatch_utils: num_completed_initialization_trials=0 num_remaining_initialization_trials=5
[INFO 02-03 18:39:18] ax.modelbridge.dispatch_utils: verbose, disable_progbar, and jit_compile are not yet supported when using choose_generation_strategy with ModularBoTorchModel, dropping these arguments.
[INFO 02-03 18:39:18] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting.
GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials])
1C. Candidate generation from a generation strategy
While often used through Service or Loop API or other higher-order abstractions like the
Ax Scheduler
(where the generation strategy is used to fit models and produce
candidates from them under-the-hood), it's also possible to use the GS directly, in
place of a ModelBridge
instance. The interface of GenerationStrategy.gen
is the same
as ModelBridge.gen
.
experiment = get_branin_experiment()
[INFO 02-03 18:39:18] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
Note that it's important to specify pending observations to the call to gen
to
avoid getting the same points re-suggested. Without pending_observations
argument, Ax
models are not aware of points that should be excluded from generation. Points are
considered "pending" when they belong to STAGED
, RUNNING
, or ABANDONED
trials
(with the latter included so model does not re-suggest points that are considered "bad"
and should not be re-suggested).
If the call to get_pending_obervation_features
becomes slow in your setup (since it
performs data-fetching etc.), you can opt for
get_pending_observation_features_based_on_trial_status
(also from
ax.modelbridge.modelbridge_utils
), but note the limitations of that utility (detailed
in its docstring).
generator_run = gs.gen(
experiment=experiment, # Ax `Experiment`, for which to generate new candidates
data=None, # Ax `Data` to use for model training, optional.
n=1, # Number of candidate arms to produce
pending_observations=get_pending_observation_features(
experiment
), # Points that should not be re-generated
# Any other kwargs specified will be passed through to `ModelBridge.gen` along with `GenerationStep.model_gen_kwargs`
)
generator_run
/home/runner/work/Ax/Ax/ax/modelbridge/cross_validation.py:439: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.
warn("Encountered exception in computing model fit quality: " + str(e))
GeneratorRun(1 arms, total weight 1.0)
Then we can add the newly produced
GeneratorRun
to the experiment as a
Trial
(or BatchTrial
if n
> 1):
trial = experiment.new_trial(generator_run)
trial
Trial(experiment_name='branin_test_experiment', index=0, status=TrialStatus.CANDIDATE, arm=Arm(name='0_0', parameters={'x1': 3.461330533027649, 'x2': 9.43772703409195}))
Important notes on GenerationStrategy.gen
:
- if
data
argument above is not specified, GS will pull experiment data from cache viaexperiment.lookup_data
, - without specifying
pending_observations
, the GS (and any model in Ax) could produce the same candidate over and over, as without that argument the model is not 'aware' that the candidate is part of aRUNNING
orABANDONED
trial and should not be re-suggested again.
In cases where get_pending_observation_features
is too slow and the experiment
consists of 1-arm Trial
-s only, it's possible to use
get_pending_observation_features_based_on_trial_status
instead (found in the same
file).
Note that when using the Ax Service API, one of the arguments to AxClient
is
choose_generation_strategy_kwargs
; specifying that argument is a convenient way to
influence the choice of generation strategy in AxClient
without manually specifying a
full GenerationStrategy
.
2. GenerationStep
as a building block of generation strategy
2A. Describing a model to use in a given GenerationStep
There are two ways of specifying a model for a generation step: via an entry in a
Models
enum or via a 'factory function' –– a callable model constructor (e.g.
get_GPEI
and other factory functions in the same file). Note that using the latter path, a
factory function, will prohibit GenerationStrategy
storage and is generally
discouraged.
2B. Other GenerationStep
settings
All of the available settings are described in the documentation:
print(GenerationStep.__doc__)
One step in the generation strategy, corresponds to a single model.
Describes the model, how many trials will be generated with this model, what
minimum number of observations is required to proceed to the next model, etc.
NOTE: Model can be specified either from the model registry
(ax.modelbridge.registry.Models or using a callable model constructor. Only
models from the registry can be saved, and thus optimization can only be
resumed if interrupted when using models from the registry.
Args:
model: A member of Models enum or a callable returning an instance of
ModelBridge with an instantiated underlying Model. Refer to
ax/modelbridge/factory.py for examples of such callables.
num_trials: How many trials to generate with the model from this step.
If set to -1, trials will continue to be generated from this model
as long as generation_strategy.gen is called (available only for
the last of the generation steps).
min_trials_observed: How many trials must be completed before the
generation strategy can proceed to the next step. Defaults to 0.
If num_trials of a given step have been generated but min_trials_
observed have not been completed, a call to generation_strategy.gen
will fail with a DataRequiredError.
max_parallelism: How many trials generated in the course of this step are
allowed to be run (i.e. have trial.status of RUNNING) simultaneously.
If max_parallelism trials from this step are already running, a call
to generation_strategy.gen will fail with a MaxParallelismReached
Exception, indicating that more trials need to be completed before
generating and running next trials.
use_update: DEPRECATED.
enforce_num_trials: Whether to enforce that only num_trials are generated
from the given step. If False and num_trials have been generated, but
min_trials_observed have not been completed, generation_strategy.gen
will continue generating trials from the current step, exceeding num_
trials for it. Allows to avoid DataRequiredError, but delays
proceeding to next generation step.
model_kwargs: Dictionary of kwargs to pass into the model constructor on
instantiation. E.g. if model is Models.SOBOL, kwargs will be applied
as Models.SOBOL(**model_kwargs); if model is get_sobol, get_sobol(
**model_kwargs). NOTE: if generation strategy is interrupted and
resumed from a stored snapshot and its last used model has state saved on
its generator runs, model_kwargs is updated with the state dict of the
model, retrieved from the last generator run of this generation strategy.
model_gen_kwargs: Each call to generation_strategy.gen performs a call to the
step's model's gen under the hood; model_gen_kwargs will be passed to
the model's gen like so: model.gen(**model_gen_kwargs).
completion_criteria: List of TransitionCriterion. All is_met must evaluate
True for the GenerationStrategy to move on to the next Step
index: Index of this generation step, for use internally in Generation
Strategy. Do not assign as it will be reassigned when instantiating
GenerationStrategy with a list of its steps.
should_deduplicate: Whether to deduplicate the parameters of proposed arms
against those of previous arms via rejection sampling. If this is True,
the generation strategy will discard generator runs produced from the
generation step that has should_deduplicate=True if they contain arms
already present on the experiment and replace them with new generator runs.
If no generator run with entirely unique arms could be produced in 5
attempts, a GenerationStrategyRepeatedPoints error will be raised, as we
assume that the optimization converged when the model can no longer suggest
unique arms.
model_name: Optional name of the model. If not specified, defaults to the
model key of the model spec.
Note for developers: by "model" here we really mean an Ax ModelBridge object, which
contains an Ax Model under the hood. We call it "model" here to simplify and focus
on explaining the logic of GenerationStep and GenerationStrategy.
2C. Chaining GenerationStep
-s together
A GenerationStrategy
moves from one step to another when:
N=num_trials
generator runs were produced and attached as trials to the experiment ANDM=min_trials_observed
have been completed and have data.
Caveat: enforce_num_trials
setting:
- If
enforce_num_trials=True
for a given generation step, if 1) is reached but 2) is not yet reached, the generation strategy will raise aDataRequiredError
, indicating that more trials need to be completed before the next step. - If
enforce_num_trials=False
, the GS will continue producing generator runs from the current step until 2) is reached.
2D. max_parallelism
enforcement
Generation strategy can restrict the number of trials that can be ran simultaneously (to
encourage sequential optimization, which benefits Bayesian optimization performance).
When the parallelism limit is reached, a call to GenerationStrategy.gen
will result in
a MaxParallelismReachedException
.
The correct way to handle this exception:
- Make sure that
GenerationStep.max_parallelism
is configured correctly for all steps in your generation strategy (to disable it completely, configureGenerationStep.max_parallelism=None
), - When encountering the exception, wait to produce more generator runs until more trial
evluations complete and log the trial completion via
trial.mark_completed
.
3. SQL and JSON storage of a generation strategy
When used through Service API or Scheduler
, generation strategy will be automatically
stored to SQL or JSON via specifying DBSettings
to either AxClient
or Scheduler
(details in respective tutorials in the "Tutorials" page).
Generation strategy can also be stored to SQL or JSON individually, as shown below.
More detail on SQL and JSON storage in Ax generally can be found in "Building Blocks of Ax" tutorial.
3A. SQL storage
For SQL storage setup in Ax, read through the "Storage" documentation page.
Note that unlike an Ax experiment, a generation strategy does not have a name or another unique identifier. Therefore, a generation strategy is stored in association with experiment and can be retrieved by the associated experiment's name.
from ax.storage.sqa_store.db import (
create_all_tables,
get_engine,
init_engine_and_session_factory,
)
from ax.storage.sqa_store.load import (
load_experiment,
load_generation_strategy_by_experiment_name,
)
from ax.storage.sqa_store.save import save_experiment, save_generation_strategy
init_engine_and_session_factory(url="sqlite:///foo2.db")
engine = get_engine()
create_all_tables(engine)
save_experiment(experiment)
save_generation_strategy(gs)
experiment = load_experiment(experiment_name=experiment.name)
gs = load_generation_strategy_by_experiment_name(
experiment_name=experiment.name,
experiment=experiment, # Can optionally specify experiment object to avoid loading it from database twice
)
gs
[ERROR 02-03 18:39:19] ax.storage.sqa_store.encoder: ATTENTION: The Ax team is considering deprecating SQLAlchemy storage. If you are currently using SQLAlchemy storage, please reach out to us via GitHub Issues here: https://github.com/facebook/Ax/issues/2975
[INFO 02-03 18:39:19] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
[INFO 02-03 18:39:19] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials])
3B. JSON storage
from ax.storage.json_store.decoder import object_from_json
from ax.storage.json_store.encoder import object_to_json
gs_json = object_to_json(gs) # Can be written to a file or string via `json.dump` etc.
gs = object_from_json(
gs_json
) # Decoded back from JSON (can be loaded from file, string via `json.load` etc.)
gs
[INFO 02-03 18:39:19] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials])
3. Advanced considerations
Below is a list of important "gotchas" of using generation strategy (especially outside
of the higher-level APIs like the Service API or the Scheduler
):
3A. GenerationStrategy.gen
produces GeneratorRun
-s, not trials
Since GenerationStrategy.gen
mimics ModelBridge.gen
and allows for human-in-the-loop
usage mode, a call to gen
produces a GeneratorRun
, which can then be added (or
altered before addition or not added at all) to a Trial
or BatchTrial
on a given
experiment. So it's important to add the generator run to a trial, since otherwise it
will not be attached to the experiment on its own.
generator_run = gs.gen(
experiment=experiment,
n=1,
pending_observations=get_pending_observation_features(experiment),
)
experiment.new_trial(generator_run)
/home/runner/work/Ax/Ax/ax/modelbridge/cross_validation.py:439: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.
warn("Encountered exception in computing model fit quality: " + str(e))
Trial(experiment_name='branin_test_experiment', index=1, status=TrialStatus.CANDIDATE, arm=Arm(name='1_0', parameters={'x1': -0.0489088986068964, 'x2': 1.4531691698357463}))
3B. model_kwargs
elements that do not define serialization logic in Ax
Note that passing objects that are not yet serializable in Ax (e.g. a BoTorch Prior
object) as part of GenerationStep.model_kwargs
or GenerationStep.model_gen_kwargs
will prevent correct generation strategy storage. If this becomes a problem, feel free
to open an issue on our Github: https://github.com/facebook/Ax/issues to get help with
adding storage support for a given object.
3C. Why prefer Models
enum entries over a factory function?
- Storage potential: a call to, for example,
Models.GPEI
captures all arguments to the model and model bridge and stores them on a generator runs, subsequently produced by the model. Since the capturing logic is part ofModels.__call__
function, it is not present in a factory function. Furthermore, there is no safe and flexible way to serialize callables in Python. - Standardization: While a 'factory function' is by default more flexible (accepts
any specified inputs and produces a
ModelBridge
with an underlyingModel
instance based on them), it is not standard in terms of its inputs.Models
introduces a standardized interface, making it easy to adapt any example to one's specific case.
3D. How can I request more modeling setups added to Models
and natively supported in Ax?
Please open a Github issue to request a new modeling setup in Ax (or for any other questions or requests).