Source code for ax.analysis.plotly.scatter

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Optional

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel

from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.exceptions.core import DataRequiredError, UserInputError
from plotly import express as px, graph_objects as go


[docs] class ScatterPlot(PlotlyAnalysis): """ Plotly Scatter plot for any two metrics. Each arm is represented by a single point, whose color indicates the arm's trial index. Optionally, the Pareto frontier can be shown. This plot is useful for understanding the relationship and/or tradeoff between two metrics. The DataFrame computed will contain one row per arm and the following columns: - trial_index: The trial index of the arm - arm_name: The name of the arm - X_METRIC_NAME: The observed mean of the metric specified - Y_METRIC_NAME: The observed mean of the metric specified - is_optimal: Whether the arm is on the Pareto frontier """ def __init__( self, x_metric_name: str, y_metric_name: str, show_pareto_frontier: bool = False ) -> None: """ Args: x_metric_name: The name of the metric to plot on the x-axis. y_metric_name: The name of the metric to plot on the y-axis. show_pareto_frontier: Whether to show the Pareto frontier for the two metrics. Optimization direction is inferred from the Experiment. """ self.x_metric_name = x_metric_name self.y_metric_name = y_metric_name self.show_pareto_frontier = show_pareto_frontier
[docs] def compute( self, experiment: Optional[Experiment] = None, generation_strategy: Optional[GenerationStrategyInterface] = None, ) -> PlotlyAnalysisCard: if experiment is None: raise UserInputError("ScatterPlot requires an Experiment") df = _prepare_data( experiment=experiment, x_metric_name=self.x_metric_name, y_metric_name=self.y_metric_name, ) fig = _prepare_plot( df=df, x_metric_name=self.x_metric_name, y_metric_name=self.y_metric_name, show_pareto_frontier=self.show_pareto_frontier, x_lower_is_better=experiment.metrics[self.x_metric_name].lower_is_better or False, ) return self._create_plotly_analysis_card( title=f"Observed {self.x_metric_name} vs. {self.y_metric_name}", subtitle="Compare arms by their observed metric values", level=AnalysisCardLevel.HIGH, df=df, fig=fig, )
def _prepare_title(x_metric_name: str, y_metric_name: str) -> str: """ Prepare a title for scatter plot. Args: x_metric_name: The name of the metric to plot on the x-axis. y_metric_name: The name of the metric to plot on the y-axis. """ return f"Observed {x_metric_name} vs. {y_metric_name}" def _prepare_data( experiment: Experiment, x_metric_name: str, y_metric_name: str ) -> pd.DataFrame: """ Extract the relevant data from the experiment and prepare it into a dataframe formatted in the way expected by _prepare_plot. Args: experiment: The experiment to extract data from. x_metric_name: The name of the metric to plot on the x-axis. y_metric_name: The name of the metric to plot on the y-axis. """ # Lookup the data that has already been fetched and attached to the experiment data = experiment.lookup_data().df # Filter for only rows with the relevant metric names metric_name_mask = data["metric_name"].isin([x_metric_name, y_metric_name]) filtered = data[metric_name_mask][ ["trial_index", "arm_name", "metric_name", "mean"] ] # Pivot the data so that each row is an arm and the columns are the metric names pivoted: pd.DataFrame = filtered.pivot_table( index=["trial_index", "arm_name"], columns="metric_name", values="mean" ).dropna() pivoted.reset_index(inplace=True) pivoted.columns.name = None if pivoted.empty: raise DataRequiredError( f"No observations have data for both {x_metric_name} and {y_metric_name}. " "Please ensure that the data has been fetched and attached to the " "experiment." ) # Add a column indicating whether the arm is on the Pareto frontier. This is # calculated by comparing each arm to all other arms in the experiment and # creating a mask. # If directional guidance is not specified, we assume that we intendt to maximize # the metric. x_lower_is_better: bool = experiment.metrics[x_metric_name].lower_is_better or False y_lower_is_better: bool = experiment.metrics[y_metric_name].lower_is_better or False def is_optimal(row: pd.Series) -> bool: x_mask = ( (pivoted[x_metric_name] < row[x_metric_name]) if x_lower_is_better else (pivoted[x_metric_name] > row[x_metric_name]) ) y_mask = ( (pivoted[y_metric_name] < row[y_metric_name]) if y_lower_is_better else (pivoted[y_metric_name] > row[y_metric_name]) ) return not (x_mask & y_mask).any() pivoted["is_optimal"] = pivoted.apply( is_optimal, axis=1, ) return pivoted def _prepare_plot( df: pd.DataFrame, x_metric_name: str, y_metric_name: str, show_pareto_frontier: bool, x_lower_is_better: bool, ) -> go.Figure: """ Prepare a scatter plot for the given DataFrame. Args: df: The DataFrame to plot. Must contain the following columns: - trial_index: The trial index of the arm - arm_name: The name of the arm - X_METRIC_NAME: The observed mean of some metric to plot on the x-axis - Y_METRIC_NAME: The observed mean of the metric to plot on the y-axis - is_optimal: Whether the arm is on the Pareto frontier (this can be omitted if show_pareto_frontier=False) x_metric_name: The name of the metric to plot on the x-axis y_metric_name: The name of the metric to plot on the y-axis show_pareto_frontier: Whether to draw the Pareto frontier for the two metrics x_lower_is_better: Whether the metric on the x-axis is being minimized (only relevant if show_pareto_frontier=True) """ fig = px.scatter( df, x=x_metric_name, y=y_metric_name, color="trial_index", hover_data=["trial_index", "arm_name", x_metric_name, y_metric_name], ) if show_pareto_frontier: # Must sort to ensure we draw the line through optimal points in the correct # order. frontier_df = df[df["is_optimal"]].sort_values(by=x_metric_name) fig.add_trace( go.Scatter( x=frontier_df[x_metric_name], y=frontier_df[y_metric_name], mode="lines", line_shape="hv" if x_lower_is_better else "vh", showlegend=False, ) ) return fig