Source code for ax.early_stopping.utils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from collections import defaultdict
from typing import Dict, List, Tuple

import pandas as pd
from ax.utils.common.logger import get_logger

logger = get_logger(__name__)


[docs]def align_partial_results( df: pd.DataFrame, progr_key: str, # progression key metrics: List[str], interpolation: str = "slinear", do_forward_fill: bool = False, # TODO: Allow normalizing progr_key (e.g. subtract min time stamp) ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]: """Helper function to align partial results with heterogeneous index Args: df: The DataFrame containing the raw data (in long format). progr_key: The key of the column indexing progression (such as the number of training examples, timestamps, etc.). metrics: The names of the metrics to consider. interpolation: The interpolation method used to fill missing values (if applicable). See `pandas.DataFrame.interpolate` for available options. Limit area is `inside`. forward_fill: If True, performs a forward fill after interpolation. This is useful for scalarizing learning curves when some data is missing. For instance, suppose we obtain a curve for task_1 for progression in [a, b] and task_2 for progression in [c, d] where b < c. Performing the forward fill on task_1 is a possible solution. Returns: A two-tuple containing a dict mapping the provided metric names to the index-normalized and interpolated mean (sem). """ missing_metrics = set(metrics) - set(df["metric_name"]) if missing_metrics: raise ValueError(f"Metrics {missing_metrics} not found in input dataframe") # select relevant metrics df = df[df["metric_name"].isin(metrics)] # log some information about raw data for m in metrics: df_m = df[df["metric_name"] == m] if len(df_m) > 0: logger.info( f"Metric {m} raw data has observations from " f"{df_m[progr_key].min()} to {df_m[progr_key].max()}." ) else: logger.info(f"No data from metric {m} yet.") # drop arm names (assumes 1:1 map between trial indices and arm names) df = df.drop("arm_name", axis=1) # remove duplicates (same trial, metric, progr_key), which can happen # if the same progression is erroneously reported more than once df = df.drop_duplicates( subset=["trial_index", "metric_name", progr_key], keep="first" ) # set multi-index over trial, metric, and progression key df = df.set_index(["trial_index", "metric_name", progr_key]) # sort index df = df.sort_index(level=["trial_index", progr_key]) # drop sem if all NaN (assumes presence of sem column) has_sem = not df["sem"].isnull().all() if not has_sem: df = df.drop("sem", axis=1) # create the common index that every map result will be re-indexed w.r.t. index_union = df.index.levels[2].unique() # loop through (trial, metric) combos and align data dfs_mean = defaultdict(list) dfs_sem = defaultdict(list) for tidx in df.index.levels[0]: # this could be slow if there are many trials for metric in df.index.levels[1]: # grab trial+metric sub-df and reindex to common index df_ridx = df.loc[(tidx, metric)].reindex(index_union) # interpolate / fill missing results # TODO: Allow passing of additional kwargs to `interpolate` # TODO: Allow using an arbitrary prediction model for this instead try: df_interp = df_ridx.interpolate( method=interpolation, limit_area="inside" ) if do_forward_fill: # do forward fill (with valid observations) to handle instances # where one task only has data for early progressions df_interp = df_interp.fillna(method="pad") except ValueError as e: df_interp = df_ridx logger.info( f"Got exception `{e}` during interpolation. " "Using uninterpolated values instead." ) # renaming column to trial index, append results dfs_mean[metric].append(df_interp["mean"].rename(tidx)) if has_sem: dfs_sem[metric].append(df_interp["sem"].rename(tidx)) # combine results into output dataframes dfs_mean = {metric: pd.concat(dfs, axis=1) for metric, dfs in dfs_mean.items()} dfs_sem = {metric: pd.concat(dfs, axis=1) for metric, dfs in dfs_sem.items()} return dfs_mean, dfs_sem