Source code for ax.early_stopping.strategies.logical

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from functools import reduce
from typing import Any, Dict, Optional, Sequence, Set

from ax.core.experiment import Experiment
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
from ax.exceptions.core import UserInputError


[docs]class LogicalEarlyStoppingStrategy(BaseEarlyStoppingStrategy): def __init__( self, left: BaseEarlyStoppingStrategy, right: BaseEarlyStoppingStrategy, seconds_between_polls: int = 300, ) -> None: super().__init__( seconds_between_polls=seconds_between_polls, ) self.left = left self.right = right
[docs]class AndEarlyStoppingStrategy(LogicalEarlyStoppingStrategy):
[docs] def should_stop_trials_early( self, trial_indices: Set[int], experiment: Experiment, **kwargs: Dict[str, Any], ) -> Dict[int, Optional[str]]: left = self.left.should_stop_trials_early( trial_indices=trial_indices, experiment=experiment, **kwargs ) right = self.right.should_stop_trials_early( trial_indices=trial_indices, experiment=experiment, **kwargs ) return { trial: f"{left[trial]}, {right[trial]}" for trial in left if trial in right }
[docs]class OrEarlyStoppingStrategy(LogicalEarlyStoppingStrategy):
[docs] @classmethod def from_early_stopping_strategies( cls, strategies: Sequence[BaseEarlyStoppingStrategy], ) -> BaseEarlyStoppingStrategy: if len(strategies) < 1: raise UserInputError("strategies must not be empty") return reduce( lambda left, right: OrEarlyStoppingStrategy(left=left, right=right), strategies[1:], strategies[0], )
[docs] def should_stop_trials_early( self, trial_indices: Set[int], experiment: Experiment, **kwargs: Dict[str, Any], ) -> Dict[int, Optional[str]]: return { **self.left.should_stop_trials_early( trial_indices=trial_indices, experiment=experiment, **kwargs ), **self.right.should_stop_trials_early( trial_indices=trial_indices, experiment=experiment, **kwargs ), }