Source code for ax.service.utils.dispatch

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import logging
from math import ceil

from ax.core.parameter import ChoiceParameter, RangeParameter
from ax.core.search_space import SearchSpace
from ax.modelbridge.factory import Models
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.utils.common.logger import get_logger


logger: logging.Logger = get_logger(__name__)


[docs]def choose_generation_strategy( search_space: SearchSpace, arms_per_trial: int = 1, enforce_sequential_optimization: bool = True, ) -> GenerationStrategy: """Select an appropriate generation strategy based on the properties of the search space.""" num_continuous_parameters, num_discrete_choices = 0, 0 for parameter in search_space.parameters: if isinstance(parameter, ChoiceParameter): num_discrete_choices += len(parameter.values) if isinstance(parameter, RangeParameter): num_continuous_parameters += 1 # If there are more discrete choices than continuous parameters, Sobol # will do better than GP+EI. if num_continuous_parameters >= num_discrete_choices: # Ensure that number of arms per model is divisible by batch size. sobol_arms = ( ceil(max(5, len(search_space.parameters)) / arms_per_trial) * arms_per_trial ) logger.info( "Using Bayesian Optimization generation strategy. Iterations after " f"{sobol_arms} will take longer to generate due to model-fitting." ) return GenerationStrategy( name="Sobol+GPEI", steps=[ GenerationStep( model=Models.SOBOL, num_arms=sobol_arms, min_arms_observed=ceil(sobol_arms / 2), enforce_num_arms=enforce_sequential_optimization, ), GenerationStep(model=Models.GPEI, num_arms=-1), ], ) else: logger.info(f"Using Sobol generation strategy.") return GenerationStrategy( name="Sobol", steps=[GenerationStep(model=Models.SOBOL, num_arms=-1)] )