#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional
import numpy as np
import plotly.graph_objs as go
from ax.plot.base import CI_OPACITY, DECIMALS, AxPlotConfig, AxPlotTypes
from ax.plot.helper import _format_CI, _format_dict
from ax.plot.pareto_utils import COLORS, ParetoFrontierResults, rgba
from scipy.stats import norm
DEFAULT_CI_LEVEL: float = 0.9
def _make_label(
mean: float, sem: float, name: str, is_relative: bool, Z: Optional[float]
) -> str:
return "{name}: {estimate}{perc} {ci}<br>".format(
name=name,
estimate=round(mean, DECIMALS),
ci=""
if Z is None
else _format_CI(estimate=mean, sd=sem, relative=is_relative, zval=Z),
perc="%" if is_relative else "",
)
[docs]def plot_pareto_frontier(
frontier: ParetoFrontierResults,
CI_level: float = DEFAULT_CI_LEVEL,
show_parameterization_on_hover: bool = True,
) -> AxPlotConfig:
"""Plot a Pareto frontier from a ParetoFrontierResults object.
Args:
frontier (ParetoFrontierResults): The results of the Pareto frontier
computation.
CI_level (float, optional): The confidence level, i.e. 0.95 (95%)
show_parameterization_on_hover (bool, optional): If True, show the
parameterization of the points on the frontier on hover.
Returns:
AEPlotConfig: The resulting Plotly plot definition.
"""
primary_means = frontier.means[frontier.primary_metric]
primary_sems = frontier.sems[frontier.primary_metric]
secondary_means = frontier.means[frontier.secondary_metric]
secondary_sems = frontier.sems[frontier.secondary_metric]
absolute_metrics = frontier.absolute_metrics
if CI_level is not None:
Z = 0.5 * norm.ppf(1 - (1 - CI_level) / 2)
else:
Z = None
labels = []
rel_x = frontier.secondary_metric not in absolute_metrics
rel_y = frontier.primary_metric not in absolute_metrics
for i, param_dict in enumerate(frontier.param_dicts):
heading = "<b>Parameterization {}</b><br>".format(i + 1)
x_lab = _make_label(
mean=secondary_means[i],
sem=secondary_sems[i],
name=frontier.secondary_metric,
is_relative=rel_x,
Z=Z,
)
y_lab = _make_label(
mean=primary_means[i],
sem=primary_sems[i],
name=frontier.primary_metric,
is_relative=rel_y,
Z=Z,
)
parameterization = (
_format_dict(param_dict, "Parameterization")
if show_parameterization_on_hover
else ""
)
labels.append(
"{heading}<br>{xlab}{ylab}{param_blob}".format(
heading=heading, xlab=x_lab, ylab=y_lab, param_blob=parameterization
)
)
traces = [
go.Scatter(
x=secondary_means,
y=primary_means,
error_x={
"type": "data",
"array": Z * np.array(secondary_sems),
"thickness": 2,
"color": rgba(COLORS.STEELBLUE.value, CI_OPACITY),
},
error_y={
"type": "data",
"array": Z * np.array(primary_sems),
"thickness": 2,
"color": rgba(COLORS.STEELBLUE.value, CI_OPACITY),
},
mode="markers",
text=labels,
hoverinfo="text",
)
]
layout = go.Layout(
title="Pareto Frontier",
xaxis={
"title": frontier.secondary_metric,
"ticksuffix": "%" if rel_x else "",
"zeroline": True,
},
yaxis={
"title": frontier.primary_metric,
"ticksuffix": "%" if rel_y else "",
"zeroline": True,
},
hovermode="closest",
legend={"orientation": "h"},
width=750,
height=500,
margin=go.layout.Margin(pad=4, l=225, b=75, t=75), # noqa E741
)
fig = go.Figure(data=traces, layout=layout)
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
[docs]def interact_pareto_frontier(
frontier_list: List[ParetoFrontierResults],
CI_level: float = DEFAULT_CI_LEVEL,
show_parameterization_on_hover: bool = True,
) -> AxPlotConfig:
"""Plot a pareto frontier from a list of objects"""
if not frontier_list:
raise ValueError("Must receive a non-empty list of pareto frontiers to plot.")
traces = [
plot_pareto_frontier(
frontier=frontier,
CI_level=CI_level,
show_parameterization_on_hover=show_parameterization_on_hover,
).data["data"][0]
for frontier in frontier_list
]
for i, trace in enumerate(traces):
if i == 0: # Only the first trace is initially set to visible
trace["visible"] = True
else: # All other plot traces are not visible initially
trace["visible"] = False
# TODO (jej): replace dropdown with two dropdowns, one for x one for y.
dropdown = []
for i, frontier in enumerate(frontier_list):
trace_cnt = 1
# Only one plot trace is visible at a given time.
visible = [False] * (len(frontier_list) * trace_cnt)
for j in range(i * trace_cnt, (i + 1) * trace_cnt):
visible[j] = True
rel_y = frontier.primary_metric not in frontier.absolute_metrics
rel_x = frontier.secondary_metric not in frontier.absolute_metrics
primary_metric = frontier.primary_metric
secondary_metric = frontier.secondary_metric
dropdown.append(
{
"method": "update",
"args": [
{"visible": visible, "method": "restyle"},
{
"yaxis.title": primary_metric,
"xaxis.title": secondary_metric,
"yaxis.ticksuffix": "%" if rel_y else "",
"xaxis.ticksuffix": "%" if rel_x else "",
},
],
"label": f"{primary_metric} vs {secondary_metric}",
}
)
# Set initial layout arguments.
initial_frontier = frontier_list[0]
rel_x = initial_frontier.secondary_metric not in initial_frontier.absolute_metrics
rel_y = initial_frontier.primary_metric not in initial_frontier.absolute_metrics
secondary_metric = initial_frontier.secondary_metric
primary_metric = initial_frontier.primary_metric
layout = go.Layout(
title="Pareto Frontier",
xaxis={
"title": secondary_metric,
"ticksuffix": "%" if rel_x else "",
"zeroline": True,
},
yaxis={
"title": primary_metric,
"ticksuffix": "%" if rel_y else "",
"zeroline": True,
},
updatemenus=[
{
"buttons": dropdown,
"x": 0.075,
"xanchor": "left",
"y": 1.1,
"yanchor": "middle",
}
],
hovermode="closest",
legend={"orientation": "h"},
width=750,
height=500,
margin=go.layout.Margin(pad=4, l=225, b=75, t=75), # noqa E741
)
fig = go.Figure(data=traces, layout=layout)
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)