Source code for ax.early_stopping.strategies.percentile

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Any, Dict, Optional, Set, Tuple

import numpy as np
import pandas as pd
from ax.core.base_trial import TrialStatus
from ax.core.experiment import Experiment
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
from ax.early_stopping.utils import align_partial_results
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none

logger = get_logger(__name__)


[docs]class PercentileEarlyStoppingStrategy(BaseEarlyStoppingStrategy): """Implements the strategy of stopping a trial if its performance falls below that of other trials at the same step.""" def __init__( self, seconds_between_polls: int = 60, true_objective_metric_name: Optional[str] = None, percentile_threshold: float = 50.0, min_progression: float = 0.1, min_curves: float = 5, trial_indices_to_ignore: Optional[List[int]] = None, ) -> None: """Construct a PercentileEarlyStoppingStrategy instance. Args: true_objective_metric_name: The actual objective to be optimized; used in situations where early stopping uses a proxy objective (such as training loss instead of eval loss) for stopping decisions. percentile_threshold: Falling below this threshold compared to other trials at the same step will stop the run. Must be between 0.0 and 100.0. e.g. if percentile_threshold=25.0, the bottom 25% of trials are stopped. Note that "bottom" here is determined based on performance, not absolute values; if `minimize` is False, then "bottom" actually refers to the top trials in terms of metric value. min_progression: Only stop trials if the latest progression value (e.g. timestamp, epochs, training data used) is greater than this threshold. Prevents stopping prematurely before enough data is gathered to make a decision. The default value (10) is reasonable when we want early stopping to start after 10 epochs. min_curves: There must be `min_curves` number of completed trials and `min_curves` number of trials with curve data to make a stopping decision (i.e., even if there are enough completed trials but not all of them are correctly returning data, then do not apply early stopping). trial_indices_to_ignore: Trial indices that should not be early stopped. """ super().__init__( seconds_between_polls=seconds_between_polls, true_objective_metric_name=true_objective_metric_name, ) self.percentile_threshold = percentile_threshold self.min_progression = min_progression self.min_curves = min_curves self.trial_indices_to_ignore = trial_indices_to_ignore
[docs] def should_stop_trials_early( self, trial_indices: Set[int], experiment: Experiment, **kwargs: Dict[str, Any], ) -> Dict[int, Optional[str]]: """Stop a trial if its performance is in the bottom `percentile_threshold` of the trials at the same step. Args: trial_indices: Indices of candidate trials to consider for early stopping. experiment: Experiment that contains the trials and other contextual data. Returns: A dictionary mapping trial indices that should be early stopped to (optional) messages with the associated reason. An empty dictionary means no suggested updates to any trial's status. """ data = self._check_validity_and_get_data(experiment=experiment) if data is None: # don't stop any trials if we don't get data back return {} optimization_config = not_none(experiment.optimization_config) objective_name = optimization_config.objective.metric.name map_key = next(iter(data.map_keys)) minimize = optimization_config.objective.minimize df = data.map_df try: metric_to_aligned_means, _ = align_partial_results( df=df, progr_key=map_key, metrics=[objective_name], ) except Exception as e: logger.warning( f"Encountered exception while aligning data: {e}. " "Not early stopping any trials." ) return {} aligned_means = metric_to_aligned_means[objective_name] decisions = { trial_index: self.should_stop_trial_early( trial_index=trial_index, experiment=experiment, df=aligned_means, minimize=minimize, ) for trial_index in trial_indices } return { trial_index: reason for trial_index, (should_stop, reason) in decisions.items() if should_stop }
[docs] def should_stop_trial_early( self, trial_index: int, experiment: Experiment, df: pd.DataFrame, minimize: bool, ) -> Tuple[bool, Optional[str]]: """Stop a trial if its performance is in the bottom `percentile_threshold` of the trials at the same step. Args: trial_index: Indices of candidate trial to stop early. experiment: Experiment that contains the trials and other contextual data. df: Dataframe of partial results after applying interpolation, filtered to objective metric. minimize: Whether objective value is being minimized. Returns: A tuple `(should_stop, reason)`, where `should_stop` is `True` iff the trial should be stopped, and `reason` is an (optional) string providing information on why the trial should or should not be stopped. """ logger.info(f"Considering trial {trial_index} for early stopping.") # check for ignored indices if self.trial_indices_to_ignore is not None: if trial_index in self.trial_indices_to_ignore: return self._log_and_return_trial_ignored( logger=logger, trial_index=trial_index ) # check for no data if trial_index not in df or len(not_none(df[trial_index].dropna())) == 0: return self._log_and_return_no_data(logger=logger, trial_index=trial_index) # check for min progression trial_last_progression = not_none(df[trial_index].dropna()).index.max() logger.info( f"Last progression of Trial {trial_index} is {trial_last_progression}." ) if trial_last_progression < self.min_progression: return self._log_and_return_min_progression( logger=logger, trial_index=trial_index, trial_last_progression=trial_last_progression, min_progression=self.min_progression, ) # dropna() here will exclude trials that have not made it to the # last progression of the trial under consideration, and therefore # can't be included in the comparison data_at_last_progression = df.loc[trial_last_progression].dropna() logger.info( "Early stopping objective at last progression is:\n" f"{data_at_last_progression}." ) # check for enough completed trials num_completed = len(experiment.trial_indices_by_status[TrialStatus.COMPLETED]) if num_completed < self.min_curves: return self._log_and_return_completed_trials( logger=logger, num_completed=num_completed, min_curves=self.min_curves ) # check for enough number of trials with data if len(data_at_last_progression) < self.min_curves: return self._log_and_return_num_trials_with_data( logger=logger, trial_index=trial_index, trial_last_progression=trial_last_progression, num_trials_with_data=len(data_at_last_progression), min_curves=self.min_curves, ) # percentile early stopping logic percentile_threshold = ( 100.0 - self.percentile_threshold if minimize else self.percentile_threshold ) percentile_value = np.percentile(data_at_last_progression, percentile_threshold) trial_objective_value = data_at_last_progression[trial_index] should_early_stop = ( trial_objective_value > percentile_value if minimize else trial_objective_value < percentile_value ) comp = "worse" if should_early_stop else "better" reason = ( f"Trial objective value {trial_objective_value} is {comp} than " f"{percentile_threshold:.1f}-th percentile ({percentile_value}) " "across comparable trials." ) logger.info( f"Early stopping decision for {trial_index}: {should_early_stop}. " f"Reason: {reason}" ) return should_early_stop, reason