Source code for ax.utils.testing.mock

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from contextlib import contextmanager, ExitStack
from functools import wraps
from typing import Any, Callable, Dict, Generator, Optional
from unittest import mock

from import fit_fully_bayesian_model_nuts
from botorch.generation.gen import minimize_with_timeout
from botorch.optim.initializers import (
from scipy.optimize import OptimizeResult
from torch import Tensor

[docs]@contextmanager def fast_botorch_optimize_context_manager( force: bool = False, ) -> Generator[None, None, None]: """A context manager to force botorch to speed up optimization. Currently, the primary tactic is to force the underlying scipy methods to stop after just one iteration. force: If True will not raise an AssertionError if no mocks are called. USE RESPONSIBLY. """ def one_iteration_minimize( *args: Any, **kwargs: Any ) -> OptimizeResult: # pyre-ignore[11] if kwargs["options"] is None: kwargs["options"] = {} kwargs["options"]["maxiter"] = 1 return minimize_with_timeout(*args, **kwargs) def minimal_gen_ics(*args: Any, **kwargs: Any) -> Tensor: kwargs["num_restarts"] = 2 kwargs["raw_samples"] = 4 return gen_batch_initial_conditions(*args, **kwargs) def minimal_gen_os_ics(*args: Any, **kwargs: Any) -> Optional[Tensor]: kwargs["num_restarts"] = 2 kwargs["raw_samples"] = 4 return gen_one_shot_kg_initial_conditions(*args, **kwargs) def minimal_fit_fully_bayesian(*args: Any, **kwargs: Any) -> None: fit_fully_bayesian_model_nuts(*args, **_get_minimal_mcmc_kwargs(**kwargs)) with ExitStack() as es: mock_generation = es.enter_context( mock.patch( "botorch.generation.gen.minimize_with_timeout", wraps=one_iteration_minimize, ) ) mock_fit = es.enter_context( mock.patch( "botorch.optim.core.minimize_with_timeout", wraps=one_iteration_minimize, ) ) mock_gen_ics = es.enter_context( mock.patch( "botorch.optim.optimize.gen_batch_initial_conditions", wraps=minimal_gen_ics, ) ) mock_gen_os_ics = es.enter_context( mock.patch( "botorch.optim.optimize.gen_one_shot_kg_initial_conditions", wraps=minimal_gen_os_ics, ) ) mock_mcmc_mbm = es.enter_context( mock.patch( "ax.models.torch.botorch_modular.utils.fit_fully_bayesian_model_nuts", wraps=minimal_fit_fully_bayesian, ) ) yield if (not force) and all( mock_.call_count < 1 for mock_ in [ mock_generation, mock_fit, mock_gen_ics, mock_gen_os_ics, mock_mcmc_mbm, ] ): raise AssertionError( "No mocks were called in the context manager. Please remove unused " "fast_botorch_optimize_context_manager()." )
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
[docs]def fast_botorch_optimize(f: Callable) -> Callable: """Wraps f in the fast_botorch_optimize_context_manager for use as a decorator.""" @wraps(f) # pyre-fixme[3]: Return type must be annotated. def inner(*args: Any, **kwargs: Any): with fast_botorch_optimize_context_manager(): return f(*args, **kwargs) return inner
[docs]@contextmanager def skip_fit_gpytorch_mll_context_manager() -> Generator[None, None, None]: """A context manager that makes `fit_gpytorch_mll` a no-op. This should only be used to speed up slow tests. """ with mock.patch( "", side_effect=lambda *args, **kwargs: args[0] ) as mock_fit: yield if mock_fit.call_count < 1: raise AssertionError( "No mocks were called in the context manager. Please remove unused " "skip_fit_gpytorch_mll_context_manager()." )
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
[docs]def skip_fit_gpytorch_mll(f: Callable) -> Callable: """Wraps f in the skip_fit_gpytorch_mll_context_manager for use as a decorator.""" @wraps(f) # pyre-fixme[3]: Return type must be annotated. def inner(*args: Any, **kwargs: Any): with skip_fit_gpytorch_mll_context_manager(): return f(*args, **kwargs) return inner
def _get_minimal_mcmc_kwargs(**kwargs: Any) -> Dict[str, Any]: kwargs["warmup_steps"] = 0 # Just get as many samples as otherwise expected. kwargs["num_samples"] = kwargs.get("num_samples", 256) // kwargs.get("thinning", 16) kwargs["thinning"] = 1 return kwargs