Source code for ax.early_stopping.utils
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
from typing import List, Dict, Tuple
import pandas as pd
from ax.utils.common.logger import get_logger
logger = get_logger(__name__)
[docs]def align_partial_results(
df: pd.DataFrame,
progr_key: str, # progression key
metrics: List[str],
interpolation: str = "slinear",
do_forward_fill: bool = False,
# TODO: Allow normalizing progr_key (e.g. subtract min time stamp)
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
"""Helper function to align partial results with heterogeneous index
Args:
df: The DataFrame containing the raw data (in long format).
progr_key: The key of the column indexing progression (such as
the number of training examples, timestamps, etc.).
metrics: The names of the metrics to consider.
interpolation: The interpolation method used to fill missing values
(if applicable). See `pandas.DataFrame.interpolate` for
available options. Limit area is `inside`.
forward_fill: If True, performs a forward fill after interpolation.
This is useful for scalarizing learning curves when some data
is missing. For instance, suppose we obtain a curve for task_1
for progression in [a, b] and task_2 for progression in [c, d]
where b < c. Performing the forward fill on task_1 is a possible
solution.
Returns:
A two-tuple containing a dict mapping the provided metric names to the
index-normalized and interpolated mean (sem).
"""
missing_metrics = set(metrics) - set(df["metric_name"])
if missing_metrics:
raise ValueError(f"Metrics {missing_metrics} not found in input dataframe")
# select relevant metrics
df = df[df["metric_name"].isin(metrics)]
# log some information about raw data
for m in metrics:
df_m = df[df["metric_name"] == m]
if len(df_m) > 0:
logger.info(
f"Metric {m} raw data has observations from "
f"{df_m[progr_key].min()} to {df_m[progr_key].max()}."
)
else:
logger.info(f"No data from metric {m} yet.")
# drop arm names (assumes 1:1 map between trial indices and arm names)
df = df.drop("arm_name", axis=1)
# remove duplicates (same trial, metric, progr_key), which can happen
# if the same progression is erroneously reported more than once
df = df.drop_duplicates(
subset=["trial_index", "metric_name", progr_key], keep="first"
)
# set multi-index over trial, metric, and progression key
df = df.set_index(["trial_index", "metric_name", progr_key])
# sort index
df = df.sort_index(level=["trial_index", progr_key])
# drop sem if all NaN (assumes presence of sem column)
has_sem = not df["sem"].isnull().all()
if not has_sem:
df = df.drop("sem", axis=1)
# create the common index that every map result will be re-indexed w.r.t.
index_union = df.index.levels[2].unique()
# loop through (trial, metric) combos and align data
dfs_mean = defaultdict(list)
dfs_sem = defaultdict(list)
for tidx in df.index.levels[0]: # this could be slow if there are many trials
for metric in df.index.levels[1]:
# grab trial+metric sub-df and reindex to common index
df_ridx = df.loc[(tidx, metric)].reindex(index_union)
# interpolate / fill missing results
# TODO: Allow passing of additional kwargs to `interpolate`
# TODO: Allow using an arbitrary prediction model for this instead
try:
df_interp = df_ridx.interpolate(
method=interpolation, limit_area="inside"
)
if do_forward_fill:
# do forward fill (with valid observations) to handle instances
# where one task only has data for early progressions
df_interp = df_interp.fillna(method="pad")
except ValueError as e:
df_interp = df_ridx
logger.info(
f"Got exception `{e}` during interpolation. "
"Using uninterpolated values instead."
)
# renaming column to trial index, append results
dfs_mean[metric].append(df_interp["mean"].rename(tidx))
if has_sem:
dfs_sem[metric].append(df_interp["sem"].rename(tidx))
# combine results into output dataframes
dfs_mean = {metric: pd.concat(dfs, axis=1) for metric, dfs in dfs_mean.items()}
dfs_sem = {metric: pd.concat(dfs, axis=1) for metric, dfs in dfs_sem.items()}
return dfs_mean, dfs_sem