#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import numbers
import warnings
from collections import OrderedDict
from collections.abc import Callable, Iterable, Sequence
from logging import Logger
from typing import Any
import numpy as np
import plotly.graph_objs as go
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.observation import Observation, ObservationFeatures
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.registry import Models
from ax.plot.base import (
AxPlotConfig,
AxPlotTypes,
CI_OPACITY,
DECIMALS,
PlotInSampleArm,
PlotMetric,
PlotOutOfSampleArm,
Z,
)
from ax.plot.color import BLUE_SCALE, COLORS, DISCRETE_COLOR_SCALE, rgba
from ax.plot.helper import (
_format_CI,
_format_dict,
_wrap_metric,
arm_name_to_sort_key,
arm_name_to_tuple,
get_plot_data,
infer_is_relative,
resize_subtitles,
TNullableGeneratorRunsDict,
)
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast_optional
from ax.utils.stats.statstools import relativize
from plotly import subplots
logger: Logger = get_logger(__name__)
# type aliases
Traces = list[dict[str, Any]]
def _error_scatter_data(
arms: Iterable[PlotInSampleArm | PlotOutOfSampleArm],
y_axis_var: PlotMetric,
x_axis_var: PlotMetric | None = None,
status_quo_arm: PlotInSampleArm | None = None,
) -> tuple[list[float], list[float] | None, list[float], list[float]]:
y_metric_key = "y_hat" if y_axis_var.pred else "y"
y_sd_key = "se_hat" if y_axis_var.pred else "se"
arm_names = [a.name for a in arms]
y = [getattr(a, y_metric_key).get(y_axis_var.metric, np.nan) for a in arms]
y_se = [getattr(a, y_sd_key).get(y_axis_var.metric, np.nan) for a in arms]
# Delta method if relative to status quo arm
if y_axis_var.rel:
if status_quo_arm is None:
raise ValueError("`status_quo_arm` cannot be None for relative effects.")
y_rel, y_se_rel = relativize(
means_t=y,
sems_t=y_se,
mean_c=getattr(status_quo_arm, y_metric_key).get(y_axis_var.metric),
sem_c=getattr(status_quo_arm, y_sd_key).get(y_axis_var.metric),
as_percent=True,
)
y = y_rel.tolist()
y_se = y_se_rel.tolist()
# x can be metric for a metric or arm names
if x_axis_var is None:
x = arm_names
x_se = None
else:
x_metric_key = "y_hat" if x_axis_var.pred else "y"
x_sd_key = "se_hat" if x_axis_var.pred else "se"
x = [getattr(a, x_metric_key).get(x_axis_var.metric, np.nan) for a in arms]
x_se = [getattr(a, x_sd_key).get(x_axis_var.metric, np.nan) for a in arms]
if x_axis_var.rel:
# Delta method if relative to status quo arm
x_rel, x_se_rel = relativize(
means_t=x,
sems_t=x_se,
mean_c=getattr(status_quo_arm, x_metric_key).get(x_axis_var.metric),
sem_c=getattr(status_quo_arm, x_sd_key).get(x_axis_var.metric),
as_percent=True,
)
x = x_rel.tolist()
x_se = x_se_rel.tolist()
return x, x_se, y, y_se
def _error_scatter_trace(
arms: Sequence[PlotInSampleArm | PlotOutOfSampleArm],
y_axis_var: PlotMetric,
x_axis_var: PlotMetric | None = None,
y_axis_label: str | None = None,
x_axis_label: str | None = None,
status_quo_arm: PlotInSampleArm | None = None,
show_CI: bool = True,
name: str = "In-sample",
color: tuple[int] = COLORS.STEELBLUE.value,
visible: bool = True,
legendgroup: str | None = None,
showlegend: bool = True,
hoverinfo: str = "text",
show_arm_details_on_hover: bool = True,
show_context: bool = False,
arm_noun: str = "arm",
color_parameter: str | None = None,
color_metric: str | None = None,
) -> dict[str, Any]:
"""Plot scatterplot with error bars.
Args:
arms (List[Union[PlotInSampleArm, PlotOutOfSampleArm]]):
a list of in-sample or out-of-sample arms.
In-sample arms have observed data, while out-of-sample arms
just have predicted data. As a result,
when passing out-of-sample arms, pred must be True.
y_axis_var: name of metric for y-axis, along with whether
it is observed or predicted.
x_axis_var: name of metric for x-axis,
along with whether it is observed or predicted. If None, arm names
are automatically used.
y_axis_label: custom label to use for y axis.
If None, use metric name from `y_axis_var`.
x_axis_label: custom label to use for x axis.
If None, use metric name from `x_axis_var` if that is not None.
status_quo_arm: the status quo
arm. Necessary for relative metrics.
show_CI: if True, plot confidence intervals.
name: name of trace. Default is "In-sample".
color: color as rgb tuple. Default is
(128, 177, 211), which corresponds to COLORS.STEELBLUE.
visible: if True, trace is visible (default).
legendgroup: group for legends.
showlegend: if True, legend if rendered.
hoverinfo: information to show on hover. Default is
custom text.
show_arm_details_on_hover: if True, display
parameterizations of arms on hover. Default is True.
show_context: if True and show_arm_details_on_hover,
context will be included in the hover.
arm_noun: noun to use instead of "arm" (e.g. group)
color_parameter: color points according to the specified parameter,
cannot be used together with color_metric.
color_metric: color points according to the specified metric,
cannot be used together with color_parameter.
"""
if color_metric and color_parameter:
raise RuntimeError(
"color_metric and color_parameter cannot be used at the same time!"
)
if (color_metric or color_parameter) and not all(
isinstance(arm, PlotInSampleArm) for arm in arms
):
raise RuntimeError("Color coding currently only works with in-sample arms!")
# Opportunistically sort if arm names are in {trial}_{arm} format
arms = sorted(arms, key=lambda a: arm_name_to_sort_key(a.name), reverse=True)
x, x_se, y, y_se = _error_scatter_data(
arms=arms,
y_axis_var=y_axis_var,
x_axis_var=x_axis_var,
status_quo_arm=status_quo_arm,
)
labels = []
colors = []
arm_names = [a.name for a in arms]
# No relativization if no x variable.
rel_x = x_axis_var.rel if x_axis_var else False
rel_y = y_axis_var.rel
for i in range(len(arm_names)):
heading = f"<b>{arm_noun.title()} {arm_names[i]}</b><br>"
x_lab = (
"{name}: {estimate}{perc} {ci}<br>".format(
name=x_axis_var.metric if x_axis_label is None else x_axis_label,
estimate=(
round(x[i], DECIMALS) if isinstance(x[i], numbers.Number) else x[i]
),
ci="" if x_se is None else _format_CI(x[i], x_se[i], rel_x),
perc="%" if rel_x else "",
)
if x_axis_var is not None
else ""
)
y_lab = "{name}: {estimate}{perc} {ci}<br>".format(
name=y_axis_var.metric if y_axis_label is None else y_axis_label,
estimate=(
round(y[i], DECIMALS) if isinstance(y[i], numbers.Number) else y[i]
),
ci="" if y_se is None else _format_CI(y[i], y_se[i], rel_y),
perc="%" if rel_y else "",
)
parameterization = (
_format_dict(arms[i].parameters, "Parameterization")
if show_arm_details_on_hover
else ""
)
if color_parameter:
colors.append(arms[i].parameters[color_parameter])
elif color_metric:
# Must be PlotInSampleArm here if no error raised previously
# pyre-ignore[16]: `PlotOutOfSampleArm` has no attribute `y`
colors.append(arms[i].y[color_metric])
context = (
# Expected `Dict[str, Optional[Union[bool, float, str]]]` for 1st anonymous
# parameter to call `ax.plot.helper._format_dict` but got
# `Optional[Dict[str, Union[float, str]]]`.
# pyre-fixme[6]:
_format_dict(arms[i].context_stratum, "Context")
if show_arm_details_on_hover
and show_context # noqa W503
and arms[i].context_stratum # noqa W503
else ""
)
labels.append(
"{arm_name}<br>{xlab}{ylab}{param_blob}{context}".format(
arm_name=heading,
xlab=x_lab,
ylab=y_lab,
param_blob=parameterization,
context=context,
)
)
i += 1
if color_metric or color_parameter:
rgba_blue_scale = [rgba(c) for c in BLUE_SCALE]
marker = {
"color": colors,
"colorscale": rgba_blue_scale,
"colorbar": {"title": color_metric or color_parameter},
"showscale": True,
}
else:
marker = {"color": rgba(color)}
trace = go.Scatter(
x=x,
y=y,
marker=marker,
mode="markers",
name=name,
text=labels,
hoverinfo=hoverinfo,
)
if show_CI:
if x_se is not None:
trace.update(
error_x={
"type": "data",
"array": np.multiply(x_se, Z),
"color": rgba(color, CI_OPACITY),
}
)
if y_se is not None:
trace.update(
error_y={
"type": "data",
"array": np.multiply(y_se, Z),
"color": rgba(color, CI_OPACITY),
}
)
if visible is not None:
trace.update(visible=visible)
if legendgroup is not None:
trace.update(legendgroup=legendgroup)
if showlegend is not None:
trace.update(showlegend=showlegend)
# pyre-fixme[7]: Expected `Dict[str, typing.Any]` but got `Scatter`.
return trace
def _multiple_metric_traces(
model: ModelBridge,
metric_x: str,
metric_y: str,
generator_runs_dict: TNullableGeneratorRunsDict,
rel_x: bool,
rel_y: bool,
fixed_features: ObservationFeatures | None = None,
data_selector: Callable[[Observation], bool] | None = None,
color_parameter: str | None = None,
color_metric: str | None = None,
) -> Traces:
"""Plot traces for multiple metrics given a model and metrics.
Args:
model: model to draw predictions from.
metric_x: metric to plot on the x-axis.
metric_y: metric to plot on the y-axis.
generator_runs_dict: a mapping from
generator run name to generator run.
rel_x: if True, use relative effects on metric_x.
rel_y: if True, use relative effects on metric_y.
fixed_features: Fixed features to use when making model predictions.
data_selector: Function for selecting observations for plotting.
color_parameter: color points according to the specified parameter,
cannot be used together with color_metric.
color_metric: color points according to the specified metric,
cannot be used together with color_parameter.
"""
metric_names = {metric_x, metric_y}
if color_metric is not None:
metric_names.add(color_metric)
plot_data, _, _ = get_plot_data(
model,
generator_runs_dict if generator_runs_dict is not None else {},
metric_names,
fixed_features=fixed_features,
data_selector=data_selector,
)
status_quo_arm = (
None
if plot_data.status_quo_name is None
else plot_data.in_sample.get(plot_data.status_quo_name)
)
traces = [
_error_scatter_trace(
list(plot_data.in_sample.values()),
x_axis_var=PlotMetric(metric_x, pred=False, rel=rel_x),
y_axis_var=PlotMetric(metric_y, pred=False, rel=rel_y),
status_quo_arm=status_quo_arm,
visible=False,
color_parameter=color_parameter,
color_metric=color_metric,
),
_error_scatter_trace(
list(plot_data.in_sample.values()),
x_axis_var=PlotMetric(metric_x, pred=True, rel=rel_x),
y_axis_var=PlotMetric(metric_y, pred=True, rel=rel_y),
status_quo_arm=status_quo_arm,
visible=True,
color_parameter=color_parameter,
color_metric=color_metric,
),
]
# TODO: Figure out if there's a better way to color code out-of-sample points
for i, (generator_run_name, cand_arms) in enumerate(
(plot_data.out_of_sample or {}).items(), start=1
):
traces.append(
_error_scatter_trace(
list(cand_arms.values()),
x_axis_var=PlotMetric(metric_x, pred=True, rel=rel_x),
y_axis_var=PlotMetric(metric_y, pred=True, rel=rel_y),
status_quo_arm=status_quo_arm,
name=generator_run_name,
color=DISCRETE_COLOR_SCALE[i],
)
)
return traces
[docs]
def plot_multiple_metrics(
model: ModelBridge,
metric_x: str,
metric_y: str,
generator_runs_dict: TNullableGeneratorRunsDict = None,
rel_x: bool = True,
rel_y: bool = True,
fixed_features: ObservationFeatures | None = None,
data_selector: Callable[[Observation], bool] | None = None,
color_parameter: str | None = None,
color_metric: str | None = None,
**kwargs: Any,
) -> AxPlotConfig:
"""Plot raw values or predictions of two metrics for arms.
All arms used in the model are included in the plot. Additional
arms can be passed through the `generator_runs_dict` argument.
Args:
model: model to draw predictions from.
metric_x: metric to plot on the x-axis.
metric_y: metric to plot on the y-axis.
generator_runs_dict: a mapping from
generator run name to generator run.
rel_x: if True, use relative effects on metric_x.
rel_y: if True, use relative effects on metric_y.
data_selector: Function for selecting observations for plotting.
color_parameter: color points according to the specified parameter,
cannot be used together with color_metric.
color_metric: color points according to the specified metric,
cannot be used together with color_parameter.
"""
if color_parameter or color_metric:
layout_offset_x = 0.15
else:
layout_offset_x = 0
rel = checked_cast_optional(bool, kwargs.get("rel"))
if rel is not None:
warnings.warn(
"Use `rel_x` and `rel_y` instead of `rel`.",
DeprecationWarning,
stacklevel=2,
)
rel_x = rel
rel_y = rel
traces = _multiple_metric_traces(
model,
metric_x,
metric_y,
generator_runs_dict,
rel_x=rel_x,
rel_y=rel_y,
fixed_features=fixed_features,
data_selector=data_selector,
color_parameter=color_parameter,
color_metric=color_metric,
)
num_cand_traces = len(generator_runs_dict) if generator_runs_dict is not None else 0
layout = go.Layout(
title="Objective Tradeoffs",
hovermode="closest",
updatemenus=[
{
"x": 1.25 + layout_offset_x,
"y": 0.67,
"buttons": [
{
"args": [
{
"error_x.width": 4,
"error_x.thickness": 2,
"error_y.width": 4,
"error_y.thickness": 2,
}
],
"label": "Yes",
"method": "restyle",
},
{
"args": [
{
"error_x.width": 0,
"error_x.thickness": 0,
"error_y.width": 0,
"error_y.thickness": 0,
}
],
"label": "No",
"method": "restyle",
},
],
"yanchor": "middle",
"xanchor": "left",
},
{
"x": 1.25 + layout_offset_x,
"y": 0.57,
"buttons": [
{
"args": [
{"visible": ([False, True] + [True] * num_cand_traces)}
],
"label": "Modeled",
"method": "restyle",
},
{
"args": [
{"visible": ([True, False] + [False] * num_cand_traces)}
],
"label": "Observed",
"method": "restyle",
},
],
"yanchor": "middle",
"xanchor": "left",
},
],
annotations=[
{
"x": 1.18 + layout_offset_x,
"y": 0.7,
"xref": "paper",
"yref": "paper",
"text": "Show CI",
"showarrow": False,
"yanchor": "middle",
},
{
"x": 1.18 + layout_offset_x,
"y": 0.6,
"xref": "paper",
"yref": "paper",
"text": "Type",
"showarrow": False,
"yanchor": "middle",
},
],
xaxis={
"title": metric_x + (" (%)" if rel else ""),
"zeroline": True,
"zerolinecolor": "red",
},
yaxis={
"title": metric_y + (" (%)" if rel else ""),
"zeroline": True,
"zerolinecolor": "red",
},
width=800,
height=600,
font={"size": 10},
legend={"x": 1 + layout_offset_x},
)
fig = go.Figure(data=traces, layout=layout)
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
[docs]
def plot_objective_vs_constraints(
model: ModelBridge,
objective: str,
subset_metrics: list[str] | None = None,
generator_runs_dict: TNullableGeneratorRunsDict = None,
rel: bool = True,
infer_relative_constraints: bool | None = False,
fixed_features: ObservationFeatures | None = None,
data_selector: Callable[[Observation], bool] | None = None,
color_parameter: str | None = None,
color_metric: str | None = None,
label_dict: dict[str, str] | None = None,
) -> AxPlotConfig:
"""Plot the tradeoff between an objective and all other metrics in a model.
All arms used in the model are included in the plot. Additional
arms can be passed through via the `generator_runs_dict` argument.
Fixed features input can be used to override fields of the insample arms
when making model predictions.
Args:
model: model to draw predictions from.
objective: metric to optimize. Plotted on the x-axis.
subset_metrics: list of metrics to plot on the y-axes
if need a subset of all metrics in the model.
generator_runs_dict: a mapping from
generator run name to generator run.
rel: if True, use relative effects. Default is True.
infer_relative_constraints: if True, read relative spec from model's
optimization config. Absolute constraints will not be relativized;
relative ones will be.
Objectives will respect the `rel` parameter.
Metrics that are not constraints will be relativized.
fixed_features: Fixed features to use when making model predictions.
data_selector: Function for selecting observations for plotting.
color_parameter: color points according to the specified parameter,
cannot be used together with color_metric.
color_metric: color points according to the specified metric,
cannot be used together with color_parameter.
label_dict: A dictionary that maps the label to
an alias to be used in the plot.
"""
if color_parameter or color_metric:
layout_offset_x = 0.15
else:
layout_offset_x = 0
if subset_metrics is not None:
metrics = subset_metrics
else:
metrics = [m for m in model.metric_names if m != objective]
if not label_dict:
_check_label_lengths(metrics + [objective])
metric_dropdown = []
if infer_relative_constraints:
rels = infer_is_relative(model, metrics, non_constraint_rel=rel)
if rel:
rels[objective] = True
else:
rels[objective] = False
else:
if rel:
rels = {metric: True for metric in metrics}
rels[objective] = True
else:
rels = {metric: False for metric in metrics}
rels[objective] = False
# set plotted data to the first outcome
plot_data = _multiple_metric_traces(
model,
objective,
metrics[0],
generator_runs_dict,
rel_x=rels[objective],
rel_y=rels[metrics[0]],
fixed_features=fixed_features,
data_selector=data_selector,
color_parameter=color_parameter,
color_metric=color_metric,
)
for metric in metrics:
otraces = _multiple_metric_traces(
model,
objective,
metric,
generator_runs_dict,
rel_x=rels[objective],
rel_y=rels[metric],
fixed_features=fixed_features,
data_selector=data_selector,
color_parameter=color_parameter,
color_metric=color_metric,
)
# Current version of Plotly does not allow updating the yaxis label
# on dropdown (via relayout) simultaneously with restyle
metric_dropdown.append(
{
"args": [
{
"y": [t["y"] for t in otraces],
"error_y.array": [t["error_y"]["array"] for t in otraces],
"text": [_replace_str(t["text"], label_dict) for t in otraces],
},
{
"yaxis.title": _replace_str(metric, label_dict)
+ (" (%)" if rels[metric] else ""),
},
],
"label": _replace_str(metric, label_dict),
"method": "update",
}
)
num_cand_traces = len(generator_runs_dict) if generator_runs_dict is not None else 0
layout = go.Layout(
title="Objective Tradeoffs",
hovermode="closest",
updatemenus=[
{
"x": 1.25 + layout_offset_x,
"y": 0.62,
"buttons": [
{
"args": [
{
"error_x.width": 4,
"error_x.thickness": 2,
"error_y.width": 4,
"error_y.thickness": 2,
}
],
"label": "Yes",
"method": "restyle",
},
{
"args": [
{
"error_x.width": 0,
"error_x.thickness": 0,
"error_y.width": 0,
"error_y.thickness": 0,
}
],
"label": "No",
"method": "restyle",
},
],
"yanchor": "middle",
"xanchor": "left",
},
{
"x": 1.25 + layout_offset_x,
"y": 0.52,
"buttons": [
{
"args": [
{"visible": ([False, True] + [True] * num_cand_traces)}
],
"label": "Modeled",
"method": "restyle",
},
{
"args": [
{"visible": ([True, False] + [False] * num_cand_traces)}
],
"label": "Observed",
"method": "restyle",
},
],
"yanchor": "middle",
"xanchor": "left",
},
{
"x": 1.25 + layout_offset_x,
"y": 0.72,
"yanchor": "middle",
"xanchor": "left",
"buttons": metric_dropdown,
},
],
annotations=[
{
"x": 1.18 + layout_offset_x,
"y": 0.72,
"xref": "paper",
"yref": "paper",
"text": "Y-Axis",
"showarrow": False,
"yanchor": "middle",
},
{
"x": 1.18 + layout_offset_x,
"y": 0.62,
"xref": "paper",
"yref": "paper",
"text": "Show CI",
"showarrow": False,
"yanchor": "middle",
},
{
"x": 1.18 + layout_offset_x,
"y": 0.52,
"xref": "paper",
"yref": "paper",
"text": "Type",
"showarrow": False,
"yanchor": "middle",
},
],
xaxis={
"title": _replace_str(objective, label_dict)
+ (" (%)" if rels[objective] else ""),
"zeroline": True,
"zerolinecolor": "red",
},
yaxis={
"title": _replace_str(metrics[0], label_dict)
+ (" (%)" if rels[metrics[0]] else ""),
"zeroline": True,
"zerolinecolor": "red",
},
width=900,
height=600,
font={"size": 10},
legend={"x": 1 + layout_offset_x},
)
fig = go.Figure(data=plot_data, layout=layout)
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def _replace_str(input_str: str, str_dict: dict[str, str] | None = None) -> str:
"""Utility function to replace a string based on a mapping dictionary.
Args:
input_str: Input string to map.
str_dict: Mapping dictionary.
"""
return str_dict[input_str] if (str_dict and input_str in str_dict) else input_str
def _check_label_lengths(labels: list[str]) -> None:
"""Utility function to check label length and provide a warning for long
labels pointing to a mapping that can be used to override them.
Args:
labels: List of labels to check.
"""
max_len = 30
long_labels = [label for label in labels if len(label) > max_len]
if long_labels:
logger.info(
"This plot may be malformed due to long labels. You"
" can override long labels by passing a label_dict dictionary"
" to plotting functions that support it.\nHere's a list of labels"
f" longer than {max_len} characters:\n" + "\n".join(long_labels)
)
[docs]
def lattice_multiple_metrics(
model: ModelBridge,
generator_runs_dict: TNullableGeneratorRunsDict = None,
rel: bool = True,
show_arm_details_on_hover: bool = False,
data_selector: Callable[[Observation], bool] | None = None,
) -> AxPlotConfig:
"""Plot raw values or predictions of combinations of two metrics for arms.
Args:
model: model to draw predictions from.
generator_runs_dict: a mapping from
generator run name to generator run.
rel: if True, use relative effects. Default is True.
show_arm_details_on_hover: if True, display
parameterizations of arms on hover. Default is False.
data_selector: Function for selecting observations for plotting.
"""
metrics = model.metric_names
fig = subplots.make_subplots(
rows=len(metrics),
cols=len(metrics),
print_grid=False,
shared_xaxes=False,
shared_yaxes=False,
)
plot_data, _, _ = get_plot_data(
model,
generator_runs_dict if generator_runs_dict is not None else {},
metrics,
data_selector=data_selector,
)
status_quo_arm = (
None
if plot_data.status_quo_name is None
else plot_data.in_sample.get(plot_data.status_quo_name)
)
# iterate over all combinations of metrics and generate scatter traces
for i, o1 in enumerate(metrics, start=1):
for j, o2 in enumerate(metrics, start=1):
if o1 != o2:
# in-sample observed and predicted
obs_insample_trace = _error_scatter_trace(
list(plot_data.in_sample.values()),
x_axis_var=PlotMetric(o1, pred=False, rel=rel),
y_axis_var=PlotMetric(o2, pred=False, rel=rel),
status_quo_arm=status_quo_arm,
showlegend=(i == 1 and j == 2),
legendgroup="In-sample",
visible=False,
show_arm_details_on_hover=show_arm_details_on_hover,
)
predicted_insample_trace = _error_scatter_trace(
list(plot_data.in_sample.values()),
x_axis_var=PlotMetric(o1, pred=True, rel=rel),
y_axis_var=PlotMetric(o2, pred=True, rel=rel),
status_quo_arm=status_quo_arm,
legendgroup="In-sample",
showlegend=(i == 1 and j == 2),
visible=True,
show_arm_details_on_hover=show_arm_details_on_hover,
)
fig.append_trace(obs_insample_trace, j, i)
fig.append_trace(predicted_insample_trace, j, i)
# iterate over models here
for k, (generator_run_name, cand_arms) in enumerate(
(plot_data.out_of_sample or {}).items(), start=1
):
fig.append_trace(
_error_scatter_trace(
list(cand_arms.values()),
x_axis_var=PlotMetric(o1, pred=True, rel=rel),
y_axis_var=PlotMetric(o2, pred=True, rel=rel),
status_quo_arm=status_quo_arm,
name=generator_run_name,
color=DISCRETE_COLOR_SCALE[k],
showlegend=(i == 1 and j == 2),
legendgroup=generator_run_name,
show_arm_details_on_hover=show_arm_details_on_hover,
),
j,
i,
)
else:
# if diagonal is set to True, add box plots
fig.append_trace(
go.Box(
y=[arm.y[o1] for arm in plot_data.in_sample.values()],
name=None,
marker={"color": rgba(COLORS.STEELBLUE.value)},
showlegend=False,
legendgroup="In-sample",
visible=False,
hoverinfo="none",
),
j,
i,
)
fig.append_trace(
go.Box(
y=[arm.y_hat[o1] for arm in plot_data.in_sample.values()],
name=None,
marker={"color": rgba(COLORS.STEELBLUE.value)},
showlegend=False,
legendgroup="In-sample",
hoverinfo="none",
),
j,
i,
)
for k, (generator_run_name, cand_arms) in enumerate(
(plot_data.out_of_sample or {}).items(), start=1
):
fig.append_trace(
go.Box(
y=[arm.y_hat[o1] for arm in cand_arms.values()],
name=None,
marker={"color": rgba(DISCRETE_COLOR_SCALE[k])},
showlegend=False,
legendgroup=generator_run_name,
hoverinfo="none",
),
j,
i,
)
fig["layout"].update(
height=800,
width=960,
font={"size": 10},
hovermode="closest",
legend={
"orientation": "h",
"x": 0,
"y": 1.05,
"xanchor": "left",
"yanchor": "middle",
},
updatemenus=[
{
"x": 0.35,
"y": 1.08,
"xanchor": "left",
"yanchor": "middle",
"buttons": [
{
"args": [
{
"error_x.width": 0,
"error_x.thickness": 0,
"error_y.width": 0,
"error_y.thickness": 0,
}
],
"label": "No",
"method": "restyle",
},
{
"args": [
{
"error_x.width": 4,
"error_x.thickness": 2,
"error_y.width": 4,
"error_y.thickness": 2,
}
],
"label": "Yes",
"method": "restyle",
},
],
},
{
"x": 0.1,
"y": 1.08,
"xanchor": "left",
"yanchor": "middle",
"buttons": [
{
"args": [
{
"visible": (
(
[False, True]
+ [True] * len(plot_data.out_of_sample or {})
)
* (len(metrics) ** 2)
)
}
],
"label": "Modeled",
"method": "restyle",
},
{
"args": [
{
"visible": (
(
[True, False]
+ [False] * len(plot_data.out_of_sample or {})
)
* (len(metrics) ** 2)
)
}
],
"label": "In-sample",
"method": "restyle",
},
],
},
],
annotations=[
{
"x": 0.02,
"y": 1.1,
"xref": "paper",
"yref": "paper",
"text": "Type",
"showarrow": False,
"yanchor": "middle",
"xanchor": "left",
},
{
"x": 0.30,
"y": 1.1,
"xref": "paper",
"yref": "paper",
"text": "Show CI",
"showarrow": False,
"yanchor": "middle",
"xanchor": "left",
},
],
)
# add metric names to axes - add to each subplot if boxplots on the
# diagonal and axes are not shared; else, add to the leftmost y-axes
# and bottom x-axes.
for i, o in enumerate(metrics):
pos_x = len(metrics) * len(metrics) - len(metrics) + i + 1
pos_y = 1 + (len(metrics) * i)
fig["layout"][f"xaxis{pos_x}"].update(
title=_wrap_metric(o), titlefont={"size": 10}
)
fig["layout"][f"yaxis{pos_y}"].update(
title=_wrap_metric(o), titlefont={"size": 10}
)
# do not put x-axis ticks for boxplots
boxplot_xaxes = []
for trace in fig["data"]:
if trace["type"] == "box":
# stores the xaxes which correspond to boxplot subplots
# since we use xaxis1, xaxis2, etc, in plotly.py
boxplot_xaxes.append("xaxis{}".format(trace["xaxis"][1:]))
else:
# clear all error bars since default is no CI
trace["error_x"].update(width=0, thickness=0)
trace["error_y"].update(width=0, thickness=0)
for xaxis in boxplot_xaxes:
fig["layout"][xaxis]["showticklabels"] = False
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
# Single metric fitted values
def _single_metric_traces(
model: ModelBridge,
metric: str,
generator_runs_dict: TNullableGeneratorRunsDict,
rel: bool,
show_arm_details_on_hover: bool = True,
showlegend: bool = True,
show_CI: bool = True,
arm_noun: str = "arm",
fixed_features: ObservationFeatures | None = None,
data_selector: Callable[[Observation], bool] | None = None,
scalarized_metric_config: list[dict[str, Any]] | None = None,
) -> Traces:
"""Plot scatterplots with errors for a single metric (y-axis).
Arms are plotted on the x-axis.
Args:
model: model to draw predictions from.
metric: name of metric to plot.
generator_runs_dict: a mapping from
generator run name to generator run.
rel: if True, plot relative predictions.
show_arm_details_on_hover: if True, display
parameterizations of arms on hover. Default is True.
show_legend: if True, show legend for trace.
show_CI: if True, render confidence intervals.
arm_noun: noun to use instead of "arm" (e.g. group)
fixed_features: Fixed features to use when making model predictions.
data_selector: Function for selecting observations for plotting.
scalarized_metric_config: An optional list of dicts specifying how to aggregate
multiple metrics into a single scalarized metric. For each dict, the key is
the name of the new scalarized metric, and the value is a dictionary mapping
each metric to its weight. e.g.
{"name": "metric1:agg", "weight": {"metric1_c1": 0.5, "metric1_c2": 0.5}}.
"""
plot_data, _, _ = get_plot_data(
model,
generator_runs_dict or {},
{metric},
fixed_features=fixed_features,
data_selector=data_selector,
scalarized_metric_config=scalarized_metric_config,
)
status_quo_arm = (
None
if plot_data.status_quo_name is None
else plot_data.in_sample.get(plot_data.status_quo_name)
)
traces = [
_error_scatter_trace(
list(plot_data.in_sample.values()),
x_axis_var=None,
y_axis_var=PlotMetric(metric, pred=True, rel=rel),
status_quo_arm=status_quo_arm,
legendgroup="In-sample",
showlegend=showlegend,
show_arm_details_on_hover=show_arm_details_on_hover,
show_CI=show_CI,
arm_noun=arm_noun,
)
]
# Candidates
for i, (generator_run_name, cand_arms) in enumerate(
(plot_data.out_of_sample or {}).items(), start=1
):
traces.append(
_error_scatter_trace(
list(cand_arms.values()),
x_axis_var=None,
y_axis_var=PlotMetric(metric, pred=True, rel=rel),
status_quo_arm=status_quo_arm,
name=generator_run_name,
color=DISCRETE_COLOR_SCALE[i],
legendgroup=generator_run_name,
showlegend=showlegend,
show_arm_details_on_hover=show_arm_details_on_hover,
show_CI=show_CI,
arm_noun=arm_noun,
)
)
return traces
[docs]
def plot_fitted(
model: ModelBridge,
metric: str,
generator_runs_dict: TNullableGeneratorRunsDict = None,
rel: bool = True,
custom_arm_order: list[str] | None = None,
custom_arm_order_name: str = "Custom",
show_CI: bool = True,
data_selector: Callable[[Observation], bool] | None = None,
scalarized_metric_config: list[dict[str, Any]] | None = None,
) -> AxPlotConfig:
"""Plot fitted metrics.
Args:
model: model to use for predictions.
metric: metric to plot predictions for.
generator_runs_dict: a mapping from
generator run name to generator run.
rel: if True, use relative effects. Default is True.
custom_arm_order: a list of arm names in the
order corresponding to how they should be plotted on the x-axis.
If not None, this is the default ordering.
custom_arm_order_name: name for custom ordering to
show in the ordering dropdown. Default is 'Custom'.
show_CI: if True, render confidence intervals.
data_selector: Function for selecting observations for plotting.
scalarized_metric_config: An optional list of dicts specifying how to aggregate
multiple metrics into a single scalarized metric. For each dict, the key is
the name of the new scalarized metric, and the value is a dictionary mapping
each metric to its weight. e.g.
{"name": "metric1:agg", "weight": {"metric1_c1": 0.5, "metric1_c2": 0.5}}.
"""
traces = _single_metric_traces(
model,
metric,
generator_runs_dict,
rel,
show_CI=show_CI,
data_selector=data_selector,
scalarized_metric_config=scalarized_metric_config,
)
# order arm name sorting arm numbers within batch
names_by_arm = sorted(
np.unique(np.concatenate([d["x"] for d in traces])),
key=lambda x: arm_name_to_tuple(x),
)
# get arm names sorted by effect size
names_by_effect = list(
OrderedDict.fromkeys(
np.concatenate([d["x"] for d in traces])
.flatten()
.take(np.argsort(np.concatenate([d["y"] for d in traces]).flatten()))
)
)
# options for ordering arms (x-axis)
xaxis_categoryorder = "array"
xaxis_categoryarray = names_by_arm
order_options = [
{
"args": [
{"xaxis.categoryorder": "array", "xaxis.categoryarray": names_by_arm}
],
"label": "Name",
"method": "relayout",
},
{
"args": [
{"xaxis.categoryorder": "array", "xaxis.categoryarray": names_by_effect}
],
"label": "Effect Size",
"method": "relayout",
},
]
# if a custom order has been passed, default to that
if custom_arm_order is not None:
xaxis_categoryorder = "array"
xaxis_categoryarray = custom_arm_order
order_options = [
{
"args": [
{
"xaxis.categoryorder": "array",
"xaxis.categoryarray": custom_arm_order,
}
],
"label": custom_arm_order_name,
"method": "relayout",
}
# Union[List[str...
] + order_options
layout = go.Layout(
title="Predicted Outcomes",
hovermode="closest",
updatemenus=[
{
"x": 1.25,
"y": 0.67,
"buttons": list(order_options),
"yanchor": "middle",
"xanchor": "left",
}
],
yaxis={
"zerolinecolor": "red",
"title": "{}{}".format(metric, " (%)" if rel else ""),
},
xaxis={
"tickangle": 45,
"categoryorder": xaxis_categoryorder,
"categoryarray": xaxis_categoryarray,
},
annotations=[
{
"x": 1.18,
"y": 0.72,
"xref": "paper",
"yref": "paper",
"text": "Sort By",
"showarrow": False,
"yanchor": "middle",
}
],
font={"size": 10},
)
fig = go.Figure(data=traces, layout=layout)
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
[docs]
def tile_fitted(
model: ModelBridge,
generator_runs_dict: TNullableGeneratorRunsDict = None,
rel: bool = True,
show_arm_details_on_hover: bool = False,
show_CI: bool = True,
arm_noun: str = "arm",
metrics: list[str] | None = None,
fixed_features: ObservationFeatures | None = None,
data_selector: Callable[[Observation], bool] | None = None,
scalarized_metric_config: list[dict[str, Any]] | None = None,
label_dict: dict[str, str] | None = None,
) -> AxPlotConfig:
"""Tile version of fitted outcome plots.
Args:
model: model to use for predictions.
generator_runs_dict: a mapping from
generator run name to generator run.
rel: if True, use relative effects. Default is True.
show_arm_details_on_hover: if True, display
parameterizations of arms on hover. Default is False.
show_CI: if True, render confidence intervals.
arm_noun: noun to use instead of "arm" (e.g. group)
metrics: List of metric names to restrict to when plotting.
fixed_features: Fixed features to use when making model predictions.
data_selector: Function for selecting observations for plotting.
scalarized_metric_config: An optional list of dicts specifying how to aggregate
multiple metrics into a single scalarized metric. For each dict, the key is
the name of the new scalarized metric, and the value is a dictionary mapping
each metric to its weight. e.g.
{"name": "metric1:agg", "weight": {"metric1_c1": 0.5, "metric1_c2": 0.5}}.
label_dict: A dictionary that maps the label to an alias to be used in the plot.
"""
metrics = metrics or list(model.metric_names)
nrows = int(np.ceil(len(metrics) / 2))
ncols = min(len(metrics), 2)
# make subplots (plot per row)
if label_dict is None:
subplot_titles = metrics
else:
subplot_titles = [label_dict.get(metric, metric) for metric in metrics]
fig = subplots.make_subplots(
rows=nrows,
cols=ncols,
print_grid=False,
shared_xaxes=False,
shared_yaxes=False,
subplot_titles=tuple(subplot_titles),
horizontal_spacing=0.05,
vertical_spacing=0.30 / nrows,
)
name_order_args: dict[str, Any] = {}
name_order_axes: dict[str, dict[str, Any]] = {}
effect_order_args: dict[str, Any] = {}
for i, metric in enumerate(metrics):
data = _single_metric_traces(
model,
metric,
generator_runs_dict,
rel,
showlegend=i == 0,
show_arm_details_on_hover=show_arm_details_on_hover,
show_CI=show_CI,
arm_noun=arm_noun,
fixed_features=fixed_features,
data_selector=data_selector,
scalarized_metric_config=scalarized_metric_config,
)
# order arm name sorting arm numbers within batch
names_by_arm = sorted(
np.unique(np.concatenate([d["x"] for d in data])),
key=lambda x: arm_name_to_tuple(x),
)
# get arm names sorted by effect size
names_by_effect = list(
OrderedDict.fromkeys(
np.concatenate([d["x"] for d in data])
.flatten()
.take(np.argsort(np.concatenate([d["y"] for d in data]).flatten()))
)
)
# options for ordering arms (x-axis)
# Note that xaxes need to be references as xaxis, xaxis2, xaxis3, etc.
# for the purposes of updatemenus argument (dropdown) in layout.
# However, when setting the initial ordering layout, the keys should be
# xaxis1, xaxis2, xaxis3, etc. Note the discrepancy for the initial
# axis.
label = "" if i == 0 else i + 1
name_order_args[f"xaxis{label}.categoryorder"] = "array"
name_order_args[f"xaxis{label}.categoryarray"] = names_by_arm
effect_order_args[f"xaxis{label}.categoryorder"] = "array"
effect_order_args[f"xaxis{label}.categoryarray"] = names_by_effect
name_order_axes[f"xaxis{i + 1}"] = {
"categoryorder": "array",
"categoryarray": names_by_arm,
"type": "category",
}
name_order_axes[f"yaxis{i + 1}"] = {
"ticksuffix": "%" if rel else "",
"zerolinecolor": "red",
}
for d in data:
fig.append_trace(d, int(np.floor(i / ncols)) + 1, i % ncols + 1)
order_options = [
{"args": [name_order_args], "label": "Name", "method": "relayout"},
{"args": [effect_order_args], "label": "Effect Size", "method": "relayout"},
]
# if odd number of plots, need to manually remove the last blank subplot
# generated by `subplots.make_subplots`
if len(metrics) % 2 == 1:
fig["layout"].pop(f"xaxis{nrows * ncols}")
fig["layout"].pop(f"yaxis{nrows * ncols}")
# allocate 400 px per plot
fig["layout"].update(
margin={"t": 0},
hovermode="closest",
updatemenus=[
{
"x": 0.15,
"y": 1 + 0.40 / nrows,
"buttons": order_options,
"xanchor": "left",
"yanchor": "middle",
}
],
font={"size": 10},
width=650 if ncols == 1 else 950,
height=300 * nrows,
legend={
"orientation": "h",
"x": 0,
"y": 1 + 0.20 / nrows,
"xanchor": "left",
"yanchor": "middle",
},
**name_order_axes,
)
# append dropdown annotations
fig["layout"]["annotations"] = fig["layout"]["annotations"] + (
{
"x": 0.5,
"y": 1 + 0.40 / nrows,
"xref": "paper",
"yref": "paper",
"font": {"size": 14},
"text": "Predicted Outcomes",
"showarrow": False,
"xanchor": "center",
"yanchor": "middle",
},
{
"x": 0.05,
"y": 1 + 0.40 / nrows,
"xref": "paper",
"yref": "paper",
"text": "Sort By",
"showarrow": False,
"xanchor": "left",
"yanchor": "middle",
},
)
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
fig = resize_subtitles(figure=fig, size=10)
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
[docs]
def interact_fitted_plotly(
model: ModelBridge,
generator_runs_dict: TNullableGeneratorRunsDict = None,
rel: bool = True,
show_arm_details_on_hover: bool = True,
show_CI: bool = True,
arm_noun: str = "arm",
metrics: list[str] | None = None,
fixed_features: ObservationFeatures | None = None,
data_selector: Callable[[Observation], bool] | None = None,
label_dict: dict[str, str] | None = None,
scalarized_metric_config: list[dict[str, Any]] | None = None,
) -> go.Figure:
"""Interactive fitted outcome plots for each arm used in fitting the model.
Choose the outcome to plot using a dropdown.
Args:
model: model to use for predictions.
generator_runs_dict: a mapping from
generator run name to generator run.
rel: if True, use relative effects. Default is True.
show_arm_details_on_hover: if True, display
parameterizations of arms on hover. Default is True.
show_CI: if True, render confidence intervals.
arm_noun: noun to use instead of "arm" (e.g. group)
metrics: List of metric names to restrict to when plotting.
fixed_features: Fixed features to use when making model predictions.
data_selector: Function for selecting observations for plotting.
label_dict: A dictionary that maps the label to
an alias to be used in the plot.
scalarized_metric_config: An optional list of dicts specifying how to aggregate
multiple metrics into a single scalarized metric. For each dict, the key is
the name of the new scalarized metric, and the value is a dictionary mapping
each metric to its weight. e.g.
{"name": "metric1:agg", "weight": {"metric1_c1": 0.5, "metric1_c2": 0.5}}.
"""
traces_per_metric = (
1 if generator_runs_dict is None else len(generator_runs_dict) + 1
)
metrics = sorted(metrics or model.metric_names)
if not label_dict:
_check_label_lengths(metrics)
traces = []
dropdown = []
if scalarized_metric_config is not None:
all_metrics = metrics + [agg["name"] for agg in scalarized_metric_config]
else:
all_metrics = metrics
for i, metric in enumerate(all_metrics):
data = _single_metric_traces(
model,
metric,
generator_runs_dict,
rel,
showlegend=i == 0,
show_arm_details_on_hover=show_arm_details_on_hover,
show_CI=show_CI,
arm_noun=arm_noun,
fixed_features=fixed_features,
data_selector=data_selector,
scalarized_metric_config=scalarized_metric_config,
)
for d in data:
d["visible"] = i == 0
traces.append(d)
# only the first two traces are visible (corresponding to first outcome
# in dropdown)
is_visible = [False] * (len(all_metrics) * traces_per_metric)
for j in range((traces_per_metric * i), (traces_per_metric * (i + 1))):
is_visible[j] = True
# on dropdown change, restyle
dropdown.append(
{
"args": ["visible", is_visible],
"label": _replace_str(
metric,
label_dict,
),
"method": "restyle",
}
)
layout = go.Layout(
xaxis={"title": arm_noun.title(), "zeroline": False, "type": "category"},
yaxis={
"ticksuffix": "%" if rel else "",
"title": ("Relative " if rel else "") + "Effect",
"zeroline": True,
"zerolinecolor": "red",
},
hovermode="closest",
updatemenus=[
{
"buttons": dropdown,
"x": 0.075,
"xanchor": "left",
"y": 1.1,
"yanchor": "middle",
}
],
annotations=[
{
"font": {"size": 12},
"showarrow": False,
"text": "Metric",
"x": 0.05,
"xanchor": "right",
"xref": "paper",
"y": 1.1,
"yanchor": "middle",
"yref": "paper",
}
],
legend={
"orientation": "h",
"x": 0.065,
"xanchor": "left",
"y": 1.2,
"yanchor": "middle",
},
height=500,
)
if traces_per_metric > 1:
layout["annotations"] = layout["annotations"] + (
{
"font": {"size": 12},
"showarrow": False,
"text": "Arm Source",
"x": 0.05,
"xanchor": "right",
"xref": "paper",
"y": 1.2,
"yanchor": "middle",
"yref": "paper",
},
)
return go.Figure(data=traces, layout=layout)
[docs]
def interact_fitted(
model: ModelBridge,
generator_runs_dict: TNullableGeneratorRunsDict = None,
rel: bool = True,
show_arm_details_on_hover: bool = True,
show_CI: bool = True,
arm_noun: str = "arm",
metrics: list[str] | None = None,
fixed_features: ObservationFeatures | None = None,
data_selector: Callable[[Observation], bool] | None = None,
label_dict: dict[str, str] | None = None,
scalarized_metric_config: list[dict[str, Any]] | None = None,
) -> AxPlotConfig:
"""Interactive fitted outcome plots for each arm used in fitting the model.
Choose the outcome to plot using a dropdown.
Args:
model: model to use for predictions.
generator_runs_dict: a mapping from
generator run name to generator run.
rel: if True, use relative effects. Default is True.
show_arm_details_on_hover: if True, display
parameterizations of arms on hover. Default is True.
show_CI: if True, render confidence intervals.
arm_noun: noun to use instead of "arm" (e.g. group)
metrics: List of metric names to restrict to when plotting.
fixed_features: Fixed features to use when making model predictions.
data_selector: Function for selecting observations for plotting.
label_dict: A dictionary that maps the label to
an alias to be used in the plot.
scalarized_metric_config: An optional list of dicts specifying how to aggregate
multiple metrics into a single scalarized metric. For each dict, the key is
the name of the new scalarized metric, and the value is a dictionary mapping
each metric to its weight. e.g.
{"name": "metric1:agg", "weight": {"metric1_c1": 0.5, "metric1_c2": 0.5}}.
"""
return AxPlotConfig(
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got
# `Figure`.
data=interact_fitted_plotly(
model=model,
generator_runs_dict=generator_runs_dict,
rel=rel,
show_arm_details_on_hover=show_arm_details_on_hover,
show_CI=show_CI,
arm_noun=arm_noun,
metrics=metrics,
fixed_features=fixed_features,
data_selector=data_selector,
label_dict=label_dict,
scalarized_metric_config=scalarized_metric_config,
),
plot_type=AxPlotTypes.GENERIC,
)
[docs]
def tile_observations(
experiment: Experiment,
data: Data | None = None,
rel: bool = True,
metrics: list[str] | None = None,
arm_names: list[str] | None = None,
arm_noun: str = "arm",
label_dict: dict[str, str] | None = None,
) -> AxPlotConfig:
"""
Tiled plot with all observed outcomes.
Will plot all observed arms. If data is provided will use that, otherwise
will fetch data from experiment. Will plot all metrics in data unless a
list is provided in metrics. If arm_names is provided will limit the plot
to only arms in that list.
Args:
experiment: Experiment
data: Data to use, otherwise will fetch data from experiment.
rel: Plot relative values, if experiment has status quo.
metrics: Limit results to this set of metrics.
arm_names: Limit results to this set of arms.
arm_noun: Noun to use instead of "arm".
label_dict: A dictionary that maps the label to an alias to be used in the plot.
Returns: Plot config for the plot.
"""
if data is None:
data = experiment.fetch_data()
if arm_names is not None:
data = Data(data.df[data.df["arm_name"].isin(arm_names)])
m_ts = Models.THOMPSON(
data=data,
search_space=experiment.search_space,
experiment=experiment,
)
return tile_fitted(
model=m_ts,
rel=rel and (experiment.status_quo is not None),
metrics=metrics,
arm_noun=arm_noun,
label_dict=label_dict,
)