Source code for ax.early_stopping.utils
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
from logging import Logger
from typing import Dict, List, Optional, Tuple
import pandas as pd
from ax.core.base_trial import TrialStatus
from ax.core.experiment import Experiment
from ax.core.map_data import MapData
from ax.utils.common.logger import get_logger
from pyre_extensions import assert_is_instance
logger: Logger = get_logger(__name__)
[docs]def align_partial_results(
df: pd.DataFrame,
progr_key: str, # progression key
metrics: List[str],
interpolation: str = "slinear",
do_forward_fill: bool = False,
# TODO: Allow normalizing progr_key (e.g. subtract min time stamp)
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
"""Helper function to align partial results with heterogeneous index
Args:
df: The DataFrame containing the raw data (in long format).
progr_key: The key of the column indexing progression (such as
the number of training examples, timestamps, etc.).
metrics: The names of the metrics to consider.
interpolation: The interpolation method used to fill missing values
(if applicable). See `pandas.DataFrame.interpolate` for
available options. Limit area is `inside`.
forward_fill: If True, performs a forward fill after interpolation.
This is useful for scalarizing learning curves when some data
is missing. For instance, suppose we obtain a curve for task_1
for progression in [a, b] and task_2 for progression in [c, d]
where b < c. Performing the forward fill on task_1 is a possible
solution.
Returns:
A two-tuple containing a dict mapping the provided metric names to the
index-normalized and interpolated mean (sem).
"""
missing_metrics = set(metrics) - set(df["metric_name"])
if missing_metrics:
raise ValueError(f"Metrics {missing_metrics} not found in input dataframe")
# select relevant metrics
df = df[df["metric_name"].isin(metrics)]
# log some information about raw data
for m in metrics:
df_m = df[df["metric_name"] == m]
if len(df_m) > 0:
logger.debug(
f"Metric {m} raw data has observations from "
f"{df_m[progr_key].min()} to {df_m[progr_key].max()}."
)
else:
logger.info(f"No data from metric {m} yet.")
# drop arm names (assumes 1:1 map between trial indices and arm names)
df = df.drop("arm_name", axis=1)
# remove duplicates (same trial, metric, progr_key), which can happen
# if the same progression is erroneously reported more than once
df = df.drop_duplicates(
subset=["trial_index", "metric_name", progr_key], keep="first"
)
# set multi-index over trial, metric, and progression key
df = df.set_index(["trial_index", "metric_name", progr_key])
# sort index
df = df.sort_index()
# drop sem if all NaN (assumes presence of sem column)
has_sem = not df["sem"].isnull().all()
if not has_sem:
df = df.drop("sem", axis=1)
# create the common index that every map result will be re-indexed w.r.t.
index_union = df.index.levels[2].unique()
# loop through (trial, metric) combos and align data
dfs_mean = defaultdict(list)
dfs_sem = defaultdict(list)
for tidx in df.index.levels[0]: # this could be slow if there are many trials
for metric in df.index.levels[1]:
# grab trial+metric sub-df and reindex to common index
df_ridx = df.loc[(tidx, metric)].reindex(index_union)
# interpolate / fill missing results
# TODO: Allow passing of additional kwargs to `interpolate`
# TODO: Allow using an arbitrary prediction model for this instead
try:
df_interp = df_ridx.interpolate(
method=interpolation, limit_area="inside"
)
if do_forward_fill:
# do forward fill (with valid observations) to handle instances
# where one task only has data for early progressions
df_interp = df_interp.fillna(method="pad")
except ValueError as e:
df_interp = df_ridx
logger.info(
f"Got exception `{e}` during interpolation. "
"Using uninterpolated values instead."
)
# renaming column to trial index, append results
dfs_mean[metric].append(df_interp["mean"].rename(tidx))
if has_sem:
dfs_sem[metric].append(df_interp["sem"].rename(tidx))
# combine results into output dataframes
dfs_mean = {metric: pd.concat(dfs, axis=1) for metric, dfs in dfs_mean.items()}
dfs_sem = {metric: pd.concat(dfs, axis=1) for metric, dfs in dfs_sem.items()}
return dfs_mean, dfs_sem
[docs]def estimate_early_stopping_savings(
experiment: Experiment,
map_key: Optional[str] = None,
) -> float:
"""Estimate resource savings due to early stopping by considering
COMPLETED and EARLY_STOPPED trials. First, use the mean of final
progressions of the set completed trials as a benchmark for the
length of a single trial. The savings is then estimated as:
resource_savings =
1 - actual_resource_usage / (num_trials * length of single trial)
Args:
experiment: The experiment.
map_key: The map_key to use when computing resource savings.
Returns:
The estimated resource savings as a fraction of total resource usage (i.e.
0.11 estimated savings indicates we would expect the experiment to have used 11%
more resources without early stopping present).
"""
map_data = assert_is_instance(experiment.lookup_data(), MapData)
# If no map_key is provided, use some arbitrary map_key in the experiment's MapData
if map_key is not None:
step_key = map_key
elif len(map_data.map_key_infos) > 0:
step_key = map_data.map_key_infos[0].key
else:
return 0
# Get final number of steps of each trial
trial_resources = (
map_data.map_df[["trial_index", step_key]]
.groupby("trial_index")
.max()
.reset_index()
)
early_stopped_trial_idcs = experiment.trial_indices_by_status[
TrialStatus.EARLY_STOPPED
]
completed_trial_idcs = experiment.trial_indices_by_status[TrialStatus.COMPLETED]
# Assume that any early stopped trial would have had the mean number of steps of
# the completed trials
mean_completed_trial_resources = trial_resources[
trial_resources["trial_index"].isin(completed_trial_idcs)
][step_key].mean()
# Calculate the steps saved per early stopped trial. If savings are estimated to be
# negative assume no savings
stopped_trial_resources = trial_resources[
trial_resources["trial_index"].isin(early_stopped_trial_idcs)
][step_key]
saved_trial_resources = (
mean_completed_trial_resources - stopped_trial_resources
).clip(0)
# Return the ratio of the total saved resources over the total resources used plus
# the total saved resources
return saved_trial_resources.sum() / trial_resources[step_key].sum()