#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from typing import Any, Dict, Iterable, List, Optional
import pandas as pd
from ax.core.abstract_data import AbstractDataFrameData
from ax.core.types import TMapTrialEvaluation
[docs]class MapData(AbstractDataFrameData):
"""Class storing mapping-like results for an experiment.
Mapping-like results occur whenever a metric is reported as a collection
of results, each element corresponding to a tuple of values.
The simplest case is a sequence. For instance a time series is
a mapping from the 1-tuple `(timestamp)` to (mean, sem) results.
Another example: MultiFidelity results. This is a mapping from
`(fidelity_feature_1, ..., fidelity_feature_n)` to (mean, sem) results.
The dataframe is retrieved via the `df` property. The data can be stored
to an external store for future use by attaching it to an experiment using
`experiment.attach_data()` (this requires a description to be set.)
"""
# Note: Although the SEM (standard error of the mean) is a required column in data,
# downstream models can infer missing SEMs. Simply specify NaN as the SEM value,
# either in your Metric class or in Data explicitly.
REQUIRED_COLUMNS = {"arm_name", "metric_name", "mean", "sem"}
DEDUPLICATE_BY_COLUMNS = ["arm_name", "metric_name"]
def __init__(
self,
df: Optional[pd.DataFrame] = None,
map_keys: Optional[List[str]] = None,
description: Optional[str] = None,
) -> None:
"""Init `MapData`.
Args:
df: DataFrame with underlying data, and required columns.
map_keys: List of all elements of the Tuple that makes up the
key in MapData.
description: Human-readable description of data.
"""
if map_keys is None and df is not None:
raise ValueError(
"map_keys may only be `None` when `df` is also None "
"(an empty `MapData`)."
)
self._map_keys = map_keys or []
# Represent MapData internally as a flat `DataFrame`
# Make an empty `DataFrame with map_keys if available`
if df is None:
self._df = pd.DataFrame(
columns=self.required_columns().union(self.map_keys)
)
else:
columns = set(df.columns)
missing_columns = self.required_columns() - columns
if missing_columns:
raise ValueError(
f"Dataframe must contain required columns {list(missing_columns)}."
)
extra_columns = columns - self.supported_columns(
extra_column_names=self.map_keys
)
if extra_columns:
raise ValueError(f"Columns {list(extra_columns)} are not supported.")
df = df.dropna(axis=0, how="all").reset_index(drop=True)
df = self._safecast_df(df=df, extra_column_types=self.map_key_types)
# Reorder the columns for easier viewing
col_order = [
c for c in self.column_data_types(self.map_key_types) if c in df.columns
]
self._df = df[col_order]
self.description = description
[docs] @staticmethod
# pyre-ignore [14]: `Iterable[Data]` not a supertype of overridden parameter.
def from_multiple_data(
data: Iterable[MapData], subset_metrics: Optional[Iterable[str]] = None
) -> MapData:
"""Combines multiple data objects into one (with the concatenated
underlying dataframe).
NOTE: if one or more data objects in the iterable is of a custom
subclass of `MapData`, object of that class will be returned. If
the iterable contains multiple types of `Data`, an error will be
raised.
Args:
data: Iterable of Ax `MapData` objects to combine.
subset_metrics: If specified, combined `MapData` will only contain
metrics, names of which appear in this iterable,
in the underlying dataframe.
"""
# Filter out empty dataframes because they may not have correct map_keys.
data = [datum for datum in data if not datum.df.empty]
dfs = [datum.df for datum in data if not datum.df.empty]
if len(dfs) == 0:
return MapData()
if subset_metrics:
dfs = [df.loc[df["metric_name"].isin(subset_metrics)] for df in dfs]
# cast to list
data = list(data)
if not all((type(datum) is MapData) for datum in data):
# check if all types in iterable match the first type
raise ValueError("Non-MapData in inputs.")
# obtain map_keys of first elt in iterable (we know it's not empty)
map_keys = data[0].map_keys
if not all((set(datum.map_keys) == set(map_keys)) for datum in data):
raise ValueError("Inconsistent map_keys found in data iterable.")
else:
# if all validation is passed return concatenated data.
return MapData(df=pd.concat(dfs, axis=0, sort=True), map_keys=map_keys)
@property
def map_keys(self):
"""Return the names of fields that together make a map key.
E.g. ["timestamp"] for a timeseries, ["fidelity_param_1", "fidelity_param_2"]
for a multi-fidelity set of results.
"""
return self._map_keys
@property
def map_key_types(self):
return {map_key: Any for map_key in self.map_keys}
[docs] def update(self, new_data: MapData) -> None:
if not new_data.map_keys == self.map_keys:
raise ValueError("Inconsistent map_keys found in new data.")
self._df = self.df.append(new_data.df)
[docs] @staticmethod
def from_map_evaluations(
evaluations: Dict[str, TMapTrialEvaluation],
trial_index: int,
map_keys: Optional[List[str]] = None,
) -> MapData:
"""
Convert dict of mapped evaluations to an Ax MapData object
Args:
evaluations: Map from arm name to metric outcomes (itself a mapping
of metric names to tuples of mean and optionally a SEM).
trial_index: Trial index to which this data belongs.
map_keys: List of all elements of the Tuple that makes up the
key in MapData.
Returns:
Ax MapData object.
"""
records = [
{
"arm_name": name,
"metric_name": metric_name,
"mean": value[0] if isinstance(value, tuple) else value,
"sem": value[1] if isinstance(value, tuple) else 0.0,
"trial_index": trial_index,
**map_dict,
}
for name, map_dict_and_metrics_list in evaluations.items()
for map_dict, evaluation in map_dict_and_metrics_list
for metric_name, value in evaluation.items()
]
map_keys_list = [
list(map_dict.keys())
for name, map_dict_and_metrics_list in evaluations.items()
for map_dict, evaluation in map_dict_and_metrics_list
]
map_keys = map_keys or map_keys_list[0]
if not all((set(mk) == set(map_keys)) for mk in map_keys_list):
raise ValueError("Inconsistent map_key sets in evaluations.")
return MapData(df=pd.DataFrame(records), map_keys=map_keys)
[docs] def deduplicate_data(self, keep: str = "last") -> MapData:
"""
Deduplicate by arm_name and metric_name, and then
drop the map_keys columns entirely.
Args:
keep: Determines which duplicates (rows that differ by map key) to keep.
- first: Drop duplicates except for the first occurrence.
- last: Drop duplicates except for the last occurrence.
- False: Drop all duplicates.
Returns:
Deduplicated MapData object.
"""
if keep not in {"last", "first", False}:
raise ValueError(
"Invalid value for `keep`: must be one of {'last', 'first', False}."
)
df = self.df
map_keys = self.map_keys
if len(map_keys) > 0:
df = df.sort_values(map_keys).drop_duplicates(
MapData.DEDUPLICATE_BY_COLUMNS, keep=keep # pyre-ignore
)
return MapData(df=df, map_keys=map_keys)