Source code for ax.utils.testing.benchmark_stubs

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from typing import Any, Callable, Iterator

import numpy as np
import torch
from ax.benchmark.benchmark_method import BenchmarkMethod
from ax.benchmark.benchmark_problem import (
    BenchmarkProblem,
    create_problem_from_botorch,
    get_moo_opt_config,
    get_soo_opt_config,
)
from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction
from ax.benchmark.benchmark_test_functions.surrogate import SurrogateTestFunction
from ax.benchmark.problems.synthetic.hss.jenatton import get_jenatton_search_space
from ax.core.arm import Arm
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.parameter import ChoiceParameter, ParameterType
from ax.core.search_space import SearchSpace
from ax.core.trial import BaseTrial, Trial
from ax.core.types import TParameterization, TParamValue
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
from ax.modelbridge.external_generation_node import ExternalGenerationNode
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.torch import TorchModelBridge
from ax.models.torch.botorch_modular.model import BoTorchModel
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.utils.testing.core_stubs import (
    get_branin_experiment,
    get_branin_experiment_with_multi_objective,
)
from botorch.models.gp_regression import SingleTaskGP
from botorch.test_functions.multi_objective import BraninCurrin
from botorch.test_functions.synthetic import Branin
from pyre_extensions import assert_is_instance
from torch.utils.data import Dataset


[docs] def get_single_objective_benchmark_problem( observe_noise_sd: bool = False, num_trials: int = 4, test_problem_kwargs: dict[str, Any] | None = None, report_inference_value_as_trace: bool = False, noise_std: float | list[float] = 0.0, ) -> BenchmarkProblem: return create_problem_from_botorch( test_problem_class=Branin, test_problem_kwargs=test_problem_kwargs or {}, num_trials=num_trials, observe_noise_sd=observe_noise_sd, report_inference_value_as_trace=report_inference_value_as_trace, noise_std=noise_std, )
[docs] def get_multi_objective_benchmark_problem( observe_noise_sd: bool = False, num_trials: int = 4, test_problem_class: type[BraninCurrin] = BraninCurrin, report_inference_value_as_trace: bool = False, ) -> BenchmarkProblem: return create_problem_from_botorch( test_problem_class=test_problem_class, test_problem_kwargs={}, num_trials=num_trials, observe_noise_sd=observe_noise_sd, report_inference_value_as_trace=report_inference_value_as_trace, )
[docs] def get_soo_surrogate_test_function(lazy: bool = True) -> SurrogateTestFunction: experiment = get_branin_experiment(with_completed_trial=True) surrogate = TorchModelBridge( experiment=experiment, search_space=experiment.search_space, model=BoTorchModel(surrogate=Surrogate(botorch_model_class=SingleTaskGP)), data=experiment.lookup_data(), transforms=[], ) if lazy: test_function = SurrogateTestFunction( outcome_names=["branin"], name="test", get_surrogate_and_datasets=lambda: (surrogate, []), ) else: test_function = SurrogateTestFunction( outcome_names=["branin"], name="test", _surrogate=surrogate, _datasets=[], ) return test_function
[docs] def get_soo_surrogate() -> BenchmarkProblem: experiment = get_branin_experiment(with_completed_trial=True) test_function = get_soo_surrogate_test_function() optimization_config = get_soo_opt_config( outcome_names=test_function.outcome_names, observe_noise_sd=True, ) return BenchmarkProblem( name="test", search_space=experiment.search_space, optimization_config=optimization_config, num_trials=6, optimal_value=0.0, test_function=test_function, )
[docs] def get_moo_surrogate() -> BenchmarkProblem: experiment = get_branin_experiment_with_multi_objective(with_completed_trial=True) surrogate = TorchModelBridge( experiment=experiment, search_space=experiment.search_space, model=BoTorchModel(surrogate=Surrogate(botorch_model_class=SingleTaskGP)), data=experiment.lookup_data(), transforms=[], ) outcome_names = ["branin_a", "branin_b"] test_function = SurrogateTestFunction( name="test", outcome_names=outcome_names, get_surrogate_and_datasets=lambda: (surrogate, []), ) optimization_config = get_moo_opt_config( outcome_names=outcome_names, ref_point=[0.0, 0.0], observe_noise_sd=True, ) return BenchmarkProblem( name="test", search_space=experiment.search_space, optimization_config=optimization_config, num_trials=10, optimal_value=1.0, test_function=test_function, )
[docs] def get_benchmark_result() -> BenchmarkResult: problem = get_single_objective_benchmark_problem() return BenchmarkResult( name="test_benchmarking_result", seed=0, experiment=Experiment( name="test_benchmarking_experiment", search_space=problem.search_space, optimization_config=problem.optimization_config, is_test=True, ), inference_trace=np.ones(4), oracle_trace=np.zeros(4), optimization_trace=np.array([3, 2, 1, 0.1]), score_trace=np.array([3, 2, 1, 0.1]), fit_time=0.1, gen_time=0.2, )
[docs] def get_aggregated_benchmark_result() -> AggregatedBenchmarkResult: result = get_benchmark_result() return AggregatedBenchmarkResult.from_benchmark_results([result, result])
[docs] @dataclass(kw_only=True) class DummyTestFunction(BenchmarkTestFunction): outcome_names: list[str] = field(default_factory=list) num_outcomes: int = 1 dim: int = 6 def __post_init__(self) -> None: self.outcome_names = [f"objective_{i}" for i in range(self.num_outcomes)] # pyre-fixme[14]: Inconsistent override, as dict[str, float] is not a # `TParameterization`
[docs] def evaluate_true(self, params: dict[str, float]) -> torch.Tensor: value = sum(elt**2 for elt in params.values()) return value * torch.ones(self.num_outcomes, dtype=torch.double)
[docs] class TestDataset(Dataset): def __init__( self, root: str = "", train: bool = True, download: bool = True, # pyre-fixme[2]: Parameter annotation cannot be `Any`. transform: Any = None, ) -> None: torch.manual_seed(0) self.data: torch.Tensor = torch.randint( low=0, high=256, size=(32, 1, 28, 28), dtype=torch.float32 ) self.targets: torch.Tensor = torch.randint( low=0, high=10, size=(32,), dtype=torch.uint8 ) def __len__(self) -> int: return len(self.data) def __getitem__(self, idx: int) -> tuple[torch.Tensor, int]: target = assert_is_instance(self.targets[idx].item(), int) return self.data[idx], target
[docs] def get_jenatton_arm(i: int) -> Arm: """ Args: i Non-negative int. """ jenatton_x_params = {f"x{j}": j % (i + 1) for j in range(1, 8)} jenatton_r_params = {"r8": 0.0, "r9": 0.0} return Arm(parameters={**jenatton_x_params, **jenatton_r_params}, name=f"0_{i}")
[docs] def get_jenatton_experiment() -> Experiment: experiment = Experiment( search_space=get_jenatton_search_space(), name="test_jenatton", is_test=True, ) return experiment
[docs] def get_jenatton_trials(n_trials: int) -> dict[int, Trial]: experiment = get_jenatton_experiment() for i in range(n_trials): trial = experiment.new_trial() trial.add_arm(get_jenatton_arm(i=i)) # pyre-fixme: Incompatible return type [7]: Expected `Dict[int, Trial]` but # got `Dict[int, BaseTrial]`. return experiment.trials
[docs] def get_jenatton_batch_trial() -> BatchTrial: experiment = get_jenatton_experiment() trial = experiment.new_batch_trial() trial.add_arm(get_jenatton_arm(0)) trial.add_arm(get_jenatton_arm(1)) return trial
[docs] class DeterministicGenerationNode(ExternalGenerationNode): """ A GenerationNode that explores a discrete search space with one parameter deterministically. """ def __init__( self, search_space: SearchSpace, ) -> None: if len(search_space.parameters) != 1: raise ValueError( "DeterministicGenerationNode only supports search spaces with one " "parameter." ) param = list(search_space.parameters.values())[0] if not isinstance(param, ChoiceParameter): raise ValueError( "DeterministicGenerationNode only supports ChoiceParameters." ) super().__init__(node_name="Deterministic") self.param_name: str = param.name self.iterator: Iterator[TParamValue] = iter(param.values)
[docs] def update_generator_state(self, experiment: Experiment, data: Data) -> None: return
[docs] def get_next_candidate( self, pending_parameters: list[TParameterization] ) -> TParameterization: return {self.param_name: next(self.iterator)}
[docs] @dataclass(kw_only=True) class IdentityTestFunction(BenchmarkTestFunction): outcome_names: Sequence[str] = field(default_factory=lambda: ["objective"]) n_time_intervals: int = 1 # pyre-fixme[14]: Inconsistent override
[docs] def evaluate_true(self, params: Mapping[str, float]) -> torch.Tensor: """ Args: params: A dictionary with key "x0". """ value = params["x0"] return torch.full( (len(self.outcome_names), self.n_time_intervals), value, dtype=torch.float64 )
[docs] def get_discrete_search_space() -> SearchSpace: return SearchSpace( parameters=[ ChoiceParameter( name="x0", parameter_type=ParameterType.INT, # pyre-fixme: Incompatible parameter type [6]: In call # `ChoiceParameter.__init__`, for argument `values`, expected # `List[Union[None, bool, float, int, str]]` but got # `List[int]`. values=list(range(20)), ) ] )
[docs] def get_async_benchmark_method( early_stopping_strategy: BaseEarlyStoppingStrategy | None = None, ) -> BenchmarkMethod: gs = GenerationStrategy( nodes=[DeterministicGenerationNode(search_space=get_discrete_search_space())] ) return BenchmarkMethod( generation_strategy=gs, distribute_replications=False, max_pending_trials=2, batch_size=1, early_stopping_strategy=early_stopping_strategy, )
[docs] def get_async_benchmark_problem( map_data: bool, trial_runtime_func: Callable[[BaseTrial], int], n_time_intervals: int = 1, lower_is_better: bool = False, ) -> BenchmarkProblem: search_space = get_discrete_search_space() test_function = IdentityTestFunction(n_time_intervals=n_time_intervals) optimization_config = get_soo_opt_config( outcome_names=["objective"], use_map_metric=map_data, observe_noise_sd=True, lower_is_better=lower_is_better, ) return BenchmarkProblem( name="test", search_space=search_space, optimization_config=optimization_config, test_function=test_function, num_trials=4, optimal_value=19.0, trial_runtime_func=trial_runtime_func, )