Source code for ax.early_stopping.utils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from collections import defaultdict
from logging import Logger

import pandas as pd
from ax.core.base_trial import TrialStatus
from ax.core.experiment import Experiment
from ax.core.map_data import MapData
from ax.utils.common.logger import get_logger
from pyre_extensions import assert_is_instance

logger: Logger = get_logger(__name__)


[docs] def align_partial_results( df: pd.DataFrame, progr_key: str, # progression key metrics: list[str], interpolation: str = "slinear", do_forward_fill: bool = False, # TODO: Allow normalizing progr_key (e.g. subtract min time stamp) ) -> tuple[dict[str, pd.DataFrame], dict[str, pd.DataFrame]]: """Helper function to align partial results with heterogeneous index Args: df: The DataFrame containing the raw data (in long format). progr_key: The key of the column indexing progression (such as the number of training examples, timestamps, etc.). metrics: The names of the metrics to consider. interpolation: The interpolation method used to fill missing values (if applicable). See `pandas.DataFrame.interpolate` for available options. Limit area is `inside`. forward_fill: If True, performs a forward fill after interpolation. This is useful for scalarizing learning curves when some data is missing. For instance, suppose we obtain a curve for task_1 for progression in [a, b] and task_2 for progression in [c, d] where b < c. Performing the forward fill on task_1 is a possible solution. Returns: A two-tuple containing a dict mapping the provided metric names to the index-normalized and interpolated mean (sem). """ missing_metrics = set(metrics) - set(df["metric_name"]) if missing_metrics: raise ValueError(f"Metrics {missing_metrics} not found in input dataframe") # select relevant metrics df = df[df["metric_name"].isin(metrics)] # log some information about raw data for m in metrics: df_m = df[df["metric_name"] == m] if len(df_m) > 0: logger.debug( f"Metric {m} raw data has observations from " f"{df_m[progr_key].min()} to {df_m[progr_key].max()}." ) else: logger.info(f"No data from metric {m} yet.") # drop arm names (assumes 1:1 map between trial indices and arm names) df = df.drop("arm_name", axis=1) # remove duplicates (same trial, metric, progr_key), which can happen # if the same progression is erroneously reported more than once df = df.drop_duplicates( subset=["trial_index", "metric_name", progr_key], keep="first" ) # set multi-index over trial, metric, and progression key df = df.set_index(["trial_index", "metric_name", progr_key]) # sort index df = df.sort_index() # drop sem if all NaN (assumes presence of sem column) has_sem = not df["sem"].isnull().all() if not has_sem: df = df.drop("sem", axis=1) # create the common index that every map result will be re-indexed w.r.t. index_union = df.index.levels[2].unique() # loop through (trial, metric) combos and align data dfs_mean = defaultdict(list) dfs_sem = defaultdict(list) for tidx in df.index.levels[0]: # this could be slow if there are many trials for metric in df.index.levels[1]: # grab trial+metric sub-df and reindex to common index df_ridx = df.loc[(tidx, metric)].reindex(index_union) # interpolate / fill missing results # TODO: Allow passing of additional kwargs to `interpolate` # TODO: Allow using an arbitrary prediction model for this instead try: df_interp = df_ridx.interpolate( method=interpolation, limit_area="inside" ) if do_forward_fill: # do forward fill (with valid observations) to handle instances # where one task only has data for early progressions df_interp = df_interp.fillna(method="pad") except ValueError as e: df_interp = df_ridx logger.info( f"Got exception `{e}` during interpolation. " "Using uninterpolated values instead." ) # renaming column to trial index, append results dfs_mean[metric].append(df_interp["mean"].rename(tidx)) if has_sem: dfs_sem[metric].append(df_interp["sem"].rename(tidx)) # combine results into output dataframes dfs_mean = {metric: pd.concat(dfs, axis=1) for metric, dfs in dfs_mean.items()} dfs_sem = {metric: pd.concat(dfs, axis=1) for metric, dfs in dfs_sem.items()} return dfs_mean, dfs_sem
[docs] def estimate_early_stopping_savings( experiment: Experiment, map_key: str | None = None, ) -> float: """Estimate resource savings due to early stopping by considering COMPLETED and EARLY_STOPPED trials. First, use the mean of final progressions of the set completed trials as a benchmark for the length of a single trial. The savings is then estimated as: resource_savings = 1 - actual_resource_usage / (num_trials * length of single trial) Args: experiment: The experiment. map_key: The map_key to use when computing resource savings. Returns: The estimated resource savings as a fraction of total resource usage (i.e. 0.11 estimated savings indicates we would expect the experiment to have used 11% more resources without early stopping present). """ map_data = assert_is_instance(experiment.lookup_data(), MapData) # If no map_key is provided, use some arbitrary map_key in the experiment's MapData if map_key is not None: step_key = map_key elif len(map_data.map_key_infos) > 0: step_key = map_data.map_key_infos[0].key else: return 0 # Get final number of steps of each trial trial_resources = ( map_data.map_df[["trial_index", step_key]] .groupby("trial_index") .max() .reset_index() ) early_stopped_trial_idcs = experiment.trial_indices_by_status[ TrialStatus.EARLY_STOPPED ] completed_trial_idcs = experiment.trial_indices_by_status[TrialStatus.COMPLETED] # Assume that any early stopped trial would have had the mean number of steps of # the completed trials mean_completed_trial_resources = trial_resources[ trial_resources["trial_index"].isin(completed_trial_idcs) ][step_key].mean() # Calculate the steps saved per early stopped trial. If savings are estimated to be # negative assume no savings stopped_trial_resources = trial_resources[ trial_resources["trial_index"].isin(early_stopped_trial_idcs) ][step_key] saved_trial_resources = ( mean_completed_trial_resources - stopped_trial_resources ).clip(0) # Return the ratio of the total saved resources over the total resources used plus # the total saved resources return saved_trial_resources.sum() / trial_resources[step_key].sum()