ax.early_stopping

Strategies

Base Strategies

class ax.early_stopping.strategies.base.BaseEarlyStoppingStrategy(metric_names: Iterable[str] | None = None, seconds_between_polls: int = 300, min_progression: float | None = None, max_progression: float | None = None, min_curves: int | None = None, trial_indices_to_ignore: list[int] | None = None, normalize_progressions: bool = False)[source]

Bases: ABC, Base

Interface for heuristics that halt trials early, typically based on early results from that trial.

estimate_early_stopping_savings(experiment: Experiment, map_key: str | None = None) float[source]

Estimate early stopping savings using progressions of the MapMetric present on the EarlyStoppingConfig as a proxy for resource usage.

Parameters:
  • experiment – The experiment containing the trials and metrics used to estimate early stopping savings.

  • map_key – The name of the map_key by which to estimate early stopping savings, usually steps. If none is specified use some arbitrary map_key in the experiment’s MapData.

Returns:

The estimated resource savings as a fraction of total resource usage (i.e. 0.11 estimated savings indicates we would expect the experiment to have used 11% more resources without early stopping present)

is_eligible(trial_index: int, experiment: Experiment, df: DataFrame, map_key: str) tuple[bool, str | None][source]

Perform a series of default checks for a specific trial trial_index and determines whether it is eligible for further stopping logic:

  1. Check for ignored indices based on self.trial_indices_to_ignore

  2. Check that df contains data for the trial trial_index

  3. Check that the trial has reached self.min_progression

  4. Check that the trial hasn’t surpassed self.max_progression

Returns two elements: a boolean indicating if all checks are passed and a str indicating the reason that early stopping is not applied (None if all checks pass).

Parameters:
  • trial_index – The index of the trial to check.

  • experiment – The experiment containing the trial.

  • df – A dataframe containing the time-dependent metrics for the trial. NOTE: df should only contain data with metric_name fields that are associated with the early stopping strategy. This is usually done automatically in _check_validity_and_get_data. is_eligible might otherwise return False even though the trial is eligible, if there are secondary tracking metrics that are in df but shouldn’t be considered in the early stopping decision.

  • map_key – The name of the column containing the progression (e.g. time).

Returns:

a boolean indicating if the trial is eligible and

an optional string indicating any reason for ineligiblity.

Return type:

A tuple of two elements

is_eligible_any(trial_indices: set[int], experiment: Experiment, df: DataFrame, map_key: str | None = None) bool[source]

Perform a series of default checks for a set of trials trial_indices and determine if at least one of them is eligible for further stopping logic:

  1. Check that at least self.min_curves trials have completed`

  2. Check that at least one trial has reached self.min_progression

Returns a boolean indicating if all checks are passed.

This is useful for some situations where if no trials are eligible for stopping, then we can skip costly steps, such as model fitting, that occur before individual trials are considered for stopping.

abstract should_stop_trials_early(trial_indices: set[int], experiment: Experiment) dict[int, str | None][source]

Decide whether to complete trials before evaluation is fully concluded.

Typical examples include stopping a machine learning model’s training, or halting the gathering of samples before some planned number are collected.

Parameters:
  • trial_indices – Indices of candidate trials to stop early.

  • experiment – Experiment that contains the trials and other contextual data.

Returns:

A dictionary mapping trial indices that should be early stopped to (optional) messages with the associated reason.

class ax.early_stopping.strategies.base.EarlyStoppingTrainingData(X: ndarray[Any, dtype[_ScalarType_co]], Y: ndarray[Any, dtype[_ScalarType_co]], Yvar: ndarray[Any, dtype[_ScalarType_co]], arm_names: list[str | None])[source]

Bases: object

Dataclass for keeping data arrays related to model training and arm names together.

Parameters:
  • X – An n x d’ array of training features. d’ = d + m, where d is the dimension of the design space and m are the number of map keys. For the case of learning curves, m = 1 since we have only the number of steps as the map key.

  • Y – An n x 1 array of training observations.

  • Yvar – An n x 1 observed measurement noise.

  • arm_names – A list of length n of arm names. Useful for understanding which data come from the same arm.

X: ndarray[Any, dtype[_ScalarType_co]]
Y: ndarray[Any, dtype[_ScalarType_co]]
Yvar: ndarray[Any, dtype[_ScalarType_co]]
arm_names: list[str | None]
class ax.early_stopping.strategies.base.ModelBasedEarlyStoppingStrategy(metric_names: Iterable[str] | None = None, seconds_between_polls: int = 300, min_progression: float | None = None, max_progression: float | None = None, min_curves: int | None = None, trial_indices_to_ignore: list[int] | None = None, normalize_progressions: bool = False, min_progression_modeling: float | None = None)[source]

Bases: BaseEarlyStoppingStrategy

A base class for model based early stopping strategies. Includes a helper function for processing MapData into arrays.

get_training_data(experiment: Experiment, map_data: MapData, max_training_size: int | None = None, outcomes: Sequence[str] | None = None, parameters: list[str] | None = None) EarlyStoppingTrainingData[source]

Processes the raw (untransformed) training data into arrays for use in modeling. The trailing dimensions of X are the map keys, in their originally specified order from map_data.

Parameters:
  • experiment – Experiment that contains the data.

  • map_data – The MapData from the experiment, as can be obtained by via _check_validity_and_get_data.

  • max_training_size – Subsample the learning curve to keep the total number of data points less than this threshold. Passed to MapData’s subsample method as limit_rows_per_metric.

Returns:

An EarlyStoppingTrainingData that contains training data arrays X, Y,

and Yvar + a list of arm names.

ax.early_stopping.strategies.base.get_transform_helper_model(experiment: Experiment, data: Data, transforms: list[type[Transform]] | None = None) MapTorchModelBridge[source]

Constructs a TorchModelBridge, to be used as a helper for transforming parameters. We perform the default Cont_X_trans for parameters but do not perform any transforms on the observations.

Parameters:
  • experiment – Experiment.

  • data – Data for fitting the model.

Returns: A torch modelbridge.

Logical Strategies

class ax.early_stopping.strategies.logical.AndEarlyStoppingStrategy(left: BaseEarlyStoppingStrategy, right: BaseEarlyStoppingStrategy, seconds_between_polls: int = 300)[source]

Bases: LogicalEarlyStoppingStrategy

should_stop_trials_early(trial_indices: set[int], experiment: Experiment, **kwargs: dict[str, Any]) dict[int, str | None][source]

Decide whether to complete trials before evaluation is fully concluded.

Typical examples include stopping a machine learning model’s training, or halting the gathering of samples before some planned number are collected.

Parameters:
  • trial_indices – Indices of candidate trials to stop early.

  • experiment – Experiment that contains the trials and other contextual data.

Returns:

A dictionary mapping trial indices that should be early stopped to (optional) messages with the associated reason.

class ax.early_stopping.strategies.logical.LogicalEarlyStoppingStrategy(left: BaseEarlyStoppingStrategy, right: BaseEarlyStoppingStrategy, seconds_between_polls: int = 300)[source]

Bases: BaseEarlyStoppingStrategy

class ax.early_stopping.strategies.logical.OrEarlyStoppingStrategy(left: BaseEarlyStoppingStrategy, right: BaseEarlyStoppingStrategy, seconds_between_polls: int = 300)[source]

Bases: LogicalEarlyStoppingStrategy

classmethod from_early_stopping_strategies(strategies: Sequence[BaseEarlyStoppingStrategy]) BaseEarlyStoppingStrategy[source]
should_stop_trials_early(trial_indices: set[int], experiment: Experiment, **kwargs: dict[str, Any]) dict[int, str | None][source]

Decide whether to complete trials before evaluation is fully concluded.

Typical examples include stopping a machine learning model’s training, or halting the gathering of samples before some planned number are collected.

Parameters:
  • trial_indices – Indices of candidate trials to stop early.

  • experiment – Experiment that contains the trials and other contextual data.

Returns:

A dictionary mapping trial indices that should be early stopped to (optional) messages with the associated reason.

PercentileEarlyStoppingStrategy

class ax.early_stopping.strategies.percentile.PercentileEarlyStoppingStrategy(metric_names: Iterable[str] | None = None, seconds_between_polls: int = 300, percentile_threshold: float = 50.0, min_progression: float | None = 10, max_progression: float | None = None, min_curves: int | None = 5, trial_indices_to_ignore: list[int] | None = None, normalize_progressions: bool = False)[source]

Bases: BaseEarlyStoppingStrategy

Implements the strategy of stopping a trial if its performance falls below that of other trials at the same step.

should_stop_trials_early(trial_indices: set[int], experiment: Experiment) dict[int, str | None][source]

Stop a trial if its performance is in the bottom percentile_threshold of the trials at the same step.

Parameters:
  • trial_indices – Indices of candidate trials to consider for early stopping.

  • experiment – Experiment that contains the trials and other contextual data.

Returns:

A dictionary mapping trial indices that should be early stopped to (optional) messages with the associated reason. An empty dictionary means no suggested updates to any trial’s status.

ThresholdEarlyStoppingStrategy

class ax.early_stopping.strategies.threshold.ThresholdEarlyStoppingStrategy(metric_names: Iterable[str] | None = None, seconds_between_polls: int = 300, metric_threshold: float = 0.2, min_progression: float | None = 10, max_progression: float | None = None, min_curves: int | None = 5, trial_indices_to_ignore: list[int] | None = None, normalize_progressions: bool = False)[source]

Bases: BaseEarlyStoppingStrategy

Implements the strategy of stopping a trial if its performance doesn’t reach a pre-specified threshold by a certain progression.

should_stop_trials_early(trial_indices: set[int], experiment: Experiment) dict[int, str | None][source]

Stop a trial if its performance doesn’t reach a pre-specified threshold by min_progression.

Parameters:
  • trial_indices – Indices of candidate trials to consider for early stopping.

  • experiment – Experiment that contains the trials and other contextual data.

Returns:

A dictionary mapping trial indices that should be early stopped to (optional) messages with the associated reason. An empty dictionary means no suggested updates to any trial’s status.

Utils

ax.early_stopping.utils.align_partial_results(df: DataFrame, progr_key: str, metrics: list[str], interpolation: str = 'slinear', do_forward_fill: bool = False) tuple[dict[str, DataFrame], dict[str, DataFrame]][source]

Helper function to align partial results with heterogeneous index

Parameters:
  • df – The DataFrame containing the raw data (in long format).

  • progr_key – The key of the column indexing progression (such as the number of training examples, timestamps, etc.).

  • metrics – The names of the metrics to consider.

  • interpolation – The interpolation method used to fill missing values (if applicable). See pandas.DataFrame.interpolate for available options. Limit area is inside.

  • forward_fill – If True, performs a forward fill after interpolation. This is useful for scalarizing learning curves when some data is missing. For instance, suppose we obtain a curve for task_1 for progression in [a, b] and task_2 for progression in [c, d] where b < c. Performing the forward fill on task_1 is a possible solution.

Returns:

A two-tuple containing a dict mapping the provided metric names to the index-normalized and interpolated mean (sem).

ax.early_stopping.utils.estimate_early_stopping_savings(experiment: Experiment, map_key: str | None = None) float[source]

Estimate resource savings due to early stopping by considering COMPLETED and EARLY_STOPPED trials. First, use the mean of final progressions of the set completed trials as a benchmark for the length of a single trial. The savings is then estimated as:

resource_savings =

1 - actual_resource_usage / (num_trials * length of single trial)

Parameters:
  • experiment – The experiment.

  • map_key – The map_key to use when computing resource savings.

Returns:

The estimated resource savings as a fraction of total resource usage (i.e. 0.11 estimated savings indicates we would expect the experiment to have used 11% more resources without early stopping present).