#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import plotly.graph_objs as go
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.multi_type_experiment import MultiTypeExperiment
from ax.core.observation import Observation
from ax.modelbridge.cross_validation import CVResult
from ax.modelbridge.transforms.convert_metric_names import convert_mt_observations
from ax.plot.base import (
AxPlotConfig,
AxPlotTypes,
PlotData,
PlotInSampleArm,
PlotMetric,
Z,
)
from ax.plot.helper import compose_annotation
from ax.plot.scatter import _error_scatter_data, _error_scatter_trace
from ax.utils.common.typeutils import not_none
from plotly import subplots
# type alias
FloatList = List[float]
# Helper functions for plotting model fits
def _get_min_max_with_errors(
x: FloatList, y: FloatList, sd_x: FloatList, sd_y: FloatList
) -> Tuple[float, float]:
"""Get min and max of a bivariate dataset (across variables).
Args:
x: point estimate of x variable.
y: point estimate of y variable.
sd_x: standard deviation of x variable.
sd_y: standard deviation of y variable.
Returns:
min_: minimum of points, including uncertainty.
max_: maximum of points, including uncertainty.
"""
min_ = min(
min(np.array(x) - np.multiply(sd_x, Z)), min(np.array(y) - np.multiply(sd_y, Z))
)
max_ = max(
max(np.array(x) + np.multiply(sd_x, Z)), max(np.array(y) + np.multiply(sd_y, Z))
)
return min_, max_
def _diagonal_trace(min_: float, max_: float, visible: bool = True) -> Dict[str, Any]:
"""Diagonal line trace from (min_, min_) to (max_, max_).
Args:
min_: minimum to be used for starting point of line.
max_: maximum to be used for ending point of line.
visible: if True, trace is set to visible.
"""
return go.Scatter(
x=[min_, max_],
y=[min_, max_],
line=dict(color="black", width=2, dash="dot"), # noqa: C408
mode="lines",
hoverinfo="none",
visible=visible,
showlegend=False,
)
def _obs_vs_pred_dropdown_plot(
data: PlotData,
rel: bool,
show_context: bool = False,
xlabel: str = "Actual Outcome",
ylabel: str = "Predicted Outcome",
) -> go.Figure:
"""Plot a dropdown plot of observed vs. predicted values from a model.
Args:
data: a name tuple storing observed and predicted data
from a model.
rel: if True, plot metrics relative to the status quo.
show_context: Show context on hover.
xlabel: Label for x-axis.
ylabel: Label for y-axis.
"""
traces = []
metric_dropdown = []
if rel and data.status_quo_name is not None:
if show_context:
raise ValueError(
"This plot does not support both context and relativization at "
"the same time."
)
status_quo_arm = data.in_sample[data.status_quo_name]
else:
status_quo_arm = None
for i, metric in enumerate(data.metrics):
y_raw, se_raw, y_hat, se_hat = _error_scatter_data(
list(data.in_sample.values()),
y_axis_var=PlotMetric(metric, pred=True, rel=rel),
x_axis_var=PlotMetric(metric, pred=False, rel=rel),
status_quo_arm=status_quo_arm,
)
se_raw = (
[0.0 if np.isnan(se) else se for se in se_raw]
if se_raw is not None
else [0.0] * len(y_raw)
)
min_, max_ = _get_min_max_with_errors(y_raw, y_hat, se_raw, se_hat)
traces.append(_diagonal_trace(min_, max_, visible=(i == 0)))
traces.append(
_error_scatter_trace(
arms=list(data.in_sample.values()),
hoverinfo="text",
show_arm_details_on_hover=True,
show_CI=True,
show_context=show_context,
status_quo_arm=status_quo_arm,
visible=(i == 0),
x_axis_label=xlabel,
x_axis_var=PlotMetric(metric, pred=False, rel=rel),
y_axis_label=ylabel,
y_axis_var=PlotMetric(metric, pred=True, rel=rel),
)
)
# only the first two traces are visible (corresponding to first outcome
# in dropdown)
is_visible = [False] * (len(data.metrics) * 2)
is_visible[2 * i] = True
is_visible[2 * i + 1] = True
# on dropdown change, restyle
metric_dropdown.append(
{"args": ["visible", is_visible], "label": metric, "method": "restyle"}
)
updatemenus = [
{
"x": 0,
"y": 1.125,
"yanchor": "top",
"xanchor": "left",
"buttons": metric_dropdown,
},
{
"buttons": [
{
"args": [
{
"error_x.width": 4,
"error_x.thickness": 2,
"error_y.width": 4,
"error_y.thickness": 2,
}
],
"label": "Yes",
"method": "restyle",
},
{
"args": [
{
"error_x.width": 0,
"error_x.thickness": 0,
"error_y.width": 0,
"error_y.thickness": 0,
}
],
"label": "No",
"method": "restyle",
},
],
"x": 1.125,
"xanchor": "left",
"y": 0.8,
"yanchor": "middle",
},
]
layout = go.Layout(
annotations=[
{
"showarrow": False,
"text": "Show CI",
"x": 1.125,
"xanchor": "left",
"xref": "paper",
"y": 0.9,
"yanchor": "middle",
"yref": "paper",
}
],
xaxis={
"title": xlabel,
"zeroline": False,
"mirror": True,
"linecolor": "black",
"linewidth": 0.5,
},
yaxis={
"title": ylabel,
"zeroline": False,
"mirror": True,
"linecolor": "black",
"linewidth": 0.5,
},
showlegend=False,
hovermode="closest",
updatemenus=updatemenus,
width=530,
height=500,
)
return go.Figure(data=traces, layout=layout)
def _get_batch_comparison_plot_data(
observations: List[Observation],
batch_x: int,
batch_y: int,
rel: bool = False,
status_quo_name: Optional[str] = None,
) -> PlotData:
"""Compute PlotData for comparing repeated arms across trials.
Args:
observations: List of observations.
batch_x: Batch for x-axis.
batch_y: Batch for y-axis.
rel: Whether to relativize data against status_quo arm.
status_quo_name: Name of the status_quo arm.
Returns:
PlotData: a plot data object.
"""
if rel and status_quo_name is None:
raise ValueError("Experiment status quo must be set for rel=True")
x_observations = {
observation.arm_name: observation
for observation in observations
if observation.features.trial_index == batch_x
}
y_observations = {
observation.arm_name: observation
for observation in observations
if observation.features.trial_index == batch_y
}
# Assume input is well formed and metric_names are consistent across observations
metric_names = observations[0].data.metric_names
insample_data: Dict[str, PlotInSampleArm] = {}
for arm_name, x_observation in x_observations.items():
# Restrict to arms present in both trials
if arm_name not in y_observations:
continue
y_observation = y_observations[arm_name]
arm_data = {
"name": arm_name,
"y": {},
"se": {},
"parameters": x_observation.features.parameters,
"y_hat": {},
"se_hat": {},
"context_stratum": None,
}
for i, mname in enumerate(x_observation.data.metric_names):
# pyre-fixme[16]: Optional type has no attribute `__setitem__`.
arm_data["y"][mname] = x_observation.data.means[i]
# pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
# typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]`
# has no attribute `__setitem__`.
arm_data["se"][mname] = np.sqrt(x_observation.data.covariance[i][i])
for i, mname in enumerate(y_observation.data.metric_names):
# pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
# typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]`
# has no attribute `__setitem__`.
arm_data["y_hat"][mname] = y_observation.data.means[i]
# pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
# typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]`
# has no attribute `__setitem__`.
arm_data["se_hat"][mname] = np.sqrt(y_observation.data.covariance[i][i])
# Expected `str` for 2nd anonymous parameter to call `dict.__setitem__` but got
# `Optional[str]`.
# pyre-fixme[6]:
insample_data[arm_name] = PlotInSampleArm(**arm_data)
return PlotData(
metrics=metric_names,
in_sample=insample_data,
out_of_sample=None,
status_quo_name=status_quo_name,
)
def _get_cv_plot_data(cv_results: List[CVResult]) -> PlotData:
if len(cv_results) == 0:
return PlotData(
metrics=[], in_sample={}, out_of_sample=None, status_quo_name=None
)
# arm_name -> Arm data
insample_data: Dict[str, PlotInSampleArm] = {}
# Get the union of all metric_names seen in predictions
metric_names = list(
set().union(*(cv_result.predicted.metric_names for cv_result in cv_results))
)
for rid, cv_result in enumerate(cv_results):
arm_name = cv_result.observed.arm_name
arm_data = {
"name": cv_result.observed.arm_name,
"y": {},
"se": {},
"parameters": cv_result.observed.features.parameters,
"y_hat": {},
"se_hat": {},
"context_stratum": None,
}
for i, mname in enumerate(cv_result.observed.data.metric_names):
# pyre-fixme[16]: Optional type has no attribute `__setitem__`.
arm_data["y"][mname] = cv_result.observed.data.means[i]
# pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
# typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]`
# has no attribute `__setitem__`.
arm_data["se"][mname] = np.sqrt(cv_result.observed.data.covariance[i][i])
for i, mname in enumerate(cv_result.predicted.metric_names):
# pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
# typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]`
# has no attribute `__setitem__`.
arm_data["y_hat"][mname] = cv_result.predicted.means[i]
# pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
# typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]`
# has no attribute `__setitem__`.
arm_data["se_hat"][mname] = np.sqrt(cv_result.predicted.covariance[i][i])
# Expected `str` for 2nd anonymous parameter to call `dict.__setitem__` but got
# `Optional[str]`.
# pyre-fixme[6]:
insample_data[f"{arm_name}_{rid}"] = PlotInSampleArm(**arm_data)
return PlotData(
metrics=metric_names,
in_sample=insample_data,
out_of_sample=None,
status_quo_name=None,
)
[docs]def interact_empirical_model_validation(batch: BatchTrial, data: Data) -> AxPlotConfig:
"""Compare the model predictions for the batch arms against observed data.
Relies on the model predictions stored on the generator_runs of batch.
Args:
batch: Batch on which to perform analysis.
data: Observed data for the batch.
Returns:
AxPlotConfig for the plot.
"""
insample_data: Dict[str, PlotInSampleArm] = {}
metric_names = list(data.df["metric_name"].unique())
for struct in batch.generator_run_structs:
generator_run = struct.generator_run
if generator_run.model_predictions is None:
continue
for i, arm in enumerate(generator_run.arms):
arm_data = {
"name": arm.name_or_short_signature,
"y": {},
"se": {},
"parameters": arm.parameters,
"y_hat": {},
"se_hat": {},
"context_stratum": None,
}
predictions = generator_run.model_predictions
for _, row in data.df[
data.df["arm_name"] == arm.name_or_short_signature
].iterrows():
metric_name = row["metric_name"]
# pyre-fixme[16]: Optional type has no attribute `__setitem__`.
arm_data["y"][metric_name] = row["mean"]
# pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
# typing.Any], Dict[str, typing.Union[None, bool, float, int, str]],
# str]` has no attribute `__setitem__`.
arm_data["se"][metric_name] = row["sem"]
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
arm_data["y_hat"][metric_name] = predictions[0][metric_name][i]
# pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
# typing.Any], Dict[str, typing.Union[None, bool, float, int, str]],
# str]` has no attribute `__setitem__`.
arm_data["se_hat"][metric_name] = predictions[1][metric_name][
metric_name
][i]
# pyre-fixme[6]: Expected `Optional[Dict[str, Union[float, str]]]` for 1s...
insample_data[arm.name_or_short_signature] = PlotInSampleArm(**arm_data)
if not insample_data:
raise ValueError("No model predictions present on the batch.")
plot_data = PlotData(
metrics=metric_names,
in_sample=insample_data,
out_of_sample=None,
status_quo_name=None,
)
fig = _obs_vs_pred_dropdown_plot(data=plot_data, rel=False)
fig["layout"]["title"] = "Cross-validation"
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
[docs]def interact_cross_validation_plotly(
cv_results: List[CVResult], show_context: bool = True, caption: str = ""
) -> go.Figure:
"""Interactive cross-validation (CV) plotting; select metric via dropdown.
Note: uses the Plotly version of dropdown (which means that all data is
stored within the notebook).
Args:
cv_results: cross-validation results.
show_context: if True, show context on hover.
Returns a plotly.graph_objects.Figure
"""
data = _get_cv_plot_data(cv_results)
fig = _obs_vs_pred_dropdown_plot(data=data, rel=False, show_context=show_context)
current_bmargin = fig["layout"]["margin"].b or 90
caption_height = 100 * (len(caption) > 0)
fig["layout"]["margin"].b = current_bmargin + caption_height
fig["layout"]["height"] += caption_height
fig["layout"]["annotations"] += tuple(compose_annotation(caption))
fig["layout"]["title"] = "Cross-validation"
return fig
[docs]def interact_cross_validation(
cv_results: List[CVResult], show_context: bool = True
) -> AxPlotConfig:
"""Interactive cross-validation (CV) plotting; select metric via dropdown.
Note: uses the Plotly version of dropdown (which means that all data is
stored within the notebook).
Args:
cv_results: cross-validation results.
show_context: if True, show context on hover.
Returns an AxPlotConfig
"""
return AxPlotConfig(
data=interact_cross_validation_plotly(
cv_results=cv_results, show_context=show_context
),
plot_type=AxPlotTypes.GENERIC,
)
[docs]def tile_cross_validation(
cv_results: List[CVResult],
show_arm_details_on_hover: bool = True,
show_context: bool = True,
) -> AxPlotConfig:
"""Tile version of CV plots; sorted by 'best fitting' outcomes.
Plots are sorted in decreasing order using the p-value of a Fisher exact
test statistic.
Args:
cv_results: cross-validation results.
include_measurement_error: if True, include
measurement_error metrics in plot.
show_arm_details_on_hover: if True, display
parameterizations of arms on hover. Default is True.
show_context: if True (default), display context on
hover.
Returns a plotly.graph_objects.Figure
"""
data = _get_cv_plot_data(cv_results)
metrics = data.metrics
# make subplots (2 plots per row)
nrows = int(np.ceil(len(metrics) / 2))
ncols = min(len(metrics), 2)
fig = subplots.make_subplots(
rows=nrows,
cols=ncols,
print_grid=False,
subplot_titles=tuple(metrics),
horizontal_spacing=0.15,
vertical_spacing=0.30 / nrows,
)
for i, metric in enumerate(metrics):
y_hat = []
se_hat = []
y_raw = []
se_raw = []
for arm in data.in_sample.values():
y_hat.append(arm.y_hat[metric])
se_hat.append(arm.se_hat[metric])
y_raw.append(arm.y[metric])
se_raw.append(arm.se[metric])
min_, max_ = _get_min_max_with_errors(y_raw, y_hat, se_raw, se_hat)
fig.append_trace(
_diagonal_trace(min_, max_), int(np.floor(i / 2)) + 1, i % 2 + 1
)
fig.append_trace(
_error_scatter_trace(
list(data.in_sample.values()),
y_axis_var=PlotMetric(metric, pred=True, rel=False),
x_axis_var=PlotMetric(metric, pred=False, rel=False),
y_axis_label="Predicted",
x_axis_label="Actual",
hoverinfo="text",
show_arm_details_on_hover=show_arm_details_on_hover,
show_context=show_context,
),
int(np.floor(i / 2)) + 1,
i % 2 + 1,
)
# if odd number of plots, need to manually remove the last blank subplot
# generated by `subplots.make_subplots`
if len(metrics) % 2 == 1:
fig["layout"].pop("xaxis{}".format(nrows * ncols))
fig["layout"].pop("yaxis{}".format(nrows * ncols))
# allocate 400 px per plot (equal aspect ratio)
fig["layout"].update(
title="Cross-Validation", # What should I replace this with?
hovermode="closest",
width=800,
height=400 * nrows,
font={"size": 10},
showlegend=False,
)
# update subplot title size and the axis labels
for i, ant in enumerate(fig["layout"]["annotations"]):
ant["font"].update(size=12)
fig["layout"]["xaxis{}".format(i + 1)].update(
title="Actual Outcome", mirror=True, linecolor="black", linewidth=0.5
)
fig["layout"]["yaxis{}".format(i + 1)].update(
title="Predicted Outcome", mirror=True, linecolor="black", linewidth=0.5
)
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
[docs]def interact_batch_comparison(
observations: List[Observation],
experiment: Experiment,
batch_x: int,
batch_y: int,
rel: bool = False,
status_quo_name: Optional[str] = None,
) -> AxPlotConfig:
"""Compare repeated arms from two trials; select metric via dropdown.
Args:
observations: List of observations to compute comparison.
batch_x: Index of batch for x-axis.
batch_y: Index of bach for y-axis.
rel: Whether to relativize data against status_quo arm.
status_quo_name: Name of the status_quo arm.
"""
if isinstance(experiment, MultiTypeExperiment):
observations = convert_mt_observations(observations, experiment)
if not status_quo_name and experiment.status_quo:
status_quo_name = not_none(experiment.status_quo).name
plot_data = _get_batch_comparison_plot_data(
observations, batch_x, batch_y, rel=rel, status_quo_name=status_quo_name
)
fig = _obs_vs_pred_dropdown_plot(
data=plot_data,
rel=rel,
xlabel="Batch {}".format(batch_x),
ylabel="Batch {}".format(batch_y),
)
fig["layout"]["title"] = "Repeated arms across trials"
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)