Source code for ax.core.abstract_data

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from abc import ABC, abstractmethod
from hashlib import md5
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Type

import numpy as np
import pandas as pd
from ax.utils.common.base import Base
from ax.utils.common.serialization import serialize_init_args, extract_init_args

    from ax.core.observation import ObservationData

[docs]class AbstractData(ABC, Base): """Abstract Base Class for storing data for an experiment.""" def __init__( self, description: Optional[str] = None, ) -> None: """Init Data. Args: description: Human-readable description of data. """ self.description = description
[docs] @staticmethod @abstractmethod def from_multiple_data( data: Iterable[AbstractData], subset_metrics: Optional[Iterable[str]] = None ) -> AbstractData: """Combines multiple data objects into one (with the concatenated underlying dataframe). NOTE: if one or more data objects in the iterable is of a custom subclass of this Data type, object of that class will be returned. If the iterable contains incompatible types of `Data`, an error will be raised. Args: data: Iterable of data objects to combine. subset_metrics: If specified, combined data will only contain metrics that appear in this iterable. """ pass # pragma: no cover
@property @abstractmethod def metric_names(self) -> Set[str]: """Set of metrics contained in this data.""" pass # pragma: no cover
[docs] @abstractmethod def to_observation_data(self) -> List["ObservationData"]: """Convert to ObservationData""" pass # pragma: no cover
[docs]class AbstractDataFrameData(AbstractData, Base): """Abstract Base Class for storing `DataFrame`-backed Data for an experiment. Attributes: df: DataFrame with underlying data, and required columns. description: Human-readable description of data. """ REQUIRED_COLUMNS = {} COLUMN_DATA_TYPES = { "arm_name": str, "metric_name": str, "mean": np.float64, "sem": np.float64, "trial_index": np.int64, } def __init__( self, df: Optional[pd.DataFrame] = None, description: Optional[str] = None, ) -> None: """Init Data. Args: description: Human-readable description of data. """ super().__init__(description=description) @classmethod def _safecast_df( cls, df: pd.DataFrame, extra_column_types: Optional[Dict[str, Type]] = None ) -> pd.DataFrame: """Function for safely casting df to standard data types. Needed because numpy does not support NaNs in integer arrays. Allows `Any` to be specified as a type, and will skip casting for that column. Args: df: DataFrame to safe-cast. extra_column_types: types of columns only specified at instantiation-time. Returns: safe_df: DataFrame cast to standard dtypes. """ extra_column_types = extra_column_types or {} dtype = { # Pandas timestamp handlng is weird col: "datetime64[ns]" if coltype is pd.Timestamp else coltype for col, coltype in cls.column_data_types( extra_column_types=extra_column_types ).items() if col in df.columns.values and not ( cls.column_data_types(extra_column_types)[col] is np.int64 and df.loc[:, col].isnull().any() ) and not (coltype is Any) } # pyre-fixme[7]: Expected `DataFrame` but got # `Union[pd.core.frame.DataFrame, pd.core.series.Series]`. return df.astype(dtype=dtype)
[docs] @classmethod def required_columns(cls) -> Set[str]: """Names of required columns.""" return cls.REQUIRED_COLUMNS
[docs] @classmethod def supported_columns( cls, extra_column_names: Optional[Iterable[str]] = None ) -> Set[str]: """Names of columns supported (but not necessarily required) by this class.""" extra_column_names = set(extra_column_names or []) extra_column_types: Dict[str, Any] = {name: Any for name in extra_column_names} return cls.REQUIRED_COLUMNS.union( cls.column_data_types(extra_column_types=extra_column_types) )
[docs] @classmethod def column_data_types( cls, extra_column_types: Optional[Dict[str, Type]] = None ) -> Dict[str, Type]: """Type specification for all supported columns.""" extra_column_types = extra_column_types or {} return {**cls.COLUMN_DATA_TYPES, **extra_column_types}
@property def df(self) -> pd.DataFrame: """Return a flattened `DataFrame` representation of this data's metrics.""" # pyre-ignore [16]: Undefined attribute. _df will be defined in subclasses. return self._df @property def df_hash(self) -> str: """Compute hash of pandas DataFrame. This first serializes the DataFrame and computes the md5 hash on the resulting string. Note that this may cause performance issue for very large DataFrames. Args: df: The DataFrame for which to compute the hash. Returns str: The hash of the DataFrame. """ # pyre-fixme[16]: `Optional` has no attribute `encode`. return md5(self.df.to_json().encode("utf-8")).hexdigest() @property def metric_names(self) -> Set[str]: """Set of metric names that appear in the underlying dataframe of this object. """ return set() if self.df.empty else set(self.df["metric_name"].values)
[docs] def get_filtered_results(self, **filters: Dict[str, Any]) -> pd.DataFrame: """Return filtered subset of data. Args: filter: Column names and values they must match. Returns df: The filtered DataFrame. """ df = self.df.copy() columns = df.columns for colname, value in filters.items(): if colname not in columns: raise ValueError( f"{colname} not in the set of columns: {columns}" f"in this data object of type: {str(type(self))}." ) df = df[df[colname] == value] return df
[docs] def to_observation_data(self) -> List["ObservationData"]: """Convert to ObservationData""" raise NotImplementedError() # pragma: no cover
[docs] @classmethod def serialize_init_args(cls, data: AbstractDataFrameData) -> Dict[str, Any]: """Serialize the class-dependent properties needed to initialize this Data. Used for storage and to help construct new similar Data. All kwargs other than "dataframe" and "description" are considered structural. """ return serialize_init_args(object=data, exclude_fields=["df", "description"])
[docs] @classmethod def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]: """Given a dictionary, extract the properties needed to initialize the metric. Used for storage. """ return extract_init_args(args=args, class_=cls)
[docs] def copy_structure_with_df(self, df: pd.DataFrame) -> AbstractDataFrameData: """Serialize the structural properties needed to initialize this Data. Used for storage and to help construct new similar Data. All kwargs other than "dataframe" and "description" are considered structural. """ cls = type(self) # pyre-ignore[45]: Cannot insantiate abstract class return cls(df=df, **cls.serialize_init_args(self))