# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from typing import Any, Dict, Generic, Iterable, Optional, Type, TypeVar
import pandas as pd
from ax.core.data import Data
from ax.core.types import TMapTrialEvaluation
from ax.exceptions.core import UnsupportedError
from ax.utils.common.base import SortableBase
from ax.utils.common.equality import dataframe_equals
from ax.utils.common.logger import get_logger
logger = get_logger(__name__)
T = TypeVar("T")
[docs]class MapKeyInfo(Generic[T], SortableBase):
"""Helper class storing map keys and auxilary info for use in MapData"""
def __init__(
self,
key: str,
default_value: T,
) -> None:
self._key = key
self._default_value = default_value
def __str__(self) -> str:
return f"MapKeyInfo({self.key}, {self.default_value})"
def __hash__(self) -> int:
return hash((self.key, self.default_value))
def _unique_id(self) -> str:
return str(self.__hash__())
@property
def key(self) -> str:
return self._key
@property
def default_value(self) -> T:
return self._default_value
@property
def value_type(self) -> Type:
return type(self._default_value)
[docs]class MapData(Data):
"""Class storing mapping-like results for an experiment.
Data is stored in a dataframe, and axilary information ((key name,
default value) pairs) are stored in a collection of MapKeyInfo objects.
Mapping-like results occur whenever a metric is reported as a collection
of results, each element corresponding to a tuple of values.
The simplest case is a sequence. For instance a time series is
a mapping from the 1-tuple `(timestamp)` to (mean, sem) results.
Another example: MultiFidelity results. This is a mapping from
`(fidelity_feature_1, ..., fidelity_feature_n)` to (mean, sem) results.
The dataframe is retrieved via the `map_df` property. The data can be stored
to an external store for future use by attaching it to an experiment using
`experiment.attach_data()` (this requires a description to be set.)
"""
DEDUPLICATE_BY_COLUMNS = ["arm_name", "metric_name"]
def __init__(
self,
df: Optional[pd.DataFrame] = None,
map_key_infos: Optional[Iterable[MapKeyInfo]] = None,
description: Optional[str] = None,
) -> None:
if map_key_infos is None and df is not None:
raise ValueError("map_key_infos may be `None` iff `df` is None.")
self._map_key_infos = map_key_infos or []
if df is None: # If df is None create an empty dataframe with appropriate cols
self._map_df = pd.DataFrame(
columns=self.required_columns().union(self.map_keys)
)
else:
columns = set(df.columns)
missing_columns = self.required_columns() - columns
if missing_columns:
raise UnsupportedError(
f"Dataframe must contain required columns {missing_columns}."
)
extra_columns = columns - self.supported_columns(
extra_column_names=self.map_keys
)
if extra_columns:
raise UnsupportedError(
f"Columns {[mki.key for mki in extra_columns]} are not supported."
)
df = df.dropna(axis=0, how="all").reset_index(drop=True)
df = self._safecast_df(df=df, extra_column_types=self.map_key_to_type)
col_order = [
c
for c in self.column_data_types(self.map_key_to_type)
if c in df.columns
]
self._map_df = df[col_order]
self.description = description
self._memo_df = None
def __eq__(self, o: MapData) -> bool:
mkis_match = set(self.map_key_infos) == set(o.map_key_infos)
dfs_match = dataframe_equals(self.map_df, o.map_df)
return mkis_match and dfs_match
@property
def true_df(self):
return self.map_df
@property
def map_key_infos(self) -> Iterable[MapKeyInfo]:
return self._map_key_infos
@property
def map_keys(self) -> Iterable[str]:
return [mki.key for mki in self.map_key_infos]
@property
def map_key_to_type(self) -> Dict[str, Type]:
return {mki.key: mki.value_type for mki in self.map_key_infos}
[docs] @staticmethod
def from_multiple_map_data(
data: Iterable[MapData],
subset_metrics: Optional[Iterable[str]] = None,
) -> MapData:
unique_map_key_infos = []
for mki in (mki for datum in data for mki in datum.map_key_infos):
if any(
mki.key == unique.key and mki.default_value != unique.default_value
for unique in unique_map_key_infos
):
logger.warning(f"MapKeyInfo conflict for {mki.key}, eliding {mki}.")
else:
if not any(mki.key == unique.key for unique in unique_map_key_infos):
# If there is a key conflict but the mkis are equal, silently do
# not add the duplicate.
unique_map_key_infos.append(mki)
df = pd.concat(
[pd.DataFrame(columns=[mki.key for mki in unique_map_key_infos])]
+ [datum.map_df for datum in data]
).fillna(value={mki.key: mki.default_value for mki in unique_map_key_infos})
subset_metrics_mask = df["metric_name"].isin(
subset_metrics if subset_metrics else df["metric_name"]
)
return MapData(df=df[subset_metrics_mask], map_key_infos=unique_map_key_infos)
[docs] @staticmethod
def from_map_evaluations(
evaluations: Dict[str, TMapTrialEvaluation],
trial_index: int,
map_key_infos: Optional[Iterable[MapKeyInfo]] = None,
) -> MapData:
records = [
{
"arm_name": name,
"metric_name": metric_name,
"mean": value[0] if isinstance(value, tuple) else value,
"sem": value[1] if isinstance(value, tuple) else None,
"trial_index": trial_index,
**map_dict,
}
for name, map_dict_and_metrics_list in evaluations.items()
for map_dict, evaluation in map_dict_and_metrics_list
for metric_name, value in evaluation.items()
]
map_keys = {
key
for name, map_dict_and_metrics_list in evaluations.items()
for map_dict, evaluation in map_dict_and_metrics_list
for key in map_dict.keys()
}
map_key_infos = map_key_infos or [
MapKeyInfo(key=key, default_value=0.0) for key in map_keys
]
if {mki.key for mki in map_key_infos} != map_keys:
raise ValueError("Inconsistent map_key sets in evaluations.")
return MapData(df=pd.DataFrame(records), map_key_infos=map_key_infos)
@property
def map_df(self) -> pd.DataFrame:
return self._map_df
@map_df.setter
def map_df(self, df: pd.DataFrame):
raise UnsupportedError(
"MapData's underlying DataFrame is immutable; create a new"
+ " MapData via `__init__` or `from_multiple_data`."
)
[docs] @staticmethod
def from_multiple_data(
data: Iterable[Data],
subset_metrics: Optional[Iterable[str]] = None,
) -> MapData:
"""Downcast instances of Data into instances of MapData with empty
map_key_infos if necessary then combine as usual (filling in empty cells with
default values).
"""
map_datas = [
MapData(df=datum.df, map_key_infos=[])
if not isinstance(datum, MapData)
else datum
for datum in data
]
return MapData.from_multiple_map_data(
data=map_datas, subset_metrics=subset_metrics
)
@property
def df(self) -> pd.DataFrame:
"""Returns a Data shaped DataFrame"""
# If map_keys is empty just return the df
if self._memo_df is not None:
return self._memo_df
if not any(True for _ in self.map_keys):
return self.map_df
self._memo_df = (
self.map_df.sort_values(list(self.map_keys))
.drop_duplicates(MapData.DEDUPLICATE_BY_COLUMNS, keep="last")
.loc[:, ~self.map_df.columns.isin(self.map_keys)]
)
return self._memo_df
[docs] @classmethod
def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]:
"""Given a dictionary, extract the properties needed to initialize the metric.
Used for storage.
"""
return {
"map_key_infos": [
MapKeyInfo(d["key"], d["default_value"]) for d in args["map_key_infos"]
]
}