Source code for ax.preview.modelbridge.dispatch_utils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from ax.core.base_trial import TrialStatus
from ax.exceptions.core import UnsupportedError
from ax.modelbridge.generation_strategy import GenerationNode, GenerationStrategy
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import Models
from ax.modelbridge.transition_criterion import MinTrials
from ax.models.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec
from ax.preview.api.configs import GenerationMethod, GenerationStrategyConfig
from botorch.models.transforms.input import Warp
from gpytorch.kernels.linear_kernel import LinearKernel


def _get_sobol_node(
    gs_config: GenerationStrategyConfig,
) -> GenerationNode:
    """Constructs a Sobol node based on inputs from ``gs_config``.
    The Sobol generator utilizes `initialization_random_seed` if specified.

    This node always transitions to "MBM", using the following transition criteria:
    - MinTrials enforcing the initialization budget.
        - If the initialization budget is not specified, it defaults to 5.
        - The TC will not block generation if `allow_exceeding_initialization_budget`
            is set to True.
        - The TC is currently not restricted to any trial statuses and will
            count all trials.
        - `use_existing_trials_for_initialization` controls whether trials previously
            attached to the experiment are counted as part of the initialization budget.
    - MinTrials enforcing the minimum number of observed initialization trials.
        - If `min_observed_initialization_trials` is not specified, it defaults
            to `max(1, initialization_budget // 2)`.
        - The TC currently only counts trials in status COMPLETED (with data attached)
            as observed trials.
        - `use_existing_trials_for_initialization` controls whether trials previously
            attached to the experiment are counted as part of the required number of
            observed initialization trials.
    """
    # Set the default options.
    initialization_budget = gs_config.initialization_budget
    if initialization_budget is None:
        initialization_budget = 5
    min_observed_initialization_trials = gs_config.min_observed_initialization_trials
    if min_observed_initialization_trials is None:
        min_observed_initialization_trials = max(1, initialization_budget // 2)
    # Construct the transition criteria.
    transition_criteria = [
        MinTrials(  # This represents the initialization budget.
            threshold=initialization_budget,
            transition_to="MBM",
            block_gen_if_met=(not gs_config.allow_exceeding_initialization_budget),
            block_transition_if_unmet=True,
            use_all_trials_in_exp=gs_config.use_existing_trials_for_initialization,
        ),
        MinTrials(  # This represents minimum observed trials requirement.
            threshold=min_observed_initialization_trials,
            transition_to="MBM",
            block_gen_if_met=False,
            block_transition_if_unmet=True,
            use_all_trials_in_exp=gs_config.use_existing_trials_for_initialization,
            only_in_statuses=[TrialStatus.COMPLETED],
            count_only_trials_with_data=True,
        ),
    ]
    return GenerationNode(
        node_name="Sobol",
        model_specs=[
            ModelSpec(
                model_enum=Models.SOBOL,
                model_kwargs={"seed": gs_config.initialization_random_seed},
            )
        ],
        transition_criteria=transition_criteria,
        should_deduplicate=True,
    )


def _get_mbm_node(
    gs_config: GenerationStrategyConfig,
) -> GenerationNode:
    """Constructs an MBM node based on the method specified in ``gs_config``.

    The ``SurrogateSpec`` takes the following form for the given method:
    - BALANCED: Two model configs: one with MBM defaults, the other with
        linear kernel with input warping.
    - FAST: An empty model config that utilizes MBM defaults.
    """
    # Construct the surrogate spec.
    if gs_config.method == GenerationMethod.FAST:
        model_configs = [ModelConfig(name="MBM defaults")]
    elif gs_config.method == GenerationMethod.BALANCED:
        model_configs = [
            ModelConfig(name="MBM defaults"),
            ModelConfig(
                covar_module_class=LinearKernel,
                input_transform_classes=[Warp],
                name="LinearKernel with Warp",
            ),
        ]
    else:
        raise UnsupportedError(f"Unsupported generation method: {gs_config.method}.")
    torch_device = (
        None if gs_config.torch_device is None else torch.device(gs_config.torch_device)
    )
    return GenerationNode(
        node_name="MBM",
        model_specs=[
            ModelSpec(
                model_enum=Models.BOTORCH_MODULAR,
                model_kwargs={
                    "surrogate_spec": SurrogateSpec(model_configs=model_configs),
                    "torch_device": torch_device,
                },
            )
        ],
        should_deduplicate=True,
    )


[docs] def choose_generation_strategy( gs_config: GenerationStrategyConfig, ) -> GenerationStrategy: """Choose a generation strategy based on the properties of the experiment and the inputs provided in ``gs_config``. NOTE: The behavior of this function is subject to change. It will be updated to produce best general purpose generation strategies based on benchmarking results. Args: gs_config: A ``GenerationStrategyConfig`` object that informs the choice of generation strategy. Returns: A generation strategy. """ # Handle the random search case. if gs_config.method == GenerationMethod.RANDOM_SEARCH: return GenerationStrategy( name="QuasiRandomSearch", nodes=[ GenerationNode( node_name="Sobol", model_specs=[ ModelSpec( model_enum=Models.SOBOL, model_kwargs={"seed": gs_config.initialization_random_seed}, ) ], ) ], ) # Construct the nodes. sobol_node = _get_sobol_node(gs_config) # Construct the MBM node. mbm_node = _get_mbm_node(gs_config) method_str = gs_config.method.value return GenerationStrategy( name=f"Sobol+MBM:{method_str}", nodes=[sobol_node, mbm_node], )