Source code for ax.service.utils.dispatch
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import logging
from math import ceil
from typing import Optional
from ax.core.parameter import ChoiceParameter, RangeParameter
from ax.core.search_space import SearchSpace
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.utils.common.logger import get_logger
logger: logging.Logger = get_logger(__name__)
[docs]def choose_generation_strategy(
search_space: SearchSpace,
arms_per_trial: int = 1,
enforce_sequential_optimization: bool = True,
random_seed: Optional[int] = None,
) -> GenerationStrategy:
"""Select an appropriate generation strategy based on the properties of
the search space."""
model_kwargs = {"seed": random_seed} if (random_seed is not None) else None
num_continuous_parameters, num_discrete_choices = 0, 0
for parameter in search_space.parameters.values():
if isinstance(parameter, ChoiceParameter):
num_discrete_choices += len(parameter.values)
if isinstance(parameter, RangeParameter):
num_continuous_parameters += 1
# If there are more discrete choices than continuous parameters, Sobol
# will do better than GP+EI.
if num_continuous_parameters >= num_discrete_choices:
# Ensure that number of arms per model is divisible by batch size.
sobol_arms = (
ceil(max(5, len(search_space.parameters)) / arms_per_trial) * arms_per_trial
)
gs = GenerationStrategy(
steps=[
GenerationStep(
model=Models.SOBOL,
num_arms=sobol_arms,
min_arms_observed=ceil(sobol_arms / 2),
enforce_num_arms=enforce_sequential_optimization,
model_kwargs=model_kwargs,
),
GenerationStep(
model=Models.GPEI, num_arms=-1, recommended_max_parallelism=3
),
]
)
logger.info(
f"Using Bayesian Optimization generation strategy: {gs}. Iterations "
f"after {sobol_arms} will take longer to generate due to model-fitting."
)
return gs
else:
logger.info(f"Using Sobol generation strategy.")
return GenerationStrategy(
steps=[
GenerationStep(
model=Models.SOBOL, num_arms=-1, model_kwargs=model_kwargs
)
]
)