Source code for ax.early_stopping.utils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from collections import defaultdict
from logging import Logger
from typing import Dict, List, Optional, Tuple

import pandas as pd
from ax.core.base_trial import TrialStatus
from ax.core.experiment import Experiment
from ax.core.map_data import MapData
from ax.utils.common.logger import get_logger
from pyre_extensions import assert_is_instance

logger: Logger = get_logger(__name__)


[docs]def align_partial_results( df: pd.DataFrame, progr_key: str, # progression key metrics: List[str], interpolation: str = "slinear", do_forward_fill: bool = False, # TODO: Allow normalizing progr_key (e.g. subtract min time stamp) ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]: """Helper function to align partial results with heterogeneous index Args: df: The DataFrame containing the raw data (in long format). progr_key: The key of the column indexing progression (such as the number of training examples, timestamps, etc.). metrics: The names of the metrics to consider. interpolation: The interpolation method used to fill missing values (if applicable). See `pandas.DataFrame.interpolate` for available options. Limit area is `inside`. forward_fill: If True, performs a forward fill after interpolation. This is useful for scalarizing learning curves when some data is missing. For instance, suppose we obtain a curve for task_1 for progression in [a, b] and task_2 for progression in [c, d] where b < c. Performing the forward fill on task_1 is a possible solution. Returns: A two-tuple containing a dict mapping the provided metric names to the index-normalized and interpolated mean (sem). """ missing_metrics = set(metrics) - set(df["metric_name"]) if missing_metrics: raise ValueError(f"Metrics {missing_metrics} not found in input dataframe") # select relevant metrics df = df[df["metric_name"].isin(metrics)] # log some information about raw data for m in metrics: df_m = df[df["metric_name"] == m] if len(df_m) > 0: logger.debug( f"Metric {m} raw data has observations from " f"{df_m[progr_key].min()} to {df_m[progr_key].max()}." ) else: logger.info(f"No data from metric {m} yet.") # drop arm names (assumes 1:1 map between trial indices and arm names) df = df.drop("arm_name", axis=1) # remove duplicates (same trial, metric, progr_key), which can happen # if the same progression is erroneously reported more than once df = df.drop_duplicates( subset=["trial_index", "metric_name", progr_key], keep="first" ) # set multi-index over trial, metric, and progression key df = df.set_index(["trial_index", "metric_name", progr_key]) # sort index df = df.sort_index() # drop sem if all NaN (assumes presence of sem column) has_sem = not df["sem"].isnull().all() if not has_sem: df = df.drop("sem", axis=1) # create the common index that every map result will be re-indexed w.r.t. index_union = df.index.levels[2].unique() # loop through (trial, metric) combos and align data dfs_mean = defaultdict(list) dfs_sem = defaultdict(list) for tidx in df.index.levels[0]: # this could be slow if there are many trials for metric in df.index.levels[1]: # grab trial+metric sub-df and reindex to common index df_ridx = df.loc[(tidx, metric)].reindex(index_union) # interpolate / fill missing results # TODO: Allow passing of additional kwargs to `interpolate` # TODO: Allow using an arbitrary prediction model for this instead try: df_interp = df_ridx.interpolate( method=interpolation, limit_area="inside" ) if do_forward_fill: # do forward fill (with valid observations) to handle instances # where one task only has data for early progressions df_interp = df_interp.fillna(method="pad") except ValueError as e: df_interp = df_ridx logger.info( f"Got exception `{e}` during interpolation. " "Using uninterpolated values instead." ) # renaming column to trial index, append results dfs_mean[metric].append(df_interp["mean"].rename(tidx)) if has_sem: dfs_sem[metric].append(df_interp["sem"].rename(tidx)) # combine results into output dataframes dfs_mean = {metric: pd.concat(dfs, axis=1) for metric, dfs in dfs_mean.items()} dfs_sem = {metric: pd.concat(dfs, axis=1) for metric, dfs in dfs_sem.items()} return dfs_mean, dfs_sem
[docs]def estimate_early_stopping_savings( experiment: Experiment, map_key: Optional[str] = None, ) -> float: """Estimate resource savings due to early stopping by considering COMPLETED and EARLY_STOPPED trials. First, use the mean of final progressions of the set completed trials as a benchmark for the length of a single trial. The savings is then estimated as: resource_savings = 1 - actual_resource_usage / (num_trials * length of single trial) Args: experiment: The experiment. map_key: The map_key to use when computing resource savings. Returns: The estimated resource savings as a fraction of total resource usage (i.e. 0.11 estimated savings indicates we would expect the experiment to have used 11% more resources without early stopping present). """ map_data = assert_is_instance(experiment.lookup_data(), MapData) # If no map_key is provided, use some arbitrary map_key in the experiment's MapData if map_key is not None: step_key = map_key elif len(map_data.map_key_infos) > 0: step_key = map_data.map_key_infos[0].key else: return 0 # Get final number of steps of each trial trial_resources = ( map_data.map_df[["trial_index", step_key]] .groupby("trial_index") .max() .reset_index() ) early_stopped_trial_idcs = experiment.trial_indices_by_status[ TrialStatus.EARLY_STOPPED ] completed_trial_idcs = experiment.trial_indices_by_status[TrialStatus.COMPLETED] # Assume that any early stopped trial would have had the mean number of steps of # the completed trials mean_completed_trial_resources = trial_resources[ trial_resources["trial_index"].isin(completed_trial_idcs) ][step_key].mean() # Calculate the steps saved per early stopped trial. If savings are estimated to be # negative assume no savings stopped_trial_resources = trial_resources[ trial_resources["trial_index"].isin(early_stopped_trial_idcs) ][step_key] saved_trial_resources = ( mean_completed_trial_resources - stopped_trial_resources ).clip(0) # Return the ratio of the total saved resources over the total resources used plus # the total saved resources return saved_trial_resources.sum() / trial_resources[step_key].sum()