This tutorial illustrates the core visualization utilities available in Ax.
import numpy as np
from ax import (
Arm,
ComparisonOp,
RangeParameter,
ParameterType,
SearchSpace,
SimpleExperiment,
OutcomeConstraint,
)
from ax.metrics.l2norm import L2NormMetric
from ax.modelbridge.cross_validation import cross_validate
from ax.modelbridge.registry import Models
from ax.plot.contour import interact_contour, plot_contour
from ax.plot.diagnostic import interact_cross_validation
from ax.plot.scatter import(
interact_fitted,
plot_objective_vs_constraints,
tile_fitted,
)
from ax.plot.slice import plot_slice
from ax.utils.measurement.synthetic_functions import hartmann6
from ax.utils.notebook.plotting import render, init_notebook_plotting
init_notebook_plotting()
[INFO 08-29 13:34:46] ipy_plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.
The vizualizations require an experiment object and a model fit on the evaluated data. The routine below is a copy of the Developer API tutorial, so the explanation here is omitted. Retrieving the experiment and model objects for each API paradigm is shown in the respective tutorials
noise_sd = 0.1
param_names = [f"x{i+1}" for i in range(6)] # x1, x2, ..., x6
def noisy_hartmann_evaluation_function(parameterization):
x = np.array([parameterization.get(p_name) for p_name in param_names])
noise1, noise2 = np.random.normal(0, noise_sd, 2)
return {
"hartmann6": (hartmann6(x) + noise1, noise_sd),
"l2norm": (np.sqrt((x ** 2).sum()) + noise2, noise_sd)
}
hartmann_search_space = SearchSpace(
parameters=[
RangeParameter(
name=p_name, parameter_type=ParameterType.FLOAT, lower=0.0, upper=1.0
)
for p_name in param_names
]
)
exp = SimpleExperiment(
name="test_branin",
search_space=hartmann_search_space,
evaluation_function=noisy_hartmann_evaluation_function,
objective_name="hartmann6",
minimize=True,
outcome_constraints=[
OutcomeConstraint(
metric=L2NormMetric(
name="l2norm", param_names=param_names, noise_sd=0.2
),
op=ComparisonOp.LEQ,
bound=1.25,
relative=False,
)
],
)
After doing (N_BATCHES=15
) rounds of optimization, fit final GP using all data to feed into the plots.
N_RANDOM = 5
BATCH_SIZE = 1
N_BATCHES = 15
sobol = Models.SOBOL(exp.search_space)
exp.new_batch_trial(generator_run=sobol.gen(N_RANDOM))
for i in range(N_BATCHES):
intermediate_gp = Models.GPEI(experiment=exp, data=exp.eval())
exp.new_trial(generator_run=intermediate_gp.gen(BATCH_SIZE))
model = Models.GPEI(experiment=exp, data=exp.eval())
The plot below shows the response surface for hartmann6
metric as a function of the x1
, x2
parameters.
The other parameters are fixed in the middle of their respective ranges, which in this example is 0.5 for all of them.
render(plot_contour(model=model, param_x="x1", param_y="x2", metric_name='hartmann6'))
The plot below allows toggling between different pairs of parameters to view the contours.
render(interact_contour(model=model, metric_name='hartmann6'))
This plot illustrates the tradeoffs achievable for 2 different metrics. The plot takes the x-axis metric as input (usually the objective) and allows toggling among all other metrics for the y-axis.
This is useful to get a sense of the pareto frontier (i.e. what is the best objective value achievable for different bounds on the constraint)
render(plot_objective_vs_constraints(model, 'hartmann6', rel=False))
CV plots are useful to check how well the model predictions calibrate against the actual measurements. If all points are close to the dashed line, then the model is a good predictor of the real data.
cv_results = cross_validate(model)
render(interact_cross_validation(cv_results))
Slice plots show the metric outcome as a function of one parameter while fixing the others. They serve a similar function as contour plots.
render(plot_slice(model, "x2", "hartmann6"))
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-9-77cae64bb37d> in <module> ----> 1 render(plot_slice(model, "x2", "hartmann6")) ~/anaconda3/lib/python3.7/site-packages/ax/plot/slice.py in plot_slice(model, param_name, metric_name, generator_runs_dict, relative, density, slice_values, fixed_features) 215 } 216 --> 217 fig = go.Figure(data=traces, layout=layout) # pyre-ignore[16] 218 return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC) 219 ~/anaconda3/lib/python3.7/site-packages/plotly/graph_objs/_figure.py in __init__(self, data, layout, frames, skip_invalid, **kwargs) 550 """ 551 super(Figure, --> 552 self).__init__(data, layout, frames, skip_invalid, **kwargs) 553 554 def add_area( ~/anaconda3/lib/python3.7/site-packages/plotly/basedatatypes.py in __init__(self, data, layout_plotly, frames, skip_invalid, **kwargs) 154 # ### Import traces ### 155 data = self._data_validator.validate_coerce(data, --> 156 skip_invalid=skip_invalid) 157 158 # ### Save tuple of trace objects ### ~/anaconda3/lib/python3.7/site-packages/_plotly_utils/basevalidators.py in validate_coerce(self, v, skip_invalid) 2333 else: 2334 trace = self.class_map[trace_type]( -> 2335 skip_invalid=skip_invalid, **v_copy) 2336 res.append(trace) 2337 else: ~/anaconda3/lib/python3.7/site-packages/plotly/graph_objs/__init__.py in __init__(self, arg, cliponaxis, connectgaps, customdata, customdatasrc, dx, dy, error_x, error_y, fill, fillcolor, groupnorm, hoverinfo, hoverinfosrc, hoverlabel, hoveron, hovertemplate, hovertemplatesrc, hovertext, hovertextsrc, ids, idssrc, legendgroup, line, marker, meta, metasrc, mode, name, opacity, orientation, r, rsrc, selected, selectedpoints, showlegend, stackgaps, stackgroup, stream, t, text, textfont, textposition, textpositionsrc, textsrc, tsrc, uid, uirevision, unselected, visible, x, x0, xaxis, xcalendar, xsrc, y, y0, yaxis, ycalendar, ysrc, **kwargs) 39591 self['legendgroup'] = legendgroup if legendgroup is not None else _v 39592 _v = arg.pop('line', None) > 39593 self['line'] = line if line is not None else _v 39594 _v = arg.pop('marker', None) 39595 self['marker'] = marker if marker is not None else _v ~/anaconda3/lib/python3.7/site-packages/plotly/basedatatypes.py in __setitem__(self, prop, value) 3306 # ### Handle compound property ### 3307 if isinstance(validator, CompoundValidator): -> 3308 self._set_compound_prop(prop, value) 3309 3310 # ### Handle compound array property ### ~/anaconda3/lib/python3.7/site-packages/plotly/basedatatypes.py in _set_compound_prop(self, prop, val) 3619 validator = self._validators.get(prop) 3620 # type: BasePlotlyType -> 3621 val = validator.validate_coerce(val, skip_invalid=self._skip_invalid) 3622 3623 # Save deep copies of current and new states ~/anaconda3/lib/python3.7/site-packages/_plotly_utils/basevalidators.py in validate_coerce(self, v, skip_invalid) 2129 2130 elif isinstance(v, dict): -> 2131 v = self.data_class(v, skip_invalid=skip_invalid) 2132 2133 elif isinstance(v, self.data_class): ~/anaconda3/lib/python3.7/site-packages/plotly/graph_objs/scatter/__init__.py in __init__(self, arg, color, dash, shape, simplify, smoothing, width, **kwargs) 2436 # ---------------------------------- 2437 _v = arg.pop('color', None) -> 2438 self['color'] = color if color is not None else _v 2439 _v = arg.pop('dash', None) 2440 self['dash'] = dash if dash is not None else _v ~/anaconda3/lib/python3.7/site-packages/plotly/basedatatypes.py in __setitem__(self, prop, value) 3315 # ### Handle simple property ### 3316 else: -> 3317 self._set_prop(prop, value) 3318 3319 # Handle non-scalar case ~/anaconda3/lib/python3.7/site-packages/plotly/basedatatypes.py in _set_prop(self, prop, val) 3560 return 3561 else: -> 3562 raise err 3563 3564 # val is None ~/anaconda3/lib/python3.7/site-packages/plotly/basedatatypes.py in _set_prop(self, prop, val) 3555 validator = self._validators.get(prop) 3556 try: -> 3557 val = validator.validate_coerce(val) 3558 except ValueError as err: 3559 if self._skip_invalid: ~/anaconda3/lib/python3.7/site-packages/_plotly_utils/basevalidators.py in validate_coerce(self, v, should_raise) 1162 validated_v = self.vc_scalar(v) 1163 if validated_v is None and should_raise: -> 1164 self.raise_invalid_val(v) 1165 1166 v = validated_v ~/anaconda3/lib/python3.7/site-packages/_plotly_utils/basevalidators.py in raise_invalid_val(self, v, inds) 275 typ=type_str(v), 276 v=repr(v), --> 277 valid_clr_desc=self.description())) 278 279 def raise_invalid_elements(self, invalid_els): ValueError: Invalid value of type 'builtins.str' received for the 'color' property of scatter.line Received value: 'transparent' The 'color' property is a color and may be specified as: - A hex string (e.g. '#ff0000') - An rgb/rgba string (e.g. 'rgb(255,0,0)') - An hsl/hsla string (e.g. 'hsl(0,100%,50%)') - An hsv/hsva string (e.g. 'hsv(0,100%,100%)') - A named CSS color: aliceblue, antiquewhite, aqua, aquamarine, azure, beige, bisque, black, blanchedalmond, blue, blueviolet, brown, burlywood, cadetblue, chartreuse, chocolate, coral, cornflowerblue, cornsilk, crimson, cyan, darkblue, darkcyan, darkgoldenrod, darkgray, darkgrey, darkgreen, darkkhaki, darkmagenta, darkolivegreen, darkorange, darkorchid, darkred, darksalmon, darkseagreen, darkslateblue, darkslategray, darkslategrey, darkturquoise, darkviolet, deeppink, deepskyblue, dimgray, dimgrey, dodgerblue, firebrick, floralwhite, forestgreen, fuchsia, gainsboro, ghostwhite, gold, goldenrod, gray, grey, green, greenyellow, honeydew, hotpink, indianred, indigo, ivory, khaki, lavender, lavenderblush, lawngreen, lemonchiffon, lightblue, lightcoral, lightcyan, lightgoldenrodyellow, lightgray, lightgrey, lightgreen, lightpink, lightsalmon, lightseagreen, lightskyblue, lightslategray, lightslategrey, lightsteelblue, lightyellow, lime, limegreen, linen, magenta, maroon, mediumaquamarine, mediumblue, mediumorchid, mediumpurple, mediumseagreen, mediumslateblue, mediumspringgreen, mediumturquoise, mediumvioletred, midnightblue, mintcream, mistyrose, moccasin, navajowhite, navy, oldlace, olive, olivedrab, orange, orangered, orchid, palegoldenrod, palegreen, paleturquoise, palevioletred, papayawhip, peachpuff, peru, pink, plum, powderblue, purple, red, rosybrown, royalblue, saddlebrown, salmon, sandybrown, seagreen, seashell, sienna, silver, skyblue, slateblue, slategray, slategrey, snow, springgreen, steelblue, tan, teal, thistle, tomato, turquoise, violet, wheat, white, whitesmoke, yellow, yellowgreen
Tile plots are useful for viewing the effect of each arm.
render(interact_fitted(model, rel=False))