Source code for ax.plot.parallel_coordinates

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List
from typing import Optional

import pandas as pd
from ax.core.experiment import Experiment
from ax.plot.base import AxPlotConfig
from ax.plot.base import AxPlotTypes
from ax.service.utils.report_utils import _get_shortest_unique_suffix_dict
from ax.service.utils.report_utils import exp_to_df
from plotly import express as px, graph_objs as go


[docs]def prepare_experiment_for_plotting( experiment: Experiment, ignored_names: Optional[List[str]] = None, ) -> pd.DataFrame: """Strip variables not desired in the final plot and truncate names for readability Args: experiment: Experiment containing trials to plot ignored_names: Metrics present in the experiment data we wish to exclude from the final plot. By default we ignore ["generation_method", "trial_status", "arm_name"] Returns: df.DataFrame: data frame ready for ingestion by plotly """ ignored_names = ( ["generation_method", "trial_status", "arm_name"] if ignored_names is None else ignored_names ) df = exp_to_df(experiment) dropped = df.drop(ignored_names, axis=1) renamed = dropped.rename( # pyre-fixme[6] Expected `typing.Union[ # typing.Callable[[Optional[typing.Hashable]], Optional[typing.Hashable]], # None, typing.Mapping[Optional[typing.Hashable], typing.Any]]` for 1st # parameter `columns` to call `pd.core.frame.DataFrame.rename` but got # `typing.Dict[str, str]`. columns=_get_shortest_unique_suffix_dict([str(c) for c in dropped.columns]) ) return renamed
[docs]def plot_parallel_coordinates_plotly( experiment: Experiment, ignored_names: Optional[List[str]] = None ) -> go.Figure: """Plot trials as a parallel coordinates graph Args: experiment: Experiment containing trials to plot ignored_names: Metrics present in the experiment data we wish to exclude from the final plot. By default we ignore ["generation_method", "trial_status", "arm_name"] Returns: go.Figure: parellel coordinates plot of all experiment trials """ df = prepare_experiment_for_plotting( experiment=experiment, ignored_names=ignored_names ) return px.parallel_coordinates(df, color=df.columns[0])
[docs]def plot_parallel_coordinates( experiment: Experiment, ignored_names: Optional[List[str]] = None ) -> AxPlotConfig: """Plot trials as a parallel coordinates graph Args: experiment: Experiment containing trials to plot ignored_names: Metrics present in the experiment data we wish to exclude from the final plot. By default we ignore ["generation_method", "trial_status", "arm_name"] Returns: AxPlotConfig: parellel coordinates plot of all experiment trials """ fig = plot_parallel_coordinates_plotly( experiment=experiment, ignored_names=ignored_names ) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)