#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import math
from collections import Counter
from logging import Logger
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import numpy as np
from ax.core.generator_run import GeneratorRun
from ax.core.observation import Observation, ObservationFeatures
from ax.core.parameter import ChoiceParameter, FixedParameter, Parameter, RangeParameter
from ax.core.types import TParameterization
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.prediction_utils import (
_compute_scalarized_outcome,
predict_at_point,
)
from ax.modelbridge.transforms.ivw import IVW
from ax.plot.base import DECIMALS, PlotData, PlotInSampleArm, PlotOutOfSampleArm, Z
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
logger: Logger = get_logger(__name__)
# Typing alias
RawData = List[Dict[str, Union[str, float]]]
TNullableGeneratorRunsDict = Optional[Dict[str, GeneratorRun]]
[docs]def extend_range(
lower: float, upper: float, percent: int = 10, log_scale: bool = False
) -> Tuple[float, float]:
"""Given a range of minimum and maximum values taken by values on a given axis,
extend it in both directions by a given percentage to have some margin within
the plot around its meaningful part.
"""
if upper <= lower:
raise ValueError(
f"`upper` should be greater than `lower`, got: {upper} (<= {lower})."
)
if log_scale:
raise NotImplementedError("Log scale not yet supported.")
margin = (upper - lower) * percent / 100
return lower - margin, upper + margin
def _format_dict(param_dict: TParameterization, name: str = "Parameterization") -> str:
"""Format a dictionary for labels.
Args:
param_dict: Dictionary to be formatted
name: String name of the thing being formatted.
Returns: stringified blob.
"""
if len(param_dict) >= 10:
blob = "{} has too many items to render on hover ({}).".format(
name, len(param_dict)
)
else:
blob = "<br><em>{}:</em><br>{}".format(
name, "<br>".join("{}: {}".format(n, v) for n, v in param_dict.items())
)
return blob
def _wrap_metric(metric_name: str) -> str:
"""Put a newline on "::" for metric names.
Args:
metric_name: metric name.
Returns: wrapped metric name.
"""
if "::" in metric_name:
return "<br>".join(metric_name.split("::"))
else:
return metric_name
def _format_CI(estimate: float, sd: float, relative: bool, zval: float = Z) -> str:
"""Format confidence intervals given estimate and standard deviation.
Args:
estimate: point estimate.
sd: standard deviation of point estimate.
relative: if True, '%' is appended.
zval: z-value associated with desired CI (e.g. 1.96 for 95% CIs)
Returns: formatted confidence interval.
"""
return "[{lb:.{digits}f}{perc}, {ub:.{digits}f}{perc}]".format(
lb=estimate - zval * sd,
ub=estimate + zval * sd,
digits=DECIMALS,
perc="%" if relative else "",
)
[docs]def arm_name_to_tuple(arm_name: str) -> Union[Tuple[int, int], Tuple[int]]:
tup = arm_name.split("_")
if len(tup) == 2:
try:
return (int(tup[0]), int(tup[1]))
except ValueError:
return (0,)
return (0,)
[docs]def arm_name_to_sort_key(arm_name: str) -> Tuple[str, int, int]:
"""Parses arm name into tuple suitable for reverse sorting by key
Example:
arm_names = ["0_0", "1_10", "1_2", "10_0", "control"]
sorted(arm_names, key=arm_name_to_sort_key, reverse=True)
["control", "0_0", "1_2", "1_10", "10_0"]
"""
try:
trial_index, arm_index = arm_name.split("_")
return ("", -int(trial_index), -int(arm_index))
except (ValueError, IndexError):
return (arm_name, 0, 0)
[docs]def resize_subtitles(figure: Dict[str, Any], size: int) -> Dict[str, Any]:
for ant in figure["layout"]["annotations"]:
ant["font"].update(size=size)
return figure
def _filter_dict(
param_dict: TParameterization, subset_keys: List[str]
) -> TParameterization:
"""Filter a dictionary to keys present in a given list."""
return {k: v for k, v in param_dict.items() if k in subset_keys}
def _get_in_sample_arms(
model: ModelBridge,
metric_names: Set[str],
fixed_features: Optional[ObservationFeatures] = None,
data_selector: Optional[Callable[[Observation], bool]] = None,
scalarized_metric_config: Optional[List[Dict[str, Dict[str, float]]]] = None,
) -> Tuple[Dict[str, PlotInSampleArm], RawData, Dict[str, TParameterization]]:
"""Get in-sample arms from a model with observed and predicted values
for specified metrics.
Returns a PlotInSampleArm object in which repeated observations are merged
with IVW, and a RawData object in which every observation is listed.
Fixed features input can be used to override fields of the insample arms
when making model predictions.
Args:
model: An instance of the model bridge.
metric_names: Restrict predictions to these metrics. If None, uses all
metrics in the model.
fixed_features: Features that should be fixed in the arms this function
will obtain predictions for.
data_selector: Function for selecting observations for plotting.
Returns:
A tuple containing
- Map from arm name to PlotInSampleArm.
- List of the data for each observation like::
{'metric_name': 'likes', 'arm_name': '0_0', 'mean': 1., 'sem': 0.1}
- Map from arm name to parameters
"""
observations = model.get_training_data()
training_in_design = model.training_in_design
if data_selector is not None:
observations = [obs for obs in observations if data_selector(obs)]
training_in_design = [
model.training_in_design[i]
for i, obs in enumerate(observations)
if data_selector(obs)
]
trial_selector = None
if fixed_features is not None:
trial_selector = fixed_features.trial_index
# Calculate raw data
raw_data = []
arm_name_to_parameters = {}
for obs in observations:
arm_name_to_parameters[obs.arm_name] = obs.features.parameters
for j, metric_name in enumerate(obs.data.metric_names):
if metric_name in metric_names:
raw_data.append(
{
"metric_name": metric_name,
"arm_name": obs.arm_name,
"mean": obs.data.means[j],
"sem": np.sqrt(obs.data.covariance[j, j]),
}
)
# Check that we have one ObservationFeatures per arm name since we
# key by arm name and the model is not Multi-task.
# If "TrialAsTask" is present, one of the arms is chosen based on the selected
# trial index in the fixed_features.
if ("TrialAsTask" not in model.transforms.keys() or trial_selector is None) and (
len(arm_name_to_parameters) != len(observations)
):
logger.error(
"Have observations of arms with different features but same"
" name. Arbitrary one will be plotted."
)
# Merge multiple measurements within each Observation with IVW to get
# un-modeled prediction
t = IVW(None, [])
observations = t.transform_observations(observations)
# Start filling in plot data
in_sample_plot: Dict[str, PlotInSampleArm] = {}
for i, obs in enumerate(observations):
if obs.arm_name is None:
raise ValueError("Observation must have arm name for plotting.")
# Extract raw measurement
obs_y = {} # Observed metric means.
obs_se = {} # Observed metric standard errors.
for j, metric_name in enumerate(obs.data.metric_names):
if metric_name in metric_names:
obs_y[metric_name] = obs.data.means[j]
obs_se[metric_name] = np.sqrt(obs.data.covariance[j, j])
# Obtain aggregated outcomes if scalarized_metric_config is provided
if scalarized_metric_config is not None:
for agg_metric in scalarized_metric_config:
agg_metric_name = agg_metric["name"]
if agg_metric_name in metric_names:
agg_mean, agg_var = _compute_scalarized_outcome(
mean_dict=obs.data.means_dict,
cov_dict=obs.data.covariance_matrix,
agg_metric_weight_dict=agg_metric["weight"],
)
obs_y[agg_metric_name] = agg_mean
obs_se[agg_metric_name] = np.sqrt(agg_var)
if training_in_design[i]:
# Update with the input fixed features
features = obs.features
if fixed_features is not None:
features.update_features(fixed_features)
# Make a prediction.
pred_y, pred_se = predict_at_point(
model, features, metric_names, scalarized_metric_config
)
elif (trial_selector is not None) and (
obs.features.trial_index != trial_selector
):
# check whether the observation is from the right trial
# need to use raw data in the selected trial for out-of-design points
continue
else:
pred_y = obs_y
pred_se = obs_se
in_sample_plot[not_none(obs.arm_name)] = PlotInSampleArm(
name=not_none(obs.arm_name),
y=obs_y,
se=obs_se,
parameters=obs.features.parameters,
y_hat=pred_y,
se_hat=pred_se,
context_stratum=None,
)
return in_sample_plot, raw_data, arm_name_to_parameters
def _get_out_of_sample_arms(
model: ModelBridge,
generator_runs_dict: Dict[str, GeneratorRun],
metric_names: Set[str],
fixed_features: Optional[ObservationFeatures] = None,
scalarized_metric_config: Optional[List[Dict[str, Dict[str, float]]]] = None,
) -> Dict[str, Dict[str, PlotOutOfSampleArm]]:
"""Get out-of-sample predictions from a model given a dict of generator runs.
Fixed features input can be used to override fields of the candidate arms
when making model predictions.
Args:
model: The model.
generator_runs_dict: a mapping from generator run name to generator run.
metric_names: metrics to include in the plot.
Returns:
A mapping from name to a mapping from arm name to plot.
"""
out_of_sample_plot: Dict[str, Dict[str, PlotOutOfSampleArm]] = {}
for generator_run_name, generator_run in generator_runs_dict.items():
out_of_sample_plot[generator_run_name] = {}
for arm in generator_run.arms:
# This assumes context is None
obsf = ObservationFeatures.from_arm(arm)
if fixed_features is not None:
obsf.update_features(fixed_features)
# Make a prediction
try:
pred_y, pred_se = predict_at_point(
model, obsf, metric_names, scalarized_metric_config
)
except Exception:
# Check if it is an out-of-design arm.
if not model.model_space.check_membership(obsf.parameters):
# Skip this point
continue
else:
# It should have worked
raise
arm_name = arm.name_or_short_signature
out_of_sample_plot[generator_run_name][arm_name] = PlotOutOfSampleArm(
name=arm_name,
parameters=obsf.parameters,
y_hat=pred_y,
se_hat=pred_se,
context_stratum=None,
)
return out_of_sample_plot
[docs]def get_plot_data(
model: ModelBridge,
generator_runs_dict: Dict[str, GeneratorRun],
metric_names: Optional[Set[str]] = None,
fixed_features: Optional[ObservationFeatures] = None,
data_selector: Optional[Callable[[Observation], bool]] = None,
scalarized_metric_config: Optional[List[Dict[str, Dict[str, float]]]] = None,
) -> Tuple[PlotData, RawData, Dict[str, TParameterization]]:
"""Format data object with metrics for in-sample and out-of-sample
arms.
Calculate both observed and predicted metrics for in-sample arms.
Calculate predicted metrics for out-of-sample arms passed via the
`generator_runs_dict` argument.
In PlotData, in-sample observations are merged with IVW. In RawData, they
are left un-merged and given as a list of dictionaries, one for each
observation and having keys 'arm_name', 'mean', and 'sem'.
Args:
model: The model.
generator_runs_dict: a mapping from generator run name to generator run.
metric_names: Restrict predictions to this set. If None, all metrics
in the model will be returned.
fixed_features: Fixed features to use when making model predictions.
data_selector: Function for selecting observations for plotting.
scalarized_metric_config: An optional list of dicts specifying how to aggregate
multiple metrics into a single scalarized metric. For each dict, the key is
the name of the new scalarized metric, and the value is a dictionary mapping
each metric to its weight. e.g.
{"name": "metric1:agg", "weight": {"metric1_c1": 0.5, "metric1_c2": 0.5}}.
Returns:
A tuple containing
- PlotData object with in-sample and out-of-sample predictions.
- List of observations like::
{'metric_name': 'likes', 'arm_name': '0_1', 'mean': 1., 'sem': 0.1}.
- Mapping from arm name to parameters.
"""
metrics_plot = model.metric_names if metric_names is None else metric_names
in_sample_plot, raw_data, cond_name_to_parameters = _get_in_sample_arms(
model=model,
metric_names=metrics_plot,
fixed_features=fixed_features,
data_selector=data_selector,
scalarized_metric_config=scalarized_metric_config,
)
out_of_sample_plot = _get_out_of_sample_arms(
model=model,
generator_runs_dict=generator_runs_dict,
metric_names=metrics_plot,
fixed_features=fixed_features,
scalarized_metric_config=scalarized_metric_config,
)
status_quo_name = None if model.status_quo is None else model.status_quo.arm_name
plot_data = PlotData(
metrics=list(metrics_plot),
in_sample=in_sample_plot,
out_of_sample=out_of_sample_plot,
status_quo_name=status_quo_name,
)
return plot_data, raw_data, cond_name_to_parameters
[docs]def get_range_parameter(model: ModelBridge, param_name: str) -> RangeParameter:
"""
Get the range parameter with the given name from the model.
Throws if parameter doesn't exist or is not a range parameter.
Args:
model: The model.
param_name: The name of the RangeParameter to be found.
Returns: The RangeParameter named `param_name`.
"""
range_param = model.model_space.parameters.get(param_name)
if range_param is None:
raise ValueError(f"Parameter `{param_name}` does not exist.")
if not isinstance(range_param, RangeParameter):
raise ValueError(f"{param_name} is not a RangeParameter")
return range_param
[docs]def get_range_parameters_from_list(
parameters: List[Parameter], min_num_values: int = 0
) -> List[RangeParameter]:
"""
Get a list of range parameters from a model.
Args:
parameters: List of parameters
min_num_values: Minimum number of values
Returns: List of RangeParameters.
"""
return [
parameter
for parameter in parameters
if isinstance(parameter, RangeParameter)
and parameter.cardinality() >= min_num_values # float has inf cardinality
]
[docs]def get_range_parameters(
model: ModelBridge, min_num_values: int = 0
) -> List[RangeParameter]:
"""
Get a list of range parameters from a model.
Args:
model: The model.
min_num_values: Minimum number of values
Returns: List of RangeParameters.
"""
return get_range_parameters_from_list(
parameters=list(model.model_space.parameters.values()),
min_num_values=min_num_values,
)
[docs]def get_grid_for_parameter(parameter: RangeParameter, density: int) -> np.ndarray:
"""Get a grid of points along the range of the parameter.
Will be a log-scale grid if parameter is log scale.
Args:
parameter: Parameter for which to generate grid.
density: Number of points in the grid.
"""
is_log = parameter.log_scale
if is_log:
grid = np.linspace(
np.log10(parameter.lower), np.log10(parameter.upper), density
)
grid = 10**grid
else:
grid = np.linspace(parameter.lower, parameter.upper, density)
return grid
[docs]def get_fixed_values(
model: ModelBridge,
slice_values: Optional[Dict[str, Any]] = None,
trial_index: Optional[int] = None,
) -> TParameterization:
"""Get fixed values for parameters in a slice plot.
If there is an in-design status quo, those values will be used. Otherwise,
the mean of RangeParameters or the mode of ChoiceParameters is used.
Any value in slice_values will override the above.
Args:
model: ModelBridge being used for plotting
slice_values: Map from parameter name to value at which is should be
fixed.
Returns: Map from parameter name to fixed value.
"""
if trial_index is not None:
if slice_values is None:
slice_values = {}
slice_values["TRIAL_PARAM"] = str(trial_index)
# Check if status_quo is in design
if model.status_quo is not None and model.model_space.check_membership(
model.status_quo.features.parameters
):
setx = model.status_quo.features.parameters
else:
observations = model.get_training_data()
setx = {}
for p_name, parameter in model.model_space.parameters.items():
# Exclude out of design status quo (no parameters)
vals = [
obs.features.parameters[p_name]
for obs in observations
if (
len(obs.features.parameters) > 0
and parameter.validate(obs.features.parameters[p_name])
)
]
if isinstance(parameter, FixedParameter):
setx[p_name] = parameter.value
elif isinstance(parameter, ChoiceParameter):
setx[p_name] = Counter(vals).most_common(1)[0][0]
elif isinstance(parameter, RangeParameter):
setx[p_name] = parameter.cast(np.mean(vals))
if slice_values is not None:
# slice_values has type Dictionary[str, Any]
setx.update(slice_values)
return setx
# Utility methods ported from JS
# pyre-fixme[2]: Parameter must be annotated.
[docs]def contour_config_to_trace(config) -> List[Dict[str, Any]]:
# Load from config
arm_data = config["arm_data"]
density = config["density"]
grid_x = config["grid_x"]
grid_y = config["grid_y"]
f = config["f"]
lower_is_better = config["lower_is_better"]
metric = config["metric"]
rel = config["rel"]
sd = config["sd"]
xvar = config["xvar"]
yvar = config["yvar"]
green_scale = config["green_scale"]
green_pink_scale = config["green_pink_scale"]
blue_scale = config["blue_scale"]
# format data
res = relativize_data(f, sd, rel, arm_data, metric)
f_final = res[0]
sd_final = res[1]
# calculate max of abs(outcome), used for colorscale
f_absmax = max(abs(min(f_final)), max(f_final))
# transform to nested array
f_plt = []
for ind in range(0, len(f_final), density):
f_plt.append(f_final[ind : ind + density])
sd_plt = []
for ind in range(0, len(sd_final), density):
sd_plt.append(sd_final[ind : ind + density])
CONTOUR_CONFIG = {
"autocolorscale": False,
"autocontour": True,
"contours": {"coloring": "heatmap"},
"hoverinfo": "x+y+z",
"ncontours": int(density / 2),
"type": "contour",
"x": grid_x,
"y": grid_y,
}
if rel:
f_scale = reversed(green_pink_scale) if lower_is_better else green_pink_scale
else:
f_scale = green_scale
f_trace = {
"colorbar": {
"x": 0.45,
"y": 0.5,
"ticksuffix": "%" if rel else "",
"tickfont": {"size": 8},
},
"colorscale": [(i / (len(f_scale) - 1), rgb(v)) for i, v in enumerate(f_scale)],
"xaxis": "x",
"yaxis": "y",
"z": f_plt,
# zmax and zmin are ignored if zauto is true
"zauto": not rel,
"zmax": f_absmax,
"zmin": -f_absmax,
}
sd_trace = {
"colorbar": {
"x": 1,
"y": 0.5,
"ticksuffix": "%" if rel else "",
"tickfont": {"size": 8},
},
"colorscale": [
(i / (len(blue_scale) - 1), rgb(v)) for i, v in enumerate(blue_scale)
],
"xaxis": "x2",
"yaxis": "y2",
"z": sd_plt,
}
f_trace.update(CONTOUR_CONFIG)
sd_trace.update(CONTOUR_CONFIG)
# get in-sample arms
arm_names = list(arm_data["in_sample"].keys())
arm_x = [
arm_data["in_sample"][arm_name]["parameters"][xvar] for arm_name in arm_names
]
arm_y = [
arm_data["in_sample"][arm_name]["parameters"][yvar] for arm_name in arm_names
]
arm_text = []
for arm_name in arm_names:
atext = f"Arm {arm_name}"
params = arm_data["in_sample"][arm_name]["parameters"]
ys = arm_data["in_sample"][arm_name]["y"]
ses = arm_data["in_sample"][arm_name]["se"]
for yname in ys.keys():
sem_str = f"{ses[yname]}" if ses[yname] is None else f"{ses[yname]:.6g}"
y_str = f"{ys[yname]}" if ys[yname] is None else f"{ys[yname]:.6g}"
atext += f"<br>{yname}: {y_str} (SEM: {sem_str})"
for pname in params.keys():
pval = params[pname]
pstr = f"{pval:.6g}" if isinstance(pval, float) else f"{pval}"
atext += f"<br>{pname}: {pstr}"
arm_text.append(atext)
# configs for in-sample arms
base_in_sample_arm_config = {
"hoverinfo": "text",
"legendgroup": "In-sample",
"marker": {"color": "black", "symbol": 1, "opacity": 0.5},
"mode": "markers",
"name": "In-sample",
"text": arm_text,
"type": "scatter",
"x": arm_x,
"y": arm_y,
}
f_in_sample_arm_trace = {"xaxis": "x", "yaxis": "y"}
sd_in_sample_arm_trace = {"showlegend": False, "xaxis": "x2", "yaxis": "y2"}
# pyre-fixme[6]: For 1st param expected `SupportsKeysAndGetItem[str, str]` but
# got `Dict[str, Union[Dict[str, Union[float, str]], List[typing.Any], str]]`.
f_in_sample_arm_trace.update(base_in_sample_arm_config)
# pyre-fixme[6]: For 1st param expected `SupportsKeysAndGetItem[str, Union[bool,
# str]]` but got `Dict[str, Union[Dict[str, Union[float, str]], List[typing.Any],
# str]]`.
sd_in_sample_arm_trace.update(base_in_sample_arm_config)
traces = [f_trace, sd_trace, f_in_sample_arm_trace, sd_in_sample_arm_trace]
# iterate over out-of-sample arms
for i, generator_run_name in enumerate(arm_data["out_of_sample"].keys()):
symbol = i + 2 # symbols starts from 2 for candidate markers
ax = []
ay = []
atext = []
for arm_name in arm_data["out_of_sample"][generator_run_name].keys():
ax.append(
arm_data["out_of_sample"][generator_run_name][arm_name]["parameters"][
xvar
]
)
ay.append(
arm_data["out_of_sample"][generator_run_name][arm_name]["parameters"][
yvar
]
)
atext.append("<em>Candidate " + arm_name + "</em>")
traces.append(
{
"hoverinfo": "text",
"legendgroup": generator_run_name,
"marker": {"color": "black", "symbol": symbol, "opacity": 0.5},
"mode": "markers",
"name": generator_run_name,
"text": atext,
"type": "scatter",
"xaxis": "x",
"x": ax,
"yaxis": "y",
"y": ay,
}
)
traces.append(
{
"hoverinfo": "text",
"legendgroup": generator_run_name,
"marker": {"color": "black", "symbol": symbol, "opacity": 0.5},
"mode": "markers",
"name": "In-sample",
"showlegend": False,
"text": atext,
"type": "scatter",
"x": ax,
"xaxis": "x2",
"y": ay,
"yaxis": "y2",
}
)
return traces
[docs]def axis_range(grid: List[float], is_log: bool) -> List[float]:
if is_log:
return [math.log10(min(grid)), math.log10(max(grid))]
else:
return [min(grid), max(grid)]
[docs]def relativize(m_t: float, sem_t: float, m_c: float, sem_c: float) -> List[float]:
r_hat = (m_t - m_c) / abs(m_c) - sem_c**2 * m_t / abs(m_c) ** 3
variance = (sem_t**2 + (m_t / m_c * sem_c) ** 2) / m_c**2
return [r_hat, math.sqrt(variance)]
[docs]def relativize_data(
f: List[float],
sd: List[float],
rel: bool,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
arm_data: Dict[Any, Any],
metric: str,
) -> List[List[float]]:
# if relative, extract status quo & compute ratio
f_final = [] if rel else f
sd_final = [] if rel else sd
if rel:
f_sq = arm_data["in_sample"][arm_data["status_quo_name"]]["y"][metric]
sd_sq = arm_data["in_sample"][arm_data["status_quo_name"]]["se"][metric]
for i in range(len(f)):
res = relativize(f[i], sd[i], f_sq, sd_sq)
f_final.append(100 * res[0])
sd_final.append(100 * res[1])
return [f_final, sd_final]
[docs]def rgb(arr: List[int]) -> str:
return "rgb({},{},{})".format(*arr)
[docs]def infer_is_relative(
model: ModelBridge, metrics: List[str], non_constraint_rel: bool
) -> Dict[str, bool]:
"""Determine whether or not to relativize a metric.
Metrics that are constraints will get this decision from their `relative` flag.
Other metrics will use the `default_rel`.
Args:
model: model fit on metrics.
metrics: list of metric names.
non_constraint_rel: whether or not to relativize non-constraint metrics
Returns:
Dict[str, bool] containing whether or not to relativize each input metric.
"""
relative = {}
constraint_relativity = {}
if model._optimization_config:
constraints = not_none(model._optimization_config).outcome_constraints
constraint_relativity = {
constraint.metric.name: constraint.relative for constraint in constraints
}
for metric in metrics:
if metric not in constraint_relativity:
relative[metric] = non_constraint_rel
else:
relative[metric] = constraint_relativity[metric]
return relative
[docs]def slice_config_to_trace(
# pyre-fixme[2]: Parameter must be annotated.
arm_data,
# pyre-fixme[2]: Parameter must be annotated.
arm_name_to_parameters,
f: List[float],
# pyre-fixme[2]: Parameter must be annotated.
fit_data,
# pyre-fixme[2]: Parameter must be annotated.
grid,
metric: str,
# pyre-fixme[2]: Parameter must be annotated.
param,
rel: bool,
# pyre-fixme[2]: Parameter must be annotated.
setx,
sd: List[float],
# pyre-fixme[2]: Parameter must be annotated.
is_log,
# pyre-fixme[2]: Parameter must be annotated.
visible,
) -> List[Dict[str, Any]]:
# format data
res = relativize_data(f, sd, rel, arm_data, metric)
f_final = res[0]
sd_final = res[1]
# get data for standard deviation fill plot
sd_upper = []
sd_lower = []
for i in range(len(sd)):
sd_upper.append(f_final[i] + 2 * sd_final[i])
sd_lower.append(f_final[i] - 2 * sd_final[i])
grid_rev = list(reversed(grid))
sd_lower_rev = list(reversed(sd_lower))
sd_x = grid + grid_rev
sd_y = sd_upper + sd_lower_rev
# get data for observed arms and error bars
arm_x = []
arm_y = []
arm_sem = []
for row in fit_data:
parameters = arm_name_to_parameters[row["arm_name"]]
plot = True
for p in setx.keys():
if p != param and parameters[p] != setx[p]:
plot = False
if plot:
arm_x.append(parameters[param])
arm_y.append(row["mean"])
arm_sem.append(row["sem"])
arm_res = relativize_data(arm_y, arm_sem, rel, arm_data, metric)
arm_y_final = arm_res[0]
arm_sem_final = [x * 2 if x is not None else None for x in arm_res[1]]
# create traces
f_trace = {
"x": grid,
"y": f_final,
"showlegend": False,
"hoverinfo": "x+y",
"line": {"color": "rgba(128, 177, 211, 1)"},
"visible": visible,
}
arms_trace = {
"x": arm_x,
"y": arm_y_final,
"mode": "markers",
"error_y": {
"type": "data",
"array": arm_sem_final,
"visible": True,
"color": "black",
},
"line": {"color": "black"},
"showlegend": False,
"hoverinfo": "x+y",
"visible": visible,
}
sd_trace = {
"x": sd_x,
"y": sd_y,
"fill": "toself",
"fillcolor": "rgba(128, 177, 211, 0.2)",
"line": {"color": "rgba(128, 177, 211, 0.0)"},
"showlegend": False,
"hoverinfo": "none",
"visible": visible,
}
traces = [sd_trace, f_trace, arms_trace]
# iterate over out-of-sample arms
for i, generator_run_name in enumerate(arm_data["out_of_sample"].keys()):
ax = []
ay = []
asem = []
atext = []
for arm_name in arm_data["out_of_sample"][generator_run_name].keys():
parameters = arm_data["out_of_sample"][generator_run_name][arm_name][
"parameters"
]
plot = True
for p in setx.keys():
if p != param and parameters[p] != setx[p]:
plot = False
if plot:
ax.append(parameters[param])
ay.append(
arm_data["out_of_sample"][generator_run_name][arm_name]["y_hat"][
metric
]
)
asem.append(
arm_data["out_of_sample"][generator_run_name][arm_name]["se_hat"][
metric
]
)
atext.append("<em>Candidate " + arm_name + "</em>")
out_of_sample_arm_res = relativize_data(ay, asem, rel, arm_data, metric)
ay_final = out_of_sample_arm_res[0]
asem_final = [x * 2 for x in out_of_sample_arm_res[1]]
traces.append(
{
"hoverinfo": "text",
"legendgroup": generator_run_name,
"marker": {"color": "black", "symbol": i + 1, "opacity": 0.5},
"mode": "markers",
"error_y": {
"type": "data",
"array": asem_final,
"visible": True,
"color": "black",
},
"name": generator_run_name,
"text": atext,
"type": "scatter",
"xaxis": "x",
"x": ax,
"yaxis": "y",
"y": ay_final,
"visible": visible,
}
)
return traces
[docs]def build_filter_trial(keep_trial_indices: List[int]) -> Callable[[Observation], bool]:
"""Creates a callable that filters observations based on trial_index"""
def trial_filter(obs: Observation) -> bool:
return obs.features.trial_index in keep_trial_indices
return trial_filter
[docs]def compose_annotation(
caption: str, x: float = 0.0, y: float = -0.15
) -> List[Dict[str, Any]]:
if not caption:
return []
return [
{
"showarrow": False,
"text": caption,
"x": x,
"xanchor": "left",
"xref": "paper",
"y": y,
"yanchor": "top",
"yref": "paper",
"align": "left",
},
]