#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Any, Dict, Optional, Set, Tuple
import numpy as np
from ax.core.experiment import Experiment
from ax.core.map_data import MapData
from ax.core.observation import observations_from_map_data
from ax.exceptions.core import UnsupportedError
from ax.modelbridge.modelbridge_utils import (
observation_data_to_array,
observation_features_to_array,
_unpack_observations,
)
from ax.utils.common.base import Base
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast, not_none
logger = get_logger(__name__)
[docs]@dataclass
class EarlyStoppingTrainingData:
"""Dataclass for keeping data arrays related to model training and
arm names together."""
X: np.ndarray
Y: np.ndarray
Yvar: np.ndarray
arm_names: List[Optional[str]]
[docs]class BaseEarlyStoppingStrategy(ABC, Base):
"""Interface for heuristics that halt trials early, typically based on early
results from that trial."""
def __init__(
self,
seconds_between_polls: int = 60,
true_objective_metric_name: Optional[str] = None,
) -> None:
"""A BaseEarlyStoppingStrategy class.
Args:
seconds_between_polls: How often to poll the early stopping metric to
evaluate whether or not the trial should be early stopped.
true_objective_metric_name: The actual objective to be optimized; used in
situations where early stopping uses a proxy objective (such as training
loss instead of eval loss) for stopping decisions.
"""
if seconds_between_polls < 0:
raise ValueError("`seconds_between_polls may not be less than 0.")
self.seconds_between_polls = seconds_between_polls
self.true_objective_metric_name = true_objective_metric_name
[docs] @abstractmethod
def should_stop_trials_early(
self,
trial_indices: Set[int],
experiment: Experiment,
**kwargs: Dict[str, Any],
) -> Dict[int, Optional[str]]:
"""Decide whether to complete trials before evaluation is fully concluded.
Typical examples include stopping a machine learning model's training, or
halting the gathering of samples before some planned number are collected.
Args:
trial_indices: Indices of candidate trials to stop early.
experiment: Experiment that contains the trials and other contextual data.
Returns:
A dictionary mapping trial indices that should be early stopped to
(optional) messages with the associated reason.
"""
pass # pragma: nocover
def _check_validity_and_get_data(self, experiment: Experiment) -> Optional[MapData]:
"""Validity checks and returns the `MapData` used for early stopping."""
if experiment.optimization_config is None:
raise UnsupportedError( # pragma: no cover
"Experiment must have an optimization config in order to use an "
"early stopping strategy."
)
optimization_config = not_none(experiment.optimization_config)
objective_name = optimization_config.objective.metric.name
data = experiment.lookup_data()
if data.df.empty:
logger.info(
f"{self.__class__.__name__} received empty data. "
"Not stopping any trials."
)
return None
if objective_name not in set(data.df["metric_name"]):
logger.info(
f"{self.__class__.__name__} did not receive data "
"from the objective metric. Not stopping any trials."
)
return None
if not isinstance(data, MapData):
logger.info(
f"{self.__class__.__name__} expects MapData, but the "
f"data attached to experiment is of type {type(data)}. "
"Not stopping any trials."
)
return None
data = checked_cast(MapData, data)
map_keys = data.map_keys
if len(list(map_keys)) > 1:
logger.info(
f"{self.__class__.__name__} expects MapData with a single "
"map key, but the data attached to the experiment has multiple: "
f"{data.map_keys}. Not stopping any trials."
)
return None
return data
@staticmethod
def _log_and_return_trial_ignored(
logger: logging.Logger, trial_index: int
) -> Tuple[bool, str]:
"""Helper function for logging/constructing a reason when a trial
should be ignored."""
logger.info(
f"Trial {trial_index} should be ignored and not considered "
"for early stopping."
)
return False, "Specified as a trial to be ignored for early stopping."
@staticmethod
def _log_and_return_no_data(
logger: logging.Logger, trial_index: int
) -> Tuple[bool, str]:
"""Helper function for logging/constructing a reason when there is no data."""
logger.info(
f"There is not yet any data associated with trial {trial_index}. "
"Not early stopping this trial."
)
return False, "No data available to make an early stopping decision."
@staticmethod
def _log_and_return_min_progression(
logger: logging.Logger,
trial_index: int,
trial_last_progression: float,
min_progression: float,
) -> Tuple[bool, str]:
"""Helper function for logging/constructing a reason when min progression
is not yet reached."""
reason = (
f"Most recent progression ({trial_last_progression}) is less than "
"the specified minimum progression for early stopping "
f"({min_progression}). "
)
logger.info(
f"Trial {trial_index}'s m{reason[1:]} Not early stopping this trial."
)
return False, reason
@staticmethod
def _log_and_return_completed_trials(
logger: logging.Logger, num_completed: int, min_curves: float
) -> Tuple[bool, str]:
"""Helper function for logging/constructing a reason when min number of
completed trials is not yet reached."""
logger.info(
f"The number of completed trials ({num_completed}) is less than "
"the minimum number of curves needed for early stopping "
f"({min_curves}). Not early stopping this trial."
)
reason = (
f"Need {min_curves} completed trials, but only {num_completed} "
"completed trials so far."
)
return False, reason
@staticmethod
def _log_and_return_num_trials_with_data(
logger: logging.Logger,
trial_index: int,
trial_last_progression: float,
num_trials_with_data: int,
min_curves: float,
) -> Tuple[bool, str]:
"""Helper function for logging/constructing a reason when min number of
trials with data is not yet reached."""
logger.info(
f"The number of trials with data ({num_trials_with_data}) "
f"at trial {trial_index}'s last progression ({trial_last_progression}) "
"is less than the specified minimum number for early stopping "
f"({min_curves}). Not early stopping this trial."
)
reason = (
f"Number of trials with data ({num_trials_with_data}) at "
f"last progression ({trial_last_progression}) is less than the "
f"specified minimum number for early stopping ({min_curves})."
)
return False, reason
[docs]class ModelBasedEarlyStoppingStrategy(BaseEarlyStoppingStrategy):
"""A base class for model based early stopping strategies. Includes
a helper function for processing MapData into arrays."""
[docs] def get_training_data(
self,
experiment: Experiment,
map_data: MapData,
keep_every_k_per_arm: Optional[int] = None,
) -> EarlyStoppingTrainingData:
"""Processes the raw (untransformed) training data into arrays for
use in modeling.
Args:
experiment: Experiment that contains the data.
map_data: The MapData from the experiment, as can be obtained by
via `_check_validity_and_get_data`.
keep_every_k_per_arm Subsample the learning curve by keeping every
kth entry. Useful for limiting training data for modeling.
Returns:
An `EarlyStoppingTrainingData` that contains training data arrays X, Y,
and Yvar + a list of arm names.
"""
if keep_every_k_per_arm is not None:
map_data = _subsample_map_data(
map_data=map_data, keep_every_k_per_arm=keep_every_k_per_arm
)
observations = observations_from_map_data(
experiment=experiment, map_data=map_data, map_keys_as_parameters=True
)
obs_features, obs_data, arm_names = _unpack_observations(observations)
parameters = list(experiment.search_space.parameters.keys())
outcome = not_none(experiment.optimization_config).objective.metric_names[0]
X = observation_features_to_array(
parameters=parameters + list(map_data.map_keys), obsf=obs_features
)
Y, Yvar = observation_data_to_array(
outcomes=[outcome], observation_data=obs_data
)
return EarlyStoppingTrainingData(X=X, Y=Y, Yvar=Yvar, arm_names=arm_names)
def _subsample_map_data(map_data: MapData, keep_every_k_per_arm: int) -> MapData:
"""Helper function for keeping every kth row for each arm."""
map_df = map_data.map_df
# count the rows for each arm name and keep every n
keep = map_df.groupby(["arm_name"]).cumcount()
keep = (keep % keep_every_k_per_arm) == 0
map_df_filtered = map_df[keep]
return MapData(
df=map_df_filtered, # pyre-ignore[6]
map_key_infos=map_data.map_key_infos,
description=map_data.description,
)