Source code for ax.runners.simulated_backend
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from collections import defaultdict
from collections.abc import Callable, Iterable
from typing import Any
import numpy as np
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.runner import Runner
from ax.utils.testing.backend_simulator import BackendSimulator
[docs]
class SimulatedBackendRunner(Runner):
"""Class for a runner that works with the BackendSimulator."""
def __init__(
self,
simulator: BackendSimulator,
sample_runtime_func: Callable[[BaseTrial], float] | None = None,
) -> None:
"""Runner for a BackendSimulator.
Args:
simulator: The backend simulator.
sample_runtime_func: A Callable that samples a runtime given a trial.
"""
self.simulator: BackendSimulator = simulator
if sample_runtime_func is None:
sample_runtime_func = sample_runtime_unif
self.sample_runtime_func: Callable[[BaseTrial], float] = sample_runtime_func
[docs]
def poll_trial_status(
self, trials: Iterable[BaseTrial]
) -> dict[TrialStatus, set[int]]:
"""Poll trial status from the ``BackendSimulator``. NOTE: The ``Scheduler``
currently marks trials as running when they are created, but some of these
trials may actually be in queued on the ``BackendSimulator``.
Returns:
A Dict mapping statuses to sets of trial indices.
"""
self.simulator.update()
trial_status = defaultdict(set)
for trial in trials:
t_index = trial.index
status = self.simulator.lookup_trial_index_status(t_index)
trial_status[status].add(t_index)
return dict(trial_status)
[docs]
def run(self, trial: BaseTrial) -> dict[str, float]:
"""Start a trial on the BackendSimulator.
Args:
trial: Trial to deploy via the runner.
Returns:
Dict containing the sampled runtime of the trial.
"""
runtime = self.sample_runtime_func(trial)
self.simulator.run_trial(trial_index=trial.index, runtime=runtime)
return {"runtime": runtime}
[docs]
def stop(self, trial: BaseTrial, reason: str | None = None) -> dict[str, Any]:
"""Stop a trial on the BackendSimulator.
Args:
trial: Trial to stop on the simulator.
reason: A message containing information why the trial is to be stopped.
Returns:
A dictionary containing a single key "reason" that maps to the reason
passed to the function. If no reason was given, returns an empty dictionary.
"""
self.simulator.stop_trial(trial.index)
return {"reason": reason} if reason else {}
[docs]
def sample_runtime_unif(trial: BaseTrial, low: float = 1.0, high: float = 5.0) -> float:
"""Return a uniform runtime in [low, high]
Args:
trial: Trial for which to sample runtime.
low: Lower bound of uniform runtime distribution.
high: Upper bound of uniform runtime distribution.
Returns:
A float representing the simulated trial runtime.
"""
return np.random.uniform(low, high)