Source code for ax.benchmark.problems.registry

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import copy
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from typing import Any

from ax.benchmark.benchmark_problem import BenchmarkProblem, create_problem_from_botorch
from ax.benchmark.problems.hd_embedding import embed_higher_dimension
from ax.benchmark.problems.hpo.torchvision import (
    get_pytorch_cnn_torchvision_benchmark_problem,
)
from ax.benchmark.problems.synthetic.hss.jenatton import get_jenatton_benchmark_problem
from botorch.test_functions import synthetic
from botorch.test_functions.multi_objective import BraninCurrin


[docs] @dataclass class BenchmarkProblemRegistryEntry: factory_fn: Callable[..., BenchmarkProblem] factory_kwargs: dict[str, Any]
BENCHMARK_PROBLEM_REGISTRY = { "ackley4": BenchmarkProblemRegistryEntry( factory_fn=create_problem_from_botorch, factory_kwargs={ "test_problem_class": synthetic.Ackley, "test_problem_kwargs": {"dim": 4}, "num_trials": 40, "observe_noise_sd": False, }, ), "branin": BenchmarkProblemRegistryEntry( factory_fn=create_problem_from_botorch, factory_kwargs={ "test_problem_class": synthetic.Branin, "test_problem_kwargs": {}, "num_trials": 30, "observe_noise_sd": False, }, ), "branin_currin": BenchmarkProblemRegistryEntry( factory_fn=create_problem_from_botorch, factory_kwargs={ "test_problem_class": BraninCurrin, "test_problem_kwargs": {}, "num_trials": 30, "observe_noise_sd": False, }, ), "branin_currin30": BenchmarkProblemRegistryEntry( factory_fn=lambda n, num_trials: embed_higher_dimension( problem=create_problem_from_botorch( test_problem_class=BraninCurrin, test_problem_kwargs={}, num_trials=num_trials, observe_noise_sd=False, ), total_dimensionality=n, ), factory_kwargs={"n": 30, "num_trials": 30}, ), "griewank4": BenchmarkProblemRegistryEntry( factory_fn=create_problem_from_botorch, factory_kwargs={ "test_problem_class": synthetic.Griewank, "test_problem_kwargs": {"dim": 4}, "num_trials": 40, "observe_noise_sd": False, }, ), "hartmann3": BenchmarkProblemRegistryEntry( factory_fn=create_problem_from_botorch, factory_kwargs={ "test_problem_class": synthetic.Hartmann, "test_problem_kwargs": {"dim": 3}, "num_trials": 30, "observe_noise_sd": False, }, ), "hartmann6": BenchmarkProblemRegistryEntry( factory_fn=create_problem_from_botorch, factory_kwargs={ "test_problem_class": synthetic.Hartmann, "test_problem_kwargs": {"dim": 6}, "num_trials": 35, "observe_noise_sd": False, }, ), "hartmann30": BenchmarkProblemRegistryEntry( factory_fn=lambda n, num_trials: embed_higher_dimension( problem=create_problem_from_botorch( test_problem_class=synthetic.Hartmann, test_problem_kwargs={"dim": 6}, num_trials=num_trials, observe_noise_sd=False, ), total_dimensionality=n, ), factory_kwargs={"n": 30, "num_trials": 25}, ), "hpo_pytorch_cnn_MNIST": BenchmarkProblemRegistryEntry( factory_fn=get_pytorch_cnn_torchvision_benchmark_problem, factory_kwargs={ "name": "MNIST", "num_trials": 20, }, ), "hpo_pytorch_cnn_FashionMNIST": BenchmarkProblemRegistryEntry( factory_fn=get_pytorch_cnn_torchvision_benchmark_problem, factory_kwargs={ "name": "FashionMNIST", "num_trials": 50, }, ), "jenatton": BenchmarkProblemRegistryEntry( factory_fn=get_jenatton_benchmark_problem, factory_kwargs={"num_trials": 50, "observe_noise_sd": False}, ), "levy4": BenchmarkProblemRegistryEntry( factory_fn=create_problem_from_botorch, factory_kwargs={ "test_problem_class": synthetic.Levy, "test_problem_kwargs": {"dim": 4}, "num_trials": 40, "observe_noise_sd": False, }, ), "powell4": BenchmarkProblemRegistryEntry( factory_fn=create_problem_from_botorch, factory_kwargs={ "test_problem_class": synthetic.Powell, "test_problem_kwargs": {"dim": 4}, "num_trials": 40, "observe_noise_sd": False, }, ), "rosenbrock4": BenchmarkProblemRegistryEntry( factory_fn=create_problem_from_botorch, factory_kwargs={ "test_problem_class": synthetic.Rosenbrock, "test_problem_kwargs": {"dim": 4}, "num_trials": 40, "observe_noise_sd": False, }, ), "six_hump_camel": BenchmarkProblemRegistryEntry( factory_fn=create_problem_from_botorch, factory_kwargs={ "test_problem_class": synthetic.SixHumpCamel, "test_problem_kwargs": {}, "num_trials": 30, "observe_noise_sd": False, }, ), "three_hump_camel": BenchmarkProblemRegistryEntry( factory_fn=create_problem_from_botorch, factory_kwargs={ "test_problem_class": synthetic.ThreeHumpCamel, "test_problem_kwargs": {}, "num_trials": 30, "observe_noise_sd": False, }, ), # Problems where we observe the noise level "branin_observed_noise": BenchmarkProblemRegistryEntry( factory_fn=create_problem_from_botorch, factory_kwargs={ "test_problem_class": synthetic.Branin, "test_problem_kwargs": {}, "num_trials": 20, "observe_noise_sd": True, }, ), "branin_currin_observed_noise": BenchmarkProblemRegistryEntry( factory_fn=create_problem_from_botorch, factory_kwargs={ "test_problem_class": BraninCurrin, "test_problem_kwargs": {}, "num_trials": 30, "observe_noise_sd": True, }, ), "branin_currin30_observed_noise": BenchmarkProblemRegistryEntry( factory_fn=lambda n, num_trials: embed_higher_dimension( problem=create_problem_from_botorch( test_problem_class=BraninCurrin, test_problem_kwargs={}, num_trials=num_trials, observe_noise_sd=True, ), total_dimensionality=n, ), factory_kwargs={"n": 30, "num_trials": 30}, ), "hartmann6_observed_noise": BenchmarkProblemRegistryEntry( factory_fn=create_problem_from_botorch, factory_kwargs={ "test_problem_class": synthetic.Hartmann, "test_problem_kwargs": {"dim": 6}, "num_trials": 50, "observe_noise_sd": True, }, ), "hartmann30_observed_noise": BenchmarkProblemRegistryEntry( factory_fn=lambda n, num_trials: embed_higher_dimension( problem=create_problem_from_botorch( test_problem_class=synthetic.Hartmann, test_problem_kwargs={"dim": 6}, num_trials=num_trials, observe_noise_sd=True, ), total_dimensionality=n, ), factory_kwargs={"n": 30, "num_trials": 25}, ), "jenatton_observed_noise": BenchmarkProblemRegistryEntry( factory_fn=get_jenatton_benchmark_problem, factory_kwargs={"num_trials": 25, "observe_noise_sd": True}, ), "constrained_gramacy_observed_noise": BenchmarkProblemRegistryEntry( factory_fn=create_problem_from_botorch, factory_kwargs={ "test_problem_class": synthetic.ConstrainedGramacy, "test_problem_kwargs": {}, "num_trials": 50, "observe_noise_sd": True, }, ), }
[docs] def get_problem( problem_key: str, registry: Mapping[str, BenchmarkProblemRegistryEntry] | None = None, **additional_kwargs: Any, ) -> BenchmarkProblem: """ Generate a benchmark problem from a key, registry, and additional arguments. Args: problem_key: The key by which a `BenchmarkProblemRegistryEntry` is looked up in the registry; a problem will then be generated from that entry and `additional_kwargs`. Note that this is not necessarily the same as the `name` attribute of the problem, and that one `problem_key` can generate several different `BenchmarkProblem`s by passing `additional_kwargs`. However, it is a good practice to maintain a 1:1 mapping between `problem_key` and the name. registry: If not provided, uses `BENCHMARK_PROBLEM_REGISTRY` to use problems defined within Ax. additional_kwargs: Additional kwargs to pass to the factory function of the `BenchmarkProblemRegistryEntry`. """ registry = BENCHMARK_PROBLEM_REGISTRY if registry is None else registry entry = registry[problem_key] kwargs = copy.copy(entry.factory_kwargs) kwargs.update(additional_kwargs) return entry.factory_fn(**kwargs)