#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional, Tuple
import numpy as np
import plotly.graph_objs as go
from ax.plot.base import CI_OPACITY, DECIMALS, AxPlotConfig, AxPlotTypes
from ax.plot.color import COLORS, rgba, DISCRETE_COLOR_SCALE
from ax.plot.helper import extend_range, _format_CI, _format_dict
from ax.plot.pareto_utils import ParetoFrontierResults
from scipy.stats import norm
DEFAULT_CI_LEVEL: float = 0.9
def _make_label(
mean: float, sem: float, name: str, is_relative: bool, Z: Optional[float]
) -> str:
estimate = str(round(mean, DECIMALS))
perc = "%" if is_relative else ""
ci = (
""
if (Z is None or np.isnan(sem))
else _format_CI(
estimate=mean, sd=sem, relative=is_relative, zval=Z # pyre-ignore
)
)
return f"{name}: {estimate}{perc} {ci}<br>"
def _filter_outliers(Y: np.ndarray, m: float = 2.0) -> np.ndarray:
std_filter = abs(Y - np.median(Y, axis=0)) < m * np.std(Y, axis=0)
return Y[np.all(abs(std_filter), axis=1)]
[docs]def scatter_plot_with_pareto_frontier_plotly(
Y: np.ndarray,
Y_pareto: np.ndarray,
metric_x: str,
metric_y: str,
reference_point: Tuple[float, float],
minimize: bool = True,
) -> go.Figure:
"""Plots a scatter of all points in ``Y`` for ``metric_x`` and ``metric_y``
with a reference point and Pareto frontier from ``Y_pareto``.
Points in the scatter are colored in a gradient representing their trial index,
with metric_x on x-axis and metric_y on y-axis. Reference point is represented
as a star and Pareto frontier –– as a line. The frontier connects to the reference
point via projection lines.
NOTE: Both metrics should have the same minimization setting, passed as `minimize`.
Args:
Y: Array of outcomes, of which the first two will be plotted.
Y_pareto: Array of Pareto-optimal points, first two outcomes in which will be
plotted.
metric_x: Name of first outcome in ``Y``.
metric_Y: Name of second outcome in ``Y``.
reference_point: Reference point for ``metric_x`` and ``metric_y``.
minimize: Whether the two metrics in the plot are being minimized or maximized.
"""
Xs = Y[:, 0]
Ys = Y[:, 1]
experimental_points_scatter = go.Scatter(
x=Xs,
y=Ys,
mode="markers",
marker={
"color": np.linspace(0, 100, int(len(Xs) * 1.05)),
"colorscale": "magma",
"colorbar": {
"tickvals": [0, 50, 100],
"ticktext": [
1,
"iteration",
len(Xs),
],
},
},
name="Experimental points",
)
reference_point_star = go.Scatter(
x=[reference_point[0]],
y=[reference_point[1]],
mode="markers",
marker={"color": rgba(COLORS.STEELBLUE.value), "size": 25, "symbol": "star"},
)
extra_point_x = min(Y_pareto[:, 0]) if minimize else max(Y_pareto[:, 0])
reference_point_line_1 = go.Scatter(
x=[extra_point_x, reference_point[0]],
y=[reference_point[1], reference_point[1]],
mode="lines",
marker={"color": rgba(COLORS.STEELBLUE.value)},
)
extra_point_y = min(Y_pareto[:, 1]) if minimize else max(Y_pareto[:, 1])
reference_point_line_2 = go.Scatter(
x=[reference_point[0], reference_point[0]],
y=[extra_point_y, reference_point[1]],
mode="lines",
marker={"color": rgba(COLORS.STEELBLUE.value)},
)
Y_pareto_with_extra = np.concatenate(
(
[[extra_point_x, reference_point[1]]],
Y_pareto,
[[reference_point[0], extra_point_y]],
),
axis=0,
)
pareto_step = go.Scatter(
x=Y_pareto_with_extra[:, 0],
y=Y_pareto_with_extra[:, 1],
mode="lines",
marker={"color": rgba(COLORS.STEELBLUE.value)},
)
Y_no_outliers = _filter_outliers(Y=Y)
range_x = (
extend_range(lower=min(Y_no_outliers[:, 0]), upper=reference_point[0])
if minimize
else extend_range(lower=reference_point[0], upper=max(Y_no_outliers[:, 0]))
)
range_y = (
extend_range(lower=min(Y_no_outliers[:, 1]), upper=reference_point[1])
if minimize
else extend_range(lower=reference_point[1], upper=max(Y_no_outliers[:, 1]))
)
layout = go.Layout(
title="Observed points with Pareto frontier",
showlegend=False,
xaxis={"title": metric_x, "range": range_x},
yaxis={"title": metric_y, "range": range_y},
)
return go.Figure(
layout=layout,
data=[
pareto_step,
reference_point_line_1,
reference_point_line_2,
experimental_points_scatter,
reference_point_star,
],
)
[docs]def scatter_plot_with_pareto_frontier(
Y: np.ndarray,
Y_pareto: np.ndarray,
metric_x: str,
metric_y: str,
reference_point: Tuple[float, float],
minimize: bool = True,
) -> AxPlotConfig:
return AxPlotConfig(
data=scatter_plot_with_pareto_frontier_plotly(
Y=Y,
Y_pareto=Y_pareto,
metric_x=metric_x,
metric_y=metric_y,
reference_point=reference_point,
),
plot_type=AxPlotTypes.GENERIC,
)
def _get_single_pareto_trace(
frontier: ParetoFrontierResults,
CI_level: float,
legend_label: str = "mean",
trace_color: Tuple[int] = COLORS.STEELBLUE.value,
show_parameterization_on_hover: bool = True,
) -> go.Scatter:
primary_means = frontier.means[frontier.primary_metric]
primary_sems = frontier.sems[frontier.primary_metric]
secondary_means = frontier.means[frontier.secondary_metric]
secondary_sems = frontier.sems[frontier.secondary_metric]
absolute_metrics = frontier.absolute_metrics
all_metrics = frontier.means.keys()
if frontier.arm_names is None:
arm_names = [f"Parameterization {i}" for i in range(len(frontier.param_dicts))]
else:
arm_names = [f"Arm {name}" for name in frontier.arm_names]
if CI_level is not None:
Z = 0.5 * norm.ppf(1 - (1 - CI_level) / 2)
else:
Z = None
labels = []
for i, param_dict in enumerate(frontier.param_dicts):
label = f"<b>{arm_names[i]}</b><br>"
for metric in all_metrics:
metric_lab = _make_label(
mean=frontier.means[metric][i],
sem=frontier.sems[metric][i],
name=metric,
is_relative=metric not in absolute_metrics,
Z=Z,
)
label += metric_lab
parameterization = (
_format_dict(param_dict, "Parameterization")
if show_parameterization_on_hover
else ""
)
label += parameterization
labels.append(label)
return go.Scatter(
name=legend_label,
legendgroup=legend_label,
x=secondary_means,
y=primary_means,
error_x={
"type": "data",
"array": Z * np.array(secondary_sems),
"thickness": 2,
"color": rgba(trace_color, CI_OPACITY),
},
error_y={
"type": "data",
"array": Z * np.array(primary_sems),
"thickness": 2,
"color": rgba(trace_color, CI_OPACITY),
},
mode="markers",
text=labels,
hoverinfo="text",
marker={"color": rgba(trace_color)},
)
[docs]def plot_pareto_frontier(
frontier: ParetoFrontierResults,
CI_level: float = DEFAULT_CI_LEVEL,
show_parameterization_on_hover: bool = True,
) -> AxPlotConfig:
"""Plot a Pareto frontier from a ParetoFrontierResults object.
Args:
frontier (ParetoFrontierResults): The results of the Pareto frontier
computation.
CI_level (float, optional): The confidence level, i.e. 0.95 (95%)
show_parameterization_on_hover (bool, optional): If True, show the
parameterization of the points on the frontier on hover.
Returns:
AEPlotConfig: The resulting Plotly plot definition.
"""
trace = _get_single_pareto_trace(
frontier=frontier,
CI_level=CI_level,
show_parameterization_on_hover=show_parameterization_on_hover,
)
shapes = []
primary_threshold = None
secondary_threshold = None
if frontier.objective_thresholds is not None:
primary_threshold = frontier.objective_thresholds.get(
frontier.primary_metric, None
)
secondary_threshold = frontier.objective_thresholds.get(
frontier.secondary_metric, None
)
absolute_metrics = frontier.absolute_metrics
rel_x = frontier.secondary_metric not in absolute_metrics
rel_y = frontier.primary_metric not in absolute_metrics
if primary_threshold is not None:
shapes.append(
{
"type": "line",
"xref": "paper",
"x0": 0.0,
"x1": 1.0,
"yref": "y",
"y0": primary_threshold,
"y1": primary_threshold,
"line": {"color": rgba(COLORS.CORAL.value), "width": 3},
}
)
if secondary_threshold is not None:
shapes.append(
{
"type": "line",
"yref": "paper",
"y0": 0.0,
"y1": 1.0,
"xref": "x",
"x0": secondary_threshold,
"x1": secondary_threshold,
"line": {"color": rgba(COLORS.CORAL.value), "width": 3},
}
)
layout = go.Layout(
title="Pareto Frontier",
xaxis={
"title": frontier.secondary_metric,
"ticksuffix": "%" if rel_x else "",
"zeroline": True,
},
yaxis={
"title": frontier.primary_metric,
"ticksuffix": "%" if rel_y else "",
"zeroline": True,
},
hovermode="closest",
legend={"orientation": "h"},
width=750,
height=500,
margin=go.layout.Margin(pad=4, l=225, b=75, t=75), # noqa E741
shapes=shapes,
)
fig = go.Figure(data=[trace], layout=layout)
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
[docs]def plot_multiple_pareto_frontiers(
frontiers: Dict[str, ParetoFrontierResults],
CI_level: float = DEFAULT_CI_LEVEL,
show_parameterization_on_hover: bool = True,
) -> AxPlotConfig:
"""Plot a Pareto frontier from a ParetoFrontierResults object.
Args:
frontiers (Dict[str, ParetoFrontierResults]): The results of
the Pareto frontier computation.
CI_level (float, optional): The confidence level, i.e. 0.95 (95%)
show_parameterization_on_hover (bool, optional): If True, show the
parameterization of the points on the frontier on hover.
Returns:
AEPlotConfig: The resulting Plotly plot definition.
"""
traces = []
for i, (method, frontier) in enumerate(frontiers.items()):
trace = _get_single_pareto_trace(
frontier=frontier,
legend_label=method,
trace_color=DISCRETE_COLOR_SCALE[i % len(DISCRETE_COLOR_SCALE)],
CI_level=CI_level,
show_parameterization_on_hover=show_parameterization_on_hover,
)
traces.append(trace)
shapes = []
primary_threshold = None
secondary_threshold = None
if frontier.objective_thresholds is not None:
primary_threshold = frontier.objective_thresholds.get(
frontier.primary_metric, None
)
secondary_threshold = frontier.objective_thresholds.get(
frontier.secondary_metric, None
)
absolute_metrics = frontier.absolute_metrics
rel_x = frontier.secondary_metric not in absolute_metrics
rel_y = frontier.primary_metric not in absolute_metrics
if primary_threshold is not None:
shapes.append(
{
"type": "line",
"xref": "paper",
"x0": 0.0,
"x1": 1.0,
"yref": "y",
"y0": primary_threshold,
"y1": primary_threshold,
"line": {"color": rgba(COLORS.CORAL.value), "width": 3},
}
)
if secondary_threshold is not None:
shapes.append(
{
"type": "line",
"yref": "paper",
"y0": 0.0,
"y1": 1.0,
"xref": "x",
"x0": secondary_threshold,
"x1": secondary_threshold,
"line": {"color": rgba(COLORS.CORAL.value), "width": 3},
}
)
layout = go.Layout(
title="Pareto Frontier",
xaxis={
"title": frontier.secondary_metric,
"ticksuffix": "%" if rel_x else "",
"zeroline": True,
},
yaxis={
"title": frontier.primary_metric,
"ticksuffix": "%" if rel_y else "",
"zeroline": True,
},
hovermode="closest",
legend={"orientation": "h"},
width=750,
height=500,
margin=go.layout.Margin(pad=4, l=225, b=75, t=75), # noqa E741
shapes=shapes,
)
fig = go.Figure(data=traces, layout=layout)
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
[docs]def interact_pareto_frontier(
frontier_list: List[ParetoFrontierResults],
CI_level: float = DEFAULT_CI_LEVEL,
show_parameterization_on_hover: bool = True,
) -> AxPlotConfig:
"""Plot a pareto frontier from a list of objects"""
if not frontier_list:
raise ValueError("Must receive a non-empty list of pareto frontiers to plot.")
traces = []
shapes = []
for frontier in frontier_list:
config = plot_pareto_frontier(
frontier=frontier,
CI_level=CI_level,
show_parameterization_on_hover=show_parameterization_on_hover,
)
traces.append(config.data["data"][0])
shapes.append(config.data["layout"].get("shapes", []))
for i, trace in enumerate(traces):
if i == 0: # Only the first trace is initially set to visible
trace["visible"] = True
else: # All other plot traces are not visible initially
trace["visible"] = False
# TODO (jej): replace dropdown with two dropdowns, one for x one for y.
dropdown = []
for i, frontier in enumerate(frontier_list):
trace_cnt = 1
# Only one plot trace is visible at a given time.
visible = [False] * (len(frontier_list) * trace_cnt)
for j in range(i * trace_cnt, (i + 1) * trace_cnt):
visible[j] = True
rel_y = frontier.primary_metric not in frontier.absolute_metrics
rel_x = frontier.secondary_metric not in frontier.absolute_metrics
primary_metric = frontier.primary_metric
secondary_metric = frontier.secondary_metric
dropdown.append(
{
"method": "update",
"args": [
{"visible": visible, "method": "restyle"},
{
"yaxis.title": primary_metric,
"xaxis.title": secondary_metric,
"yaxis.ticksuffix": "%" if rel_y else "",
"xaxis.ticksuffix": "%" if rel_x else "",
"shapes": shapes[i],
},
],
"label": f"{primary_metric} vs {secondary_metric}",
}
)
# Set initial layout arguments.
initial_frontier = frontier_list[0]
rel_x = initial_frontier.secondary_metric not in initial_frontier.absolute_metrics
rel_y = initial_frontier.primary_metric not in initial_frontier.absolute_metrics
secondary_metric = initial_frontier.secondary_metric
primary_metric = initial_frontier.primary_metric
layout = go.Layout(
title="Pareto Frontier",
xaxis={
"title": secondary_metric,
"ticksuffix": "%" if rel_x else "",
"zeroline": True,
},
yaxis={
"title": primary_metric,
"ticksuffix": "%" if rel_y else "",
"zeroline": True,
},
updatemenus=[
{
"buttons": dropdown,
"x": 0.075,
"xanchor": "left",
"y": 1.1,
"yanchor": "middle",
}
],
hovermode="closest",
legend={"orientation": "h"},
width=750,
height=500,
margin=go.layout.Margin(pad=4, l=225, b=75, t=75), # noqa E741
shapes=shapes[0],
)
fig = go.Figure(data=traces, layout=layout)
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)