Source code for ax.early_stopping.strategies.base

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Any, Dict, Optional, Set, Tuple

import numpy as np
from ax.core.experiment import Experiment
from ax.core.map_data import MapData
from ax.core.observation import observations_from_map_data
from ax.exceptions.core import UnsupportedError
from ax.modelbridge.modelbridge_utils import (
    observation_data_to_array,
    observation_features_to_array,
    _unpack_observations,
)
from ax.utils.common.base import Base
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast, not_none

logger = get_logger(__name__)


[docs]@dataclass class EarlyStoppingTrainingData: """Dataclass for keeping data arrays related to model training and arm names together.""" X: np.ndarray Y: np.ndarray Yvar: np.ndarray arm_names: List[Optional[str]]
[docs]class BaseEarlyStoppingStrategy(ABC, Base): """Interface for heuristics that halt trials early, typically based on early results from that trial.""" def __init__( self, seconds_between_polls: int = 60, true_objective_metric_name: Optional[str] = None, ) -> None: """A BaseEarlyStoppingStrategy class. Args: seconds_between_polls: How often to poll the early stopping metric to evaluate whether or not the trial should be early stopped. true_objective_metric_name: The actual objective to be optimized; used in situations where early stopping uses a proxy objective (such as training loss instead of eval loss) for stopping decisions. """ if seconds_between_polls < 0: raise ValueError("`seconds_between_polls may not be less than 0.") self.seconds_between_polls = seconds_between_polls self.true_objective_metric_name = true_objective_metric_name
[docs] @abstractmethod def should_stop_trials_early( self, trial_indices: Set[int], experiment: Experiment, **kwargs: Dict[str, Any], ) -> Dict[int, Optional[str]]: """Decide whether to complete trials before evaluation is fully concluded. Typical examples include stopping a machine learning model's training, or halting the gathering of samples before some planned number are collected. Args: trial_indices: Indices of candidate trials to stop early. experiment: Experiment that contains the trials and other contextual data. Returns: A dictionary mapping trial indices that should be early stopped to (optional) messages with the associated reason. """ pass # pragma: nocover
def _check_validity_and_get_data(self, experiment: Experiment) -> Optional[MapData]: """Validity checks and returns the `MapData` used for early stopping.""" if experiment.optimization_config is None: raise UnsupportedError( # pragma: no cover "Experiment must have an optimization config in order to use an " "early stopping strategy." ) optimization_config = not_none(experiment.optimization_config) objective_name = optimization_config.objective.metric.name data = experiment.lookup_data() if data.df.empty: logger.info( f"{self.__class__.__name__} received empty data. " "Not stopping any trials." ) return None if objective_name not in set(data.df["metric_name"]): logger.info( f"{self.__class__.__name__} did not receive data " "from the objective metric. Not stopping any trials." ) return None if not isinstance(data, MapData): logger.info( f"{self.__class__.__name__} expects MapData, but the " f"data attached to experiment is of type {type(data)}. " "Not stopping any trials." ) return None data = checked_cast(MapData, data) map_keys = data.map_keys if len(list(map_keys)) > 1: logger.info( f"{self.__class__.__name__} expects MapData with a single " "map key, but the data attached to the experiment has multiple: " f"{data.map_keys}. Not stopping any trials." ) return None return data @staticmethod def _log_and_return_trial_ignored( logger: logging.Logger, trial_index: int ) -> Tuple[bool, str]: """Helper function for logging/constructing a reason when a trial should be ignored.""" logger.info( f"Trial {trial_index} should be ignored and not considered " "for early stopping." ) return False, "Specified as a trial to be ignored for early stopping." @staticmethod def _log_and_return_no_data( logger: logging.Logger, trial_index: int ) -> Tuple[bool, str]: """Helper function for logging/constructing a reason when there is no data.""" logger.info( f"There is not yet any data associated with trial {trial_index}. " "Not early stopping this trial." ) return False, "No data available to make an early stopping decision." @staticmethod def _log_and_return_min_progression( logger: logging.Logger, trial_index: int, trial_last_progression: float, min_progression: float, ) -> Tuple[bool, str]: """Helper function for logging/constructing a reason when min progression is not yet reached.""" reason = ( f"Most recent progression ({trial_last_progression}) is less than " "the specified minimum progression for early stopping " f"({min_progression}). " ) logger.info( f"Trial {trial_index}'s m{reason[1:]} Not early stopping this trial." ) return False, reason @staticmethod def _log_and_return_completed_trials( logger: logging.Logger, num_completed: int, min_curves: float ) -> Tuple[bool, str]: """Helper function for logging/constructing a reason when min number of completed trials is not yet reached.""" logger.info( f"The number of completed trials ({num_completed}) is less than " "the minimum number of curves needed for early stopping " f"({min_curves}). Not early stopping this trial." ) reason = ( f"Need {min_curves} completed trials, but only {num_completed} " "completed trials so far." ) return False, reason @staticmethod def _log_and_return_num_trials_with_data( logger: logging.Logger, trial_index: int, trial_last_progression: float, num_trials_with_data: int, min_curves: float, ) -> Tuple[bool, str]: """Helper function for logging/constructing a reason when min number of trials with data is not yet reached.""" logger.info( f"The number of trials with data ({num_trials_with_data}) " f"at trial {trial_index}'s last progression ({trial_last_progression}) " "is less than the specified minimum number for early stopping " f"({min_curves}). Not early stopping this trial." ) reason = ( f"Number of trials with data ({num_trials_with_data}) at " f"last progression ({trial_last_progression}) is less than the " f"specified minimum number for early stopping ({min_curves})." ) return False, reason
[docs]class ModelBasedEarlyStoppingStrategy(BaseEarlyStoppingStrategy): """A base class for model based early stopping strategies. Includes a helper function for processing MapData into arrays."""
[docs] def get_training_data( self, experiment: Experiment, map_data: MapData, keep_every_k_per_arm: Optional[int] = None, ) -> EarlyStoppingTrainingData: """Processes the raw (untransformed) training data into arrays for use in modeling. Args: experiment: Experiment that contains the data. map_data: The MapData from the experiment, as can be obtained by via `_check_validity_and_get_data`. keep_every_k_per_arm Subsample the learning curve by keeping every kth entry. Useful for limiting training data for modeling. Returns: An `EarlyStoppingTrainingData` that contains training data arrays X, Y, and Yvar + a list of arm names. """ if keep_every_k_per_arm is not None: map_data = _subsample_map_data( map_data=map_data, keep_every_k_per_arm=keep_every_k_per_arm ) observations = observations_from_map_data( experiment=experiment, map_data=map_data, map_keys_as_parameters=True ) obs_features, obs_data, arm_names = _unpack_observations(observations) parameters = list(experiment.search_space.parameters.keys()) outcome = not_none(experiment.optimization_config).objective.metric_names[0] X = observation_features_to_array( parameters=parameters + list(map_data.map_keys), obsf=obs_features ) Y, Yvar = observation_data_to_array( outcomes=[outcome], observation_data=obs_data ) return EarlyStoppingTrainingData(X=X, Y=Y, Yvar=Yvar, arm_names=arm_names)
def _subsample_map_data(map_data: MapData, keep_every_k_per_arm: int) -> MapData: """Helper function for keeping every kth row for each arm.""" map_df = map_data.map_df # count the rows for each arm name and keep every n keep = map_df.groupby(["arm_name"]).cumcount() keep = (keep % keep_every_k_per_arm) == 0 map_df_filtered = map_df[keep] return MapData( df=map_df_filtered, # pyre-ignore[6] map_key_infos=map_data.map_key_infos, description=map_data.description, )