Source code for ax.early_stopping.strategies.logical
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional, Set
from ax.core.experiment import Experiment
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
[docs]class LogicalEarlyStoppingStrategy(BaseEarlyStoppingStrategy):
def __init__(
self,
left: BaseEarlyStoppingStrategy,
right: BaseEarlyStoppingStrategy,
seconds_between_polls: int = 60,
true_objective_metric_name: Optional[str] = None,
) -> None:
super().__init__(
seconds_between_polls=seconds_between_polls,
true_objective_metric_name=true_objective_metric_name,
)
self.left = left
self.right = right
[docs]class AndEarlyStoppingStrategy(LogicalEarlyStoppingStrategy):
[docs] def should_stop_trials_early(
self,
trial_indices: Set[int],
experiment: Experiment,
**kwargs: Dict[str, Any],
) -> Dict[int, Optional[str]]:
left = self.left.should_stop_trials_early(
trial_indices=trial_indices, experiment=experiment, **kwargs
)
right = self.right.should_stop_trials_early(
trial_indices=trial_indices, experiment=experiment, **kwargs
)
return {
trial: f"{left[trial]}, {right[trial]}" for trial in left if trial in right
}
[docs]class OrEarlyStoppingStrategy(LogicalEarlyStoppingStrategy):
[docs] def should_stop_trials_early(
self,
trial_indices: Set[int],
experiment: Experiment,
**kwargs: Dict[str, Any],
) -> Dict[int, Optional[str]]:
return {
**self.left.should_stop_trials_early(
trial_indices=trial_indices, experiment=experiment, **kwargs
),
**self.right.should_stop_trials_early(
trial_indices=trial_indices, experiment=experiment, **kwargs
),
}