#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Metrics that allow to retrieve curves of partial results.
Typically used to retrieve partial learning curves of ML training jobs.
"""
from __future__ import annotations
from abc import abstractmethod, ABC
from typing import Tuple, Any, Dict, Iterable, List, Optional, Union
import pandas as pd
from ax.core.base_trial import BaseTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.map_data import MapKeyInfo, MapData
from ax.core.map_metric import MapMetric
from ax.core.metric import Metric
from ax.core.trial import Trial
from ax.utils.common.logger import get_logger
logger = get_logger(__name__)
[docs]class AbstractCurveMetric(MapMetric, ABC):
"""Metric representing (partial) learning curves of ML model training jobs."""
MAP_KEY = MapKeyInfo(key="training_rows", default_value=0.0)
def __init__(
self,
name: str,
curve_name: str,
lower_is_better: bool = True,
) -> None:
"""Inits Metric.
Args:
name: The name of the metric.
curve_name: The name of the learning curve in the training output
(there may be multiple outputs e.g. for MTML models).
lower_is_better: If True, lower curve values are considered better.
"""
super().__init__(name=name, lower_is_better=lower_is_better)
self.curve_name = curve_name
[docs] @classmethod
def is_available_while_running(cls) -> bool:
return True
[docs] def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> Data:
"""Fetch data for one trial."""
return self.fetch_trial_data_multi(trial=trial, metrics=[self], **kwargs)
[docs] @classmethod
def fetch_trial_data_multi(
cls, trial: BaseTrial, metrics: Iterable[Metric], **kwargs: Any
) -> Data:
"""Fetch multiple metrics data for one trial."""
return cls.fetch_experiment_data_multi(
experiment=trial.experiment, metrics=metrics, trials=[trial], **kwargs
)
[docs] @classmethod
def fetch_experiment_data_multi(
cls,
experiment: Experiment,
metrics: Iterable[Metric],
trials: Optional[Iterable[BaseTrial]] = None,
**kwargs: Any,
) -> Data:
"""Fetch multiple metrics data for an experiment."""
if trials is None:
trials = list(experiment.trials.values())
trials = [trial for trial in trials if trial.status.expecting_data]
if any(not isinstance(trial, Trial) for trial in trials):
raise RuntimeError(
f"Only (non-batch) Trials are supported by {cls.__name__}"
)
ids = cls.get_ids_from_trials(trials=trials)
trials_filtered, ids_filtered = _filter_trials_with_ids(trials, ids)
if len(ids_filtered) == 0:
logger.debug("Could not get ids from trials. Returning empty data.")
return MapData(map_key_infos=[cls.MAP_KEY])
all_curve_series = cls.get_curves_from_ids(ids=ids_filtered)
if all(id_ not in all_curve_series for id_ in ids_filtered):
logger.debug("Could not get curves from ids. Returning empty data.")
return MapData(map_key_infos=[cls.MAP_KEY])
dfs = []
for trial, id_ in zip(trials_filtered, ids_filtered):
if id_ not in all_curve_series:
logger.debug(f"Could not get curve data for id {id_}. Ignoring.")
continue
curve_series = all_curve_series[id_]
for m in metrics:
if m.curve_name in curve_series: # pyre-ignore [16]
cs = curve_series[m.curve_name].rename("mean") # pyre-ignore [6]
dfi = cs.reset_index().rename( # pyre-ignore [16]
columns={"index": cls.MAP_KEY.key}
)
dfi["trial_index"] = trial.index
dfi["arm_name"] = trial.arm.name
dfi["metric_name"] = m.name
dfi["sem"] = float("nan")
dfs.append(dfi.drop_duplicates())
else:
logger.debug(
f"{m.curve_name} not yet present in curves from {id_}. "
"Returning without this metric."
)
df = pd.concat(dfs, axis=0, ignore_index=True) if len(dfs) > 0 else None
return MapData(df=df, map_key_infos=[cls.MAP_KEY])
[docs] @classmethod
@abstractmethod
def get_ids_from_trials(
cls, trials: Iterable[BaseTrial]
) -> List[Optional[Union[int, str]]]:
"""Get backend run ids associated with trials.
Args:
trials: The trials for which to retrieve the associated
ids that can be used to to identify the corresponding
runs on the backend.
Returns:
A list of identifiers (ints or strings) corresponding to
the backend runs associated with the trials, in the same
order as the `trials` input.
"""
... # pragma: nocover
[docs] @classmethod
@abstractmethod
def get_curves_from_ids(
cls, ids: Iterable[Union[int, str]]
) -> Dict[Union[int, str], Dict[str, pd.Series]]:
"""Get partial result curves from backend ids.
Args:
ids: The ids of the backend runs for which to fetch the
partial result curves.
Returns:
A dictionary mapping the backend id to the partial result
curves, each of which is represented as a mapping from
the metric name to a pandas Series indexed by the progression
(which will be mapped to the `MAP_KEY` of the metric class).
E.g. if `curve_name=loss` and `MAP_KEY=training_rows`, then a
Series should look like:
training_rows (index) | loss
-----------------------|------
100 | 0.5
200 | 0.2
"""
... # pragma: nocover
def _filter_trials_with_ids(
trials: Iterable[BaseTrial], ids: List[Optional[Union[int, str]]]
) -> Tuple[Iterable[BaseTrial], List[Union[int, str]]]:
trials_filtered, ids_filtered = [], []
for trial, id_ in zip(trials, ids):
if id_ is None:
logger.debug(f"Could not get id for Trial {trial.index}. Ignoring.")
else:
trials_filtered.append(trial)
ids_filtered.append(id_)
return trials_filtered, ids_filtered