Source code for ax.utils.testing.backend_scheduler

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from collections import defaultdict
from typing import Dict, Optional, Set

from ax.core.base_trial import TrialStatus
from ax.core.experiment import Experiment
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.runners.simulated_backend import SimulatedBackendRunner
from ax.service.scheduler import (
    Scheduler,
    SchedulerOptions,
)
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
from ax.utils.testing.backend_simulator import BackendSimulator

logger = get_logger(__name__)


[docs]class AsyncSimulatedBackendScheduler(Scheduler): """A Scheduler that uses a simulated backend for Ax asynchronous benchmarks.""" def __init__( self, experiment: Experiment, generation_strategy: GenerationStrategy, max_pending_trials: int, options: SchedulerOptions, ) -> None: """A Scheduler for Ax asynchronous benchmarks. Args: experiment: Experiment, in which results of the optimization will be recorded. generation_strategy: Generation strategy for the optimization, describes models that will be used in optimization. max_pending_trials: The maximum number of pending trials allowed. options: `SchedulerOptions` for this Scheduler instance. """ if not isinstance(experiment.runner, SimulatedBackendRunner): raise ValueError( "experiment must have runner of type SimulatedBackendRunner attached" ) super().__init__( experiment=experiment, generation_strategy=generation_strategy, options=options, _skip_experiment_save=True, ) self.max_pending_trials = max_pending_trials @property def backend_simulator(self) -> BackendSimulator: """Get the ``BackendSimulator`` stored on the runner of the experiment. Returns: The backend simulator. """ return self.experiment.runner.simulator # pyre-ignore[16]
[docs] def poll_trial_status(self) -> Dict[TrialStatus, Set[int]]: """Poll trial status from the ``BackendSimulator``. NOTE: The ``Scheduler`` currently marks trials as running when they are created, but some of these trials may actually be in queued on the ``BackendSimulator``. Returns: A Dict mapping statuses to sets of trial indices. """ self.backend_simulator.update() trials_by_status = self.experiment.trials_by_status trial_status = defaultdict(set) for ts in (TrialStatus.CANDIDATE, TrialStatus.STAGED, TrialStatus.RUNNING): for trial in trials_by_status[ts]: t_index = trial.index status = self.backend_simulator.lookup_trial_index_status(t_index) trial_status[status].add(t_index) return dict(trial_status)
[docs] def has_capacity(self, n: int = 1) -> bool: """Whether or not there is available capacity for ``n`` trials. Args: n: The number of trials Returns: A boolean representing whether or not there is available capacity. """ return not_none(self.poll_available_capacity()) >= n
[docs] def poll_available_capacity(self) -> Optional[int]: """Get the capacity remaining after accounting for staged and running trials, with the maximum being ``max_pending_trials``. Returns: The available capacity. """ trials_by_status = self.experiment.trials_by_status num_staged = len(trials_by_status[TrialStatus.STAGED]) num_running = len(trials_by_status[TrialStatus.RUNNING]) capacity = self.max_pending_trials - (num_staged + num_running) return capacity
[docs] def should_stop_trials_early( self, trial_indices: Set[int] ) -> Dict[int, Optional[str]]: """Given a set of trial indices, decide whether or not to early-stop running trials using the ``early_stopping_strategy``. Args: trial_indices: Indices of trials to consider for early stopping. Returns: Dict with new suggested ``TrialStatus`` as keys and a set of indices of trials to update (subset of initially-passed trials) as values. """ # TODO: The status on the experiment does not distinguish between # running and queued trials, so here we check status on the # ``backend_simulator`` directly to make sure it is running. running_trials = set() skipped_trials = set() for trial_index in trial_indices: sim_trial = self.backend_simulator.get_sim_trial_by_index(trial_index) if sim_trial.sim_start_time is not None and ( # pyre-ignore[16] self.backend_simulator.time - sim_trial.sim_start_time > 0 ): running_trials.add(trial_index) else: skipped_trials.add(trial_index) if len(skipped_trials) > 0: logger.info( f"Not sending {skipped_trials} to base `should_stop_trials_early` " "because they have not been running for a positive amount of time " "on the backend simulator." ) return super().should_stop_trials_early(trial_indices=running_trials)