#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import re
from copy import deepcopy
from typing import Any, Dict, Optional, Tuple
import numpy as np
import plotly.graph_objs as go
from ax.core.observation import ObservationFeatures
from ax.modelbridge.base import ModelBridge
from ax.plot.base import AxPlotConfig, AxPlotTypes, PlotData
from ax.plot.color import BLUE_SCALE, GREEN_PINK_SCALE, GREEN_SCALE
from ax.plot.helper import (
TNullableGeneratorRunsDict,
axis_range,
contour_config_to_trace,
get_fixed_values,
get_grid_for_parameter,
get_plot_data,
get_range_parameter,
get_range_parameters,
relativize_data,
rgb,
)
# type aliases
ContourPredictions = Tuple[
PlotData, np.ndarray, np.ndarray, np.ndarray, np.ndarray, Dict[str, bool]
]
MAX_PARAM_LENGTH = 40
[docs]def short_name(param_name: str) -> str:
if len(param_name) < MAX_PARAM_LENGTH:
return param_name
# Try to find a canonical prefix
prefix = re.split(r"\s|_|:", param_name)[0]
if len(prefix) > 10:
prefix = param_name[0:10]
suffix = param_name[len(param_name) - (MAX_PARAM_LENGTH - len(prefix) - 3) :]
return prefix + "..." + suffix
def _get_contour_predictions(
model: ModelBridge,
x_param_name: str,
y_param_name: str,
metric: str,
generator_runs_dict: TNullableGeneratorRunsDict,
density: int,
slice_values: Optional[Dict[str, Any]] = None,
fixed_features: Optional[ObservationFeatures] = None,
) -> ContourPredictions:
"""
slice_values is a dictionary {param_name: value} for the parameters that
are being sliced on.
"""
x_param = get_range_parameter(model, x_param_name)
y_param = get_range_parameter(model, y_param_name)
plot_data, _, _ = get_plot_data(
model, generator_runs_dict or {}, {metric}, fixed_features=fixed_features
)
grid_x = get_grid_for_parameter(x_param, density)
grid_y = get_grid_for_parameter(y_param, density)
scales = {"x": x_param.log_scale, "y": y_param.log_scale}
grid2_x, grid2_y = np.meshgrid(grid_x, grid_y)
grid2_x = grid2_x.flatten()
grid2_y = grid2_y.flatten()
if fixed_features is not None:
slice_values = fixed_features.parameters
else:
fixed_features = ObservationFeatures(parameters={})
fixed_values = get_fixed_values(model, slice_values)
param_grid_obsf = []
for i in range(density ** 2):
predf = deepcopy(fixed_features)
predf.parameters = fixed_values.copy()
predf.parameters[x_param_name] = grid2_x[i]
predf.parameters[y_param_name] = grid2_y[i]
param_grid_obsf.append(predf)
mu, cov = model.predict(param_grid_obsf)
f_plt = mu[metric]
sd_plt = np.sqrt(cov[metric][metric])
return plot_data, f_plt, sd_plt, grid_x, grid_y, scales
[docs]def plot_contour(
model: ModelBridge,
param_x: str,
param_y: str,
metric_name: str,
generator_runs_dict: TNullableGeneratorRunsDict = None,
relative: bool = False,
density: int = 50,
slice_values: Optional[Dict[str, Any]] = None,
lower_is_better: bool = False,
fixed_features: Optional[ObservationFeatures] = None,
) -> AxPlotConfig:
"""Plot predictions for a 2-d slice of the parameter space.
Args:
model: ModelBridge that contains model for predictions
param_x: Name of parameter that will be sliced on x-axis
param_y: Name of parameter that will be sliced on y-axis
metric_name: Name of metric to plot
generator_runs_dict: A dictionary {name: generator run} of generator runs
whose arms will be plotted, if they lie in the slice.
relative: Predictions relative to status quo
density: Number of points along slice to evaluate predictions.
slice_values: A dictionary {name: val} for the fixed values of the
other parameters. If not provided, then the status quo values will
be used if there is a status quo, otherwise the mean of numeric
parameters or the mode of choice parameters.
lower_is_better: Lower values for metric are better.
fixed_features: An ObservationFeatures object containing the values of
features (including non-parameter features like context) to be set
in the slice.
"""
if param_x == param_y:
raise ValueError("Please select different parameters for x- and y-dimensions.")
data, f_plt, sd_plt, grid_x, grid_y, scales = _get_contour_predictions(
model=model,
x_param_name=param_x,
y_param_name=param_y,
metric=metric_name,
generator_runs_dict=generator_runs_dict,
density=density,
slice_values=slice_values,
)
config = {
"arm_data": data,
"blue_scale": BLUE_SCALE,
"density": density,
"f": f_plt,
"green_scale": GREEN_SCALE,
"green_pink_scale": GREEN_PINK_SCALE,
"grid_x": grid_x,
"grid_y": grid_y,
"lower_is_better": lower_is_better,
"metric": metric_name,
"rel": relative,
"sd": sd_plt,
"xvar": param_x,
"yvar": param_y,
"x_is_log": scales["x"],
"y_is_log": scales["y"],
}
config = AxPlotConfig(config, plot_type=AxPlotTypes.GENERIC).data
traces = contour_config_to_trace(config)
density = config["density"]
grid_x = config["grid_x"]
grid_y = config["grid_y"]
lower_is_better = config["lower_is_better"]
xvar = config["xvar"]
yvar = config["yvar"]
x_is_log = config["x_is_log"]
y_is_log = config["y_is_log"]
xrange = axis_range(grid_x, x_is_log)
yrange = axis_range(grid_y, y_is_log)
xtype = "log" if x_is_log else "linear"
ytype = "log" if y_is_log else "linear"
layout = {
"annotations": [
{
"font": {"size": 14},
"showarrow": False,
"text": "Mean",
"x": 0.25,
"xanchor": "center",
"xref": "paper",
"y": 1,
"yanchor": "bottom",
"yref": "paper",
},
{
"font": {"size": 14},
"showarrow": False,
"text": "Standard Error",
"x": 0.8,
"xanchor": "center",
"xref": "paper",
"y": 1,
"yanchor": "bottom",
"yref": "paper",
},
],
"autosize": False,
"height": 450,
"hovermode": "closest",
"legend": {"orientation": "h", "x": 0, "y": -0.25},
"margin": {"b": 100, "l": 35, "pad": 0, "r": 35, "t": 35},
"width": 950,
"xaxis": {
"anchor": "y",
"autorange": False,
"domain": [0.05, 0.45],
"exponentformat": "e",
"range": xrange,
"tickfont": {"size": 11},
"tickmode": "auto",
"title": xvar,
"type": xtype,
},
"xaxis2": {
"anchor": "y2",
"autorange": False,
"domain": [0.6, 1],
"exponentformat": "e",
"range": xrange,
"tickfont": {"size": 11},
"tickmode": "auto",
"title": xvar,
"type": xtype,
},
"yaxis": {
"anchor": "x",
"autorange": False,
"domain": [0, 1],
"exponentformat": "e",
"range": yrange,
"tickfont": {"size": 11},
"tickmode": "auto",
"title": yvar,
"type": ytype,
},
"yaxis2": {
"anchor": "x2",
"autorange": False,
"domain": [0, 1],
"exponentformat": "e",
"range": yrange,
"tickfont": {"size": 11},
"tickmode": "auto",
"type": ytype,
},
}
fig = go.Figure(data=traces, layout=layout) # pyre-ignore[16]
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
[docs]def interact_contour(
model: ModelBridge,
metric_name: str,
generator_runs_dict: TNullableGeneratorRunsDict = None,
relative: bool = False,
density: int = 50,
slice_values: Optional[Dict[str, Any]] = None,
lower_is_better: bool = False,
fixed_features: Optional[ObservationFeatures] = None,
) -> AxPlotConfig:
"""Create interactive plot with predictions for a 2-d slice of the parameter
space.
Args:
model: ModelBridge that contains model for predictions
metric_name: Name of metric to plot
generator_runs_dict: A dictionary {name: generator run} of generator runs
whose arms will be plotted, if they lie in the slice.
relative: Predictions relative to status quo
density: Number of points along slice to evaluate predictions.
slice_values: A dictionary {name: val} for the fixed values of the
other parameters. If not provided, then the status quo values will
be used if there is a status quo, otherwise the mean of numeric
parameters or the mode of choice parameters.
lower_is_better: Lower values for metric are better.
fixed_features: An ObservationFeatures object containing the values of
features (including non-parameter features like context) to be set
in the slice.
"""
range_parameters = get_range_parameters(model)
plot_data, _, _ = get_plot_data(
model, generator_runs_dict or {}, {metric_name}, fixed_features=fixed_features
)
# TODO T38563759: Sort parameters by feature importances
param_names = [parameter.name for parameter in range_parameters]
is_log_dict: Dict[str, bool] = {}
grid_dict: Dict[str, np.ndarray] = {}
for parameter in range_parameters:
is_log_dict[parameter.name] = parameter.log_scale
grid_dict[parameter.name] = get_grid_for_parameter(parameter, density)
f_dict: Dict[str, Dict[str, np.ndarray]] = {
param1: {param2: [] for param2 in param_names} for param1 in param_names
}
sd_dict: Dict[str, Dict[str, np.ndarray]] = {
param1: {param2: [] for param2 in param_names} for param1 in param_names
}
for param1 in param_names:
for param2 in param_names:
_, f_plt, sd_plt, _, _, _ = _get_contour_predictions(
model=model,
x_param_name=param1,
y_param_name=param2,
metric=metric_name,
generator_runs_dict=generator_runs_dict,
density=density,
slice_values=slice_values,
fixed_features=fixed_features,
)
f_dict[param1][param2] = f_plt
sd_dict[param1][param2] = sd_plt
config = {
"arm_data": plot_data,
"blue_scale": BLUE_SCALE,
"density": density,
"f_dict": f_dict,
"green_scale": GREEN_SCALE,
"green_pink_scale": GREEN_PINK_SCALE,
"grid_dict": grid_dict,
"lower_is_better": lower_is_better,
"metric": metric_name,
"rel": relative,
"sd_dict": sd_dict,
"is_log_dict": is_log_dict,
"param_names": param_names,
}
config = AxPlotConfig(config, plot_type=AxPlotTypes.GENERIC).data
arm_data = config["arm_data"]
density = config["density"]
grid_dict = config["grid_dict"]
f_dict = config["f_dict"]
lower_is_better = config["lower_is_better"]
metric = config["metric"]
rel = config["rel"]
sd_dict = config["sd_dict"]
is_log_dict = config["is_log_dict"]
param_names = config["param_names"]
green_scale = config["green_scale"]
green_pink_scale = config["green_pink_scale"]
blue_scale = config["blue_scale"]
CONTOUR_CONFIG = {
"autocolorscale": False,
"autocontour": True,
"contours": {"coloring": "heatmap"},
"hoverinfo": "x+y+z",
"ncontours": int(density / 2),
"type": "contour",
}
if rel:
f_scale = reversed(green_pink_scale) if lower_is_better else green_pink_scale
else:
f_scale = green_scale
f_contour_trace_base = {
"colorbar": {
"len": 0.875,
"x": 0.45,
"y": 0.5,
"ticksuffix": "%" if rel else "",
"tickfont": {"size": 8},
},
"colorscale": [(i / (len(f_scale) - 1), rgb(v)) for i, v in enumerate(f_scale)],
"xaxis": "x",
"yaxis": "y",
# zmax and zmin are ignored if zauto is true
"zauto": not rel,
}
sd_contour_trace_base = {
"colorbar": {
"len": 0.875,
"x": 1,
"y": 0.5,
"ticksuffix": "%" if rel else "",
"tickfont": {"size": 8},
},
"colorscale": [
(i / (len(blue_scale) - 1), rgb(v)) for i, v in enumerate(blue_scale)
],
"xaxis": "x2",
"yaxis": "y2",
}
f_contour_trace_base.update(CONTOUR_CONFIG)
sd_contour_trace_base.update(CONTOUR_CONFIG)
insample_param_values = {}
for param_name in param_names:
insample_param_values[param_name] = []
for arm_name in arm_data["in_sample"].keys():
insample_param_values[param_name].append(
arm_data["in_sample"][arm_name]["parameters"][param_name]
)
insample_arm_text = list(arm_data["in_sample"].keys())
out_of_sample_param_values = {}
for param_name in param_names:
out_of_sample_param_values[param_name] = {}
for generator_run_name in arm_data["out_of_sample"].keys():
out_of_sample_param_values[param_name][generator_run_name] = []
for arm_name in arm_data["out_of_sample"][generator_run_name].keys():
out_of_sample_param_values[param_name][generator_run_name].append(
arm_data["out_of_sample"][generator_run_name][arm_name][
"parameters"
][param_name]
)
out_of_sample_arm_text = {}
for generator_run_name in arm_data["out_of_sample"].keys():
out_of_sample_arm_text[generator_run_name] = [
"<em>Candidate " + arm_name + "</em>"
for arm_name in arm_data["out_of_sample"][generator_run_name].keys()
]
# Number of traces for each pair of parameters
trace_cnt = 4 + (len(arm_data["out_of_sample"]) * 2)
xbuttons = []
ybuttons = []
for xvar in param_names:
xbutton_data_args = {"x": [], "y": [], "z": []}
for yvar in param_names:
res = relativize_data(
f_dict[xvar][yvar], sd_dict[xvar][yvar], rel, arm_data, metric
)
f_final = res[0]
sd_final = res[1]
# transform to nested array
f_plt = []
for ind in range(0, len(f_final), density):
f_plt.append(f_final[ind : ind + density])
sd_plt = []
for ind in range(0, len(sd_final), density):
sd_plt.append(sd_final[ind : ind + density])
# grid + in-sample
xbutton_data_args["x"] += [
grid_dict[xvar],
grid_dict[xvar],
insample_param_values[xvar],
insample_param_values[xvar],
]
xbutton_data_args["y"] += [
grid_dict[yvar],
grid_dict[yvar],
insample_param_values[yvar],
insample_param_values[yvar],
]
xbutton_data_args["z"] = xbutton_data_args["z"] + [f_plt, sd_plt, [], []]
for generator_run_name in out_of_sample_param_values[xvar]:
generator_run_x_vals = out_of_sample_param_values[xvar][
generator_run_name
]
xbutton_data_args["x"] += [generator_run_x_vals] * 2
for generator_run_name in out_of_sample_param_values[yvar]:
generator_run_y_vals = out_of_sample_param_values[yvar][
generator_run_name
]
xbutton_data_args["y"] += [generator_run_y_vals] * 2
xbutton_data_args["z"] += [[]] * 2
xbutton_args = [
xbutton_data_args,
{
"xaxis.title": short_name(xvar),
"xaxis2.title": short_name(xvar),
"xaxis.range": axis_range(grid_dict[xvar], is_log_dict[xvar]),
"xaxis2.range": axis_range(grid_dict[xvar], is_log_dict[xvar]),
},
]
xbuttons.append({"args": xbutton_args, "label": xvar, "method": "update"})
# No y button for first param so initial value is sane
for y_idx in range(1, len(param_names)):
visible = [False] * (len(param_names) * trace_cnt)
for i in range(y_idx * trace_cnt, (y_idx + 1) * trace_cnt):
visible[i] = True
y_param = param_names[y_idx]
ybuttons.append(
{
"args": [
{"visible": visible},
{
"yaxis.title": short_name(y_param),
"yaxis.range": axis_range(
grid_dict[y_param], is_log_dict[y_param]
),
"yaxis2.range": axis_range(
grid_dict[y_param], is_log_dict[y_param]
),
},
],
"label": param_names[y_idx],
"method": "update",
}
)
# calculate max of abs(outcome), used for colorscale
# TODO(T37079623) Make this work for relative outcomes
# let f_absmax = Math.max(Math.abs(Math.min(...f_final)), Math.max(...f_final))
traces = []
xvar = param_names[0]
base_in_sample_arm_config = None
# start symbol at 2 for out-of-sample candidate markers
i = 2
for yvar_idx, yvar in enumerate(param_names):
cur_visible = yvar_idx == 1
f_start = xbuttons[0]["args"][0]["z"][trace_cnt * yvar_idx]
sd_start = xbuttons[0]["args"][0]["z"][trace_cnt * yvar_idx + 1]
# create traces
f_trace = {
"x": grid_dict[xvar],
"y": grid_dict[yvar],
"z": f_start,
"visible": cur_visible,
}
for key in f_contour_trace_base.keys():
f_trace[key] = f_contour_trace_base[key]
sd_trace = {
"x": grid_dict[xvar],
"y": grid_dict[yvar],
"z": sd_start,
"visible": cur_visible,
}
for key in sd_contour_trace_base.keys():
sd_trace[key] = sd_contour_trace_base[key]
f_in_sample_arm_trace = {"xaxis": "x", "yaxis": "y"}
sd_in_sample_arm_trace = {"showlegend": False, "xaxis": "x2", "yaxis": "y2"}
base_in_sample_arm_config = {
"hoverinfo": "text",
"legendgroup": "In-sample",
"marker": {"color": "black", "symbol": 1, "opacity": 0.5},
"mode": "markers",
"name": "In-sample",
"text": insample_arm_text,
"type": "scatter",
"visible": cur_visible,
"x": insample_param_values[xvar],
"y": insample_param_values[yvar],
}
for key in base_in_sample_arm_config.keys():
f_in_sample_arm_trace[key] = base_in_sample_arm_config[key]
sd_in_sample_arm_trace[key] = base_in_sample_arm_config[key]
traces += [f_trace, sd_trace, f_in_sample_arm_trace, sd_in_sample_arm_trace]
# iterate over out-of-sample arms
for generator_run_name in arm_data["out_of_sample"].keys():
traces.append(
{
"hoverinfo": "text",
"legendgroup": generator_run_name,
"marker": {"color": "black", "symbol": i, "opacity": 0.5},
"mode": "markers",
"name": generator_run_name,
"text": out_of_sample_arm_text[generator_run_name],
"type": "scatter",
"xaxis": "x",
"x": out_of_sample_param_values[xvar][generator_run_name],
"yaxis": "y",
"y": out_of_sample_param_values[yvar][generator_run_name],
"visible": cur_visible,
}
)
traces.append(
{
"hoverinfo": "text",
"legendgroup": generator_run_name,
"marker": {"color": "black", "symbol": i, "opacity": 0.5},
"mode": "markers",
"name": "In-sample",
"showlegend": False,
"text": out_of_sample_arm_text[generator_run_name],
"type": "scatter",
"x": out_of_sample_param_values[xvar][generator_run_name],
"xaxis": "x2",
"y": out_of_sample_param_values[yvar][generator_run_name],
"yaxis": "y2",
"visible": cur_visible,
}
)
i += 1
xrange = axis_range(grid_dict[xvar], is_log_dict[xvar])
yrange = axis_range(grid_dict[yvar], is_log_dict[yvar])
xtype = "log" if is_log_dict[xvar] else "linear"
ytype = "log" if is_log_dict[yvar] else "linear"
layout = {
"annotations": [
{
"font": {"size": 14},
"showarrow": False,
"text": "Mean",
"x": 0.25,
"xanchor": "center",
"xref": "paper",
"y": 1,
"yanchor": "bottom",
"yref": "paper",
},
{
"font": {"size": 14},
"showarrow": False,
"text": "Standard Error",
"x": 0.8,
"xanchor": "center",
"xref": "paper",
"y": 1,
"yanchor": "bottom",
"yref": "paper",
},
{
"x": 0.26,
"y": -0.26,
"xref": "paper",
"yref": "paper",
"text": "x-param:",
"showarrow": False,
"yanchor": "top",
"xanchor": "left",
},
{
"x": 0.26,
"y": -0.4,
"xref": "paper",
"yref": "paper",
"text": "y-param:",
"showarrow": False,
"yanchor": "top",
"xanchor": "left",
},
],
"updatemenus": [
{
"x": 0.35,
"y": -0.29,
"buttons": xbuttons,
"xanchor": "left",
"yanchor": "middle",
"direction": "up",
},
{
"x": 0.35,
"y": -0.43,
"buttons": ybuttons,
"xanchor": "left",
"yanchor": "middle",
"direction": "up",
},
],
"autosize": False,
"height": 450,
"hovermode": "closest",
"legend": {"orientation": "v", "x": 0, "y": -0.2, "yanchor": "top"},
"margin": {"b": 100, "l": 35, "pad": 0, "r": 35, "t": 35},
"width": 950,
"xaxis": {
"anchor": "y",
"autorange": False,
"domain": [0.05, 0.45],
"exponentformat": "e",
"range": xrange,
"tickfont": {"size": 11},
"tickmode": "auto",
"title": short_name(xvar),
"type": xtype,
},
"xaxis2": {
"anchor": "y2",
"autorange": False,
"domain": [0.6, 1],
"exponentformat": "e",
"range": xrange,
"tickfont": {"size": 11},
"tickmode": "auto",
"title": short_name(xvar),
"type": xtype,
},
"yaxis": {
"anchor": "x",
"autorange": False,
"domain": [0, 1],
"exponentformat": "e",
"range": yrange,
"tickfont": {"size": 11},
"tickmode": "auto",
"title": short_name(yvar),
"type": ytype,
},
"yaxis2": {
"anchor": "x2",
"autorange": False,
"domain": [0, 1],
"exponentformat": "e",
"range": yrange,
"tickfont": {"size": 11},
"tickmode": "auto",
"type": ytype,
},
}
fig = go.Figure(data=traces, layout=layout)
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
# return AxPlotConfig(config, plot_type=AxPlotTypes.INTERACT_CONTOUR)