Source code for ax.plot.exp_utils
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, List, Optional
from ax.core import Experiment
from ax.core.metric import Metric
from ax.core.multi_type_experiment import MultiTypeExperiment
# pyre-fixme[21]: Could not find name `DataFrame` in `pandas`.
from pandas import DataFrame
def _rename_tuples(input):
if isinstance(input, tuple):
if not input[1]:
return input[0]
else:
return input[1] # "_".join(input)
else:
return input
def _compact_column(df, column):
metrics_name_list = list(df[column].columns)
temp = df[column][metrics_name_list[0]]
del df[column]
df[column] = temp
[docs]def exp_to_df(
exp: Experiment,
metrics: Optional[List[Metric]] = None,
key_components: Optional[List[str]] = None,
**kwargs: Any,
# pyre-fixme[11]: Annotation `DataFrame` is not defined as a type.
) -> DataFrame:
"""Transforms an experiment to a DataFrame. Only supports Experiment and
SimpleExperiment.
Transforms an Experiment into a dataframe with rows keyed by trial_index
and arm_name, metrics pivoted into one row.
Args:
exp: An Experiment that may have pending trials.
metrics: Override list of metrics to return. Return all metrics if None.
key_components: fields that combine to make a unique key corresponding
to rows, similar to the list of fields passed to a GROUP BY.
Defaults to ['arm_name', 'trial_index'].
**kwargs: Custom named arguments, useful for passing complex
objects from call-site to the `fetch_data` callback.
Returns:
DataFrame: A dataframe of inputs and metrics by trial and arm.
"""
key_components = key_components or ["trial_index", "arm_name"]
# Accept Experiment and SimpleExperiment
if isinstance(exp, MultiTypeExperiment):
raise ValueError("Cannot transform MultiTypeExperiments to DataFrames.")
results = exp.fetch_data(metrics, **kwargs).df
if len(results.index) == 0: # Handle empty case
return results
key_col = "-".join(key_components)
key_vals = results[key_components[0]].astype("str")
for key in key_components[1:]:
key_vals = key_vals + results[key].astype("str")
results[key_col] = key_vals
metrics_pivot = results.pivot(
index=key_col, columns="metric_name", values=["mean"] + key_components
)
for key in key_components:
_compact_column(metrics_pivot, key)
# pyre-fixme[16]: Module `pandas` has no attribute `DataFrame`.
inputs = DataFrame(
[
dict(arm.parameters, arm_name=name)
for i, (name, arm) in enumerate(exp.arms_by_name.items())
]
)
metrics_pivot = metrics_pivot.reset_index(drop=True)
results = metrics_pivot.merge(inputs, on="arm_name", copy=False)
results.rename(columns=_rename_tuples, inplace=True)
results = results.loc[:, ~results.columns.duplicated()]
return results