Source code for ax.runners.simulated_backend

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict


from collections import defaultdict
from collections.abc import Callable, Iterable
from typing import Any

import numpy as np
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.runner import Runner
from ax.utils.testing.backend_simulator import BackendSimulator


[docs] class SimulatedBackendRunner(Runner): """Class for a runner that works with the BackendSimulator.""" def __init__( self, simulator: BackendSimulator, sample_runtime_func: Callable[[BaseTrial], float] | None = None, ) -> None: """Runner for a BackendSimulator. Args: simulator: The backend simulator. sample_runtime_func: A Callable that samples a runtime given a trial. """ self.simulator: BackendSimulator = simulator if sample_runtime_func is None: sample_runtime_func = sample_runtime_unif self.sample_runtime_func: Callable[[BaseTrial], float] = sample_runtime_func
[docs] def poll_trial_status( self, trials: Iterable[BaseTrial] ) -> dict[TrialStatus, set[int]]: """Poll trial status from the ``BackendSimulator``. NOTE: The ``Scheduler`` currently marks trials as running when they are created, but some of these trials may actually be in queued on the ``BackendSimulator``. Returns: A Dict mapping statuses to sets of trial indices. """ self.simulator.update() trial_status = defaultdict(set) for trial in trials: t_index = trial.index status = self.simulator.lookup_trial_index_status(t_index) trial_status[status].add(t_index) return dict(trial_status)
[docs] def run(self, trial: BaseTrial) -> dict[str, float]: """Start a trial on the BackendSimulator. Args: trial: Trial to deploy via the runner. Returns: Dict containing the sampled runtime of the trial. """ runtime = self.sample_runtime_func(trial) self.simulator.run_trial(trial_index=trial.index, runtime=runtime) return {"runtime": runtime}
[docs] def stop(self, trial: BaseTrial, reason: str | None = None) -> dict[str, Any]: """Stop a trial on the BackendSimulator. Args: trial: Trial to stop on the simulator. reason: A message containing information why the trial is to be stopped. Returns: A dictionary containing a single key "reason" that maps to the reason passed to the function. If no reason was given, returns an empty dictionary. """ self.simulator.stop_trial(trial.index) return {"reason": reason} if reason else {}
[docs] def sample_runtime_unif(trial: BaseTrial, low: float = 1.0, high: float = 5.0) -> float: """Return a uniform runtime in [low, high] Args: trial: Trial for which to sample runtime. low: Lower bound of uniform runtime distribution. high: Upper bound of uniform runtime distribution. Returns: A float representing the simulated trial runtime. """ return np.random.uniform(low, high)