#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import plotly.graph_objs as go
from ax.plot.base import AxPlotConfig, AxPlotTypes
from ax.plot.color import COLORS, DISCRETE_COLOR_SCALE, rgba
# type aliases
Traces = List[Dict[str, Any]]
[docs]def mean_trace_scatter(
y: np.ndarray,
trace_color: Tuple[int] = COLORS.STEELBLUE.value,
legend_label: str = "mean",
) -> go.Scatter:
"""Creates a graph object for trace of the mean of the given series across
runs.
Args:
y: (r x t) array with results from r runs and t trials.
trace_color: tuple of 3 int values representing an RGB color.
Defaults to blue.
legend_label: label for this trace
Returns:
go.Scatter: plotly graph object
"""
return go.Scatter( # pyre-ignore[16]: `plotly.graph_objs` has no attr. `Scatter`
name=legend_label,
legendgroup=legend_label,
x=np.arange(1, y.shape[1] + 1),
y=np.mean(y, axis=0),
mode="lines",
line={"color": rgba(trace_color)},
fillcolor=rgba(trace_color, 0.3),
fill="tonexty",
)
[docs]def sem_range_scatter(
y: np.ndarray,
trace_color: Tuple[int] = COLORS.STEELBLUE.value,
legend_label: str = "",
) -> Tuple[go.Scatter]:
"""Creates a graph object for trace of mean +/- 2 SEMs for y, across runs.
Args:
y: (r x t) array with results from r runs and t trials.
trace_color: tuple of 3 int values representing an RGB color.
Defaults to blue.
legend_label: Label for the legend group.
Returns:
Tuple[go.Scatter]: plotly graph objects for lower and upper bounds
"""
mean = np.mean(y, axis=0)
sem = np.std(y, axis=0) / np.sqrt(y.shape[0])
return (
go.Scatter( # pyre-ignore[16]: `plotly.graph_objs` has no attr. `Scatter`
x=np.arange(1, y.shape[1] + 1),
y=mean - 2 * sem,
legendgroup=legend_label,
mode="lines",
line={"width": 0},
showlegend=False,
hoverinfo="none",
),
go.Scatter( # pyre-ignore[16]: `plotly.graph_objs` has no attr. `Scatter`
x=np.arange(1, y.shape[1] + 1),
y=mean + 2 * sem,
legendgroup=legend_label,
mode="lines",
line={"width": 0},
fillcolor=rgba(trace_color, 0.3),
fill="tonexty",
showlegend=False,
hoverinfo="none",
),
)
[docs]def optimum_objective_scatter(
optimum: float, num_iterations: int, optimum_color: Tuple[int] = COLORS.ORANGE.value
) -> go.Scatter:
"""Creates a graph object for the line representing optimal objective.
Args:
optimum: value of the optimal objective
num_iterations: how many trials were in the optimization (used to
determine the width of the plot)
trace_color: tuple of 3 int values representing an RGB color.
Defaults to orange.
Returns:
go.Scatter: plotly graph objects for the optimal objective line
"""
return go.Scatter( # pyre-ignore[16]: `plotly.graph_objs` has no attr. `Scatter`
x=[1, num_iterations],
y=[optimum] * 2,
mode="lines",
line={"dash": "dash", "color": rgba(optimum_color)},
name="Optimum",
)
[docs]def generator_changes_scatter(
generator_changes: List[int],
y_range: List[float],
generator_change_color: Tuple[int] = COLORS.TEAL.value,
) -> List[go.Scatter]:
"""Creates a graph object for the line(s) representing generator changes.
Args:
generator_changes: iterations, before which generators
changed
y_range: upper and lower values of the y-range of the plot
generator_change_color: tuple of 3 int values representing
an RGB color. Defaults to orange.
Returns:
go.Scatter: plotly graph objects for the lines representing generator
changes
"""
if len(y_range) != 2: # pragma: no cover
raise ValueError("y_range should have two values, lower and upper.")
data: List[go.Scatter] = []
for change in generator_changes:
data.append(
go.Scatter( # pyre-ignore[16]: `plotly.graph_objs` has no attr. `Scatter`
x=[change] * 2,
y=y_range,
mode="lines",
line={"dash": "dash", "color": rgba(generator_change_color)},
name="Generator change",
)
)
return data
[docs]def optimization_trace_single_method(
y: np.ndarray,
optimum: Optional[float] = None,
generator_changes: Optional[List[int]] = None,
title: str = "",
ylabel: str = "",
trace_color: Tuple[int] = COLORS.STEELBLUE.value,
optimum_color: Tuple[int] = COLORS.ORANGE.value,
generator_change_color: Tuple[int] = COLORS.TEAL.value,
) -> AxPlotConfig:
"""Plots an optimization trace with mean and 2 SEMs
Args:
y: (r x t) array; result to plot, with r runs and t trials
optimum: value of the optimal objective
generator_changes: iterations, before which generators
changed
title: title of this plot
ylabel: Label for y axis
trace_color: tuple of 3 int values representing an RGB color.
Defaults to orange.
optimum_color: tuple of 3 int values representing an RGB color.
Defaults to orange.
generator_change_color: tuple of 3 int values representing
an RGB color. Defaults to orange.
Returns:
AxPlotConfig: plot of the optimization trace with IQR
"""
trace = mean_trace_scatter(y=y, trace_color=trace_color)
# pyre-fixme[23]: Unable to unpack single value, 2 were expected.
lower, upper = sem_range_scatter(y=y, trace_color=trace_color)
layout = go.Layout( # pyre-ignore[16]: ...graph_objs` has no attr. `Layout`
title=title,
showlegend=True,
yaxis={"title": ylabel},
xaxis={"title": "Iteration"},
)
data = [lower, trace, upper]
if optimum is not None:
data.append(
optimum_objective_scatter(
optimum=optimum, num_iterations=y.shape[1], optimum_color=optimum_color
)
)
if generator_changes is not None: # pragma: no cover
y_lower = np.min(np.percentile(y, 25, axis=0))
y_upper = np.max(np.percentile(y, 75, axis=0))
if optimum is not None and optimum < y_lower:
y_lower = optimum
if optimum is not None and optimum > y_upper:
y_upper = optimum
data.extend(
generator_changes_scatter(
generator_changes=generator_changes,
y_range=[y_lower, y_upper],
generator_change_color=generator_change_color,
)
)
return AxPlotConfig(
# pyre-ignore[16]: ...graph_objs` has no attr. `Figure`
data=go.Figure(layout=layout, data=data),
plot_type=AxPlotTypes.GENERIC,
)
[docs]def optimization_trace_all_methods(
y_dict: Dict[str, np.ndarray],
optimum: Optional[float] = None,
title: str = "",
ylabel: str = "",
trace_colors: List[Tuple[int]] = DISCRETE_COLOR_SCALE,
optimum_color: Tuple[int] = COLORS.ORANGE.value,
) -> AxPlotConfig:
"""Plots a comparison of optimization traces with 2-SEM bands for multiple
methods on the same problem.
Args:
y: a mapping of method names to (r x t) arrays, where r is the number
of runs in the test, and t is the number of trials.
optimum: value of the optimal objective.
title: Title for this plot.
ylabel: Label for y axis
trace_colors: tuples of 3 int values representing
RGB colors to use for different methods shown in the combination plot.
Defaults to Ax discrete color scale.
optimum_color: tuple of 3 int values representing an RGB color.
Defaults to orange.
Returns:
AxPlotConfig: plot of the comparison of optimization traces with IQR
"""
data: List[go.Scatter] = []
for i, (method, y) in enumerate(y_dict.items()):
# If there are more traces than colors, start reusing colors.
color = trace_colors[i % len(trace_colors)]
trace = mean_trace_scatter(y=y, trace_color=color, legend_label=method)
# pyre-fixme[23]: Unable to unpack single value, 2 were expected.
lower, upper = sem_range_scatter(y=y, trace_color=color, legend_label=method)
data.extend([lower, trace, upper])
if optimum is not None:
num_iterations = max(y.shape[1] for y in y_dict.values())
data.append(
optimum_objective_scatter(
optimum=optimum,
num_iterations=num_iterations,
optimum_color=optimum_color,
)
)
layout = go.Layout( # pyre-ignore[16]: ...graph_objs` has no attr. `Layout`
title=title,
showlegend=True,
yaxis={"title": ylabel},
xaxis={"title": "Iteration"},
)
return AxPlotConfig(
# pyre-ignore[16]: ...graph_objs` has no attr. `Figure`
data=go.Figure(layout=layout, data=data),
plot_type=AxPlotTypes.GENERIC,
)
[docs]def optimization_times(
fit_times: Dict[str, List[float]],
gen_times: Dict[str, List[float]],
title: str = "",
) -> AxPlotConfig:
"""Plots wall times for each method as a bar chart.
Args:
fit_times: A map from method name to a list of the model fitting times.
gen_times: A map from method name to a list of the gen times.
title: Title for this plot.
Returns: AxPlotConfig with the plot
"""
# Compute means and SEs
methods = list(fit_times.keys())
fit_res: Dict[str, Union[str, List[float]]] = {"name": "Fitting"}
fit_res["mean"] = [np.mean(fit_times[m]) for m in methods]
fit_res["2sems"] = [
2 * np.std(fit_times[m]) / np.sqrt(len(fit_times[m])) for m in methods
]
gen_res: Dict[str, Union[str, List[float]]] = {"name": "Generation"}
gen_res["mean"] = [np.mean(gen_times[m]) for m in methods]
gen_res["2sems"] = [
2 * np.std(gen_times[m]) / np.sqrt(len(gen_times[m])) for m in methods
]
total_mean: List[float] = []
total_2sems: List[float] = []
for m in methods:
totals = np.array(fit_times[m]) + np.array(gen_times[m])
total_mean.append(np.mean(totals))
total_2sems.append(2 * np.std(totals) / np.sqrt(len(totals)))
total_res: Dict[str, Union[str, List[float]]] = {
"name": "Total",
"mean": total_mean,
"2sems": total_2sems,
}
# Construct plot
data: List[go.Bar] = []
for i, res in enumerate([fit_res, gen_res, total_res]):
data.append(
go.Bar( # pyre-ignore[16]: ...graph_objs` has no attr. `Bar`
x=methods,
y=res["mean"],
text=res["name"],
textposition="auto",
error_y={"type": "data", "array": res["2sems"], "visible": True},
marker={
"color": rgba(DISCRETE_COLOR_SCALE[i]),
"line": {"color": "rgb(0,0,0)", "width": 1.0},
},
opacity=0.6,
name=res["name"],
)
)
layout = go.Layout( # pyre-ignore[16]: ...graph_objs` has no attr. `Layout`
title=title,
showlegend=False,
yaxis={"title": "Time"},
xaxis={"title": "Method"},
)
return AxPlotConfig(
# pyre-ignore[16]: ...graph_objs` has no attr. `Figure`
data=go.Figure(layout=layout, data=data),
plot_type=AxPlotTypes.GENERIC,
)