Source code for ax.plot.slice

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from copy import deepcopy
from typing import Any, Optional, Union

import numpy as np
import numpy.typing as npt
from ax.core.observation import ObservationFeatures
from ax.modelbridge.base import ModelBridge
from ax.plot.base import AxPlotConfig, AxPlotTypes, PlotData
from ax.plot.helper import (
    axis_range,
    get_fixed_values,
    get_grid_for_parameter,
    get_plot_data,
    get_range_parameter,
    get_range_parameters,
    slice_config_to_trace,
    TNullableGeneratorRunsDict,
)
from plotly import graph_objs as go
from pyre_extensions import none_throws


# type aliases
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
SlicePredictions = tuple[
    PlotData,
    list[dict[str, Union[str, float]]],
    list[float],
    np.ndarray,
    np.ndarray,
    str,
    str,
    bool,
    dict[str, Optional[Union[str, bool, float, int]]],
    np.ndarray,
    bool,
]


def _get_slice_predictions(
    model: ModelBridge,
    param_name: str,
    metric_name: str,
    generator_runs_dict: TNullableGeneratorRunsDict = None,
    relative: bool = False,
    density: int = 50,
    slice_values: dict[str, Any] | None = None,
    fixed_features: ObservationFeatures | None = None,
    trial_index: int | None = None,
) -> SlicePredictions:
    """Computes slice prediction configuration values for a single metric name.

    Args:
        model: ModelBridge that contains model for predictions
        param_name: Name of parameter that will be sliced
        metric_name: Name of metric to plot
        generator_runs_dict: A dictionary {name: generator run} of generator runs
            whose arms will be plotted, if they lie in the slice.
        relative: Predictions relative to status quo
        density: Number of points along slice to evaluate predictions.
        slice_values: A dictionary {name: val} for the fixed values of the
            other parameters. If not provided, then the status quo values will
            be used if there is a status quo, otherwise the mean of numeric
            parameters or the mode of choice parameters. Ignored if
            fixed_features is specified.
        fixed_features: An ObservationFeatures object containing the values of
            features (including non-parameter features like context) to be set
            in the slice.

    Returns: Configruation values for AxPlotConfig.
    """
    if generator_runs_dict is None:
        generator_runs_dict = {}

    parameter = get_range_parameter(model, param_name)
    grid = get_grid_for_parameter(parameter, density)

    plot_data, raw_data, cond_name_to_parameters = get_plot_data(
        model=model,
        generator_runs_dict=generator_runs_dict,
        metric_names={metric_name},
        fixed_features=fixed_features,
    )

    if fixed_features is not None:
        slice_values = fixed_features.parameters
    else:
        fixed_features = ObservationFeatures(parameters={})
    fixed_values = get_fixed_values(model, slice_values, trial_index)

    prediction_features = []
    for x in grid:
        predf = deepcopy(fixed_features)
        predf.parameters = fixed_values.copy()
        predf.parameters[param_name] = x
        prediction_features.append(predf)

    f, cov = model.predict(prediction_features)
    f_plt = f[metric_name]
    sd_plt = np.sqrt(cov[metric_name][metric_name])
    # pyre-fixme[7]: Expected `Tuple[PlotData, List[Dict[str, Union[float, str]]],
    #  List[float], np.ndarray, np.ndarray, str, str, bool, Dict[str, Union[None, bool,
    #  float, int, str]], np.ndarray, bool]` but got `Tuple[PlotData, Dict[str,
    #  Dict[str, Union[None, bool, float, int, str]]], List[float], List[Dict[str,
    #  Union[float, str]]], np.ndarray, str, str, bool, Dict[str, Union[None, bool,
    #  float, int, str]], typing.Any, bool]`.
    return (
        plot_data,
        cond_name_to_parameters,
        f_plt,
        raw_data,
        grid,
        metric_name,
        param_name,
        relative,
        fixed_values,
        sd_plt,
        parameter.log_scale,
    )


[docs] def plot_slice_plotly( model: ModelBridge, param_name: str, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: dict[str, Any] | None = None, fixed_features: ObservationFeatures | None = None, trial_index: int | None = None, ) -> go.Figure: """Plot predictions for a 1-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions param_name: Name of parameter that will be sliced metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. Ignored if fixed_features is specified. fixed_features: An ObservationFeatures object containing the values of features (including non-parameter features like context) to be set in the slice. Returns: go.Figure: plot of objective vs. parameter value """ pd, cntp, f_plt, rd, grid, _, _, _, fv, sd_plt, ls = _get_slice_predictions( model=model, param_name=param_name, metric_name=metric_name, generator_runs_dict=generator_runs_dict, relative=relative, density=density, slice_values=slice_values, fixed_features=fixed_features, trial_index=trial_index, ) config = { "arm_data": pd, "arm_name_to_parameters": cntp, "f": f_plt, "fit_data": rd, "grid": grid, "metric": metric_name, "param": param_name, "rel": relative, "setx": fv, "sd": sd_plt, "is_log": ls, } config = AxPlotConfig(config, plot_type=AxPlotTypes.GENERIC).data arm_data = config["arm_data"] arm_name_to_parameters = config["arm_name_to_parameters"] f = config["f"] fit_data = config["fit_data"] grid = config["grid"] metric = config["metric"] param = config["param"] rel = config["rel"] setx = config["setx"] sd = config["sd"] is_log = config["is_log"] traces = slice_config_to_trace( arm_data, arm_name_to_parameters, f, fit_data, grid, metric, param, rel, setx, sd, is_log, True, ) # layout xrange = axis_range(grid, is_log) xtype = "log" if is_log else "linear" layout = { "hovermode": "closest", "xaxis": { "anchor": "y", "autorange": False, "exponentformat": "e", "range": xrange, "tickfont": {"size": 11}, "tickmode": "auto", "title": param, "type": xtype, }, "yaxis": { "anchor": "x", "tickfont": {"size": 11}, "tickmode": "auto", "title": metric, }, } return go.Figure(data=traces, layout=layout)
[docs] def plot_slice( model: ModelBridge, param_name: str, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: dict[str, Any] | None = None, fixed_features: ObservationFeatures | None = None, trial_index: int | None = None, ) -> AxPlotConfig: """Plot predictions for a 1-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions param_name: Name of parameter that will be sliced metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. Ignored if fixed_features is specified. fixed_features: An ObservationFeatures object containing the values of features (including non-parameter features like context) to be set in the slice. Returns: AxPlotConfig: plot of objective vs. parameter value """ return AxPlotConfig( # pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got # `Figure`. data=plot_slice_plotly( model=model, param_name=param_name, metric_name=metric_name, generator_runs_dict=generator_runs_dict, relative=relative, density=density, slice_values=slice_values, fixed_features=fixed_features, trial_index=trial_index, ), plot_type=AxPlotTypes.GENERIC, )
[docs] def interact_slice_plotly( model: ModelBridge, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: dict[str, Any] | None = None, fixed_features: ObservationFeatures | None = None, trial_index: int | None = None, ) -> go.Figure: """Create interactive plot with predictions for a 1-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. Ignored if fixed_features is specified. fixed_features: An ObservationFeatures object containing the values of features (including non-parameter features like context) to be set in the slice. Returns: go.Figure: interactive plot of objective vs. parameter """ if generator_runs_dict is None: generator_runs_dict = {} metric_names = list(model.metric_names) # Populate `pbuttons`, which allows the user to select 1D slices of parameter # space with the chosen parameter on the x-axis. range_parameters = get_range_parameters(model, min_num_values=5) param_names = [parameter.name for parameter in range_parameters] pbuttons = [] init_traces = [] xaxis_init_format = {} first_param_bool = True should_replace_slice_values = fixed_features is not None for param_name in param_names: pbutton_data_args = {"x": [], "y": [], "error_y": []} parameter = get_range_parameter(model, param_name) grid = get_grid_for_parameter(parameter, density) plot_data_dict = {} raw_data_dict = {} sd_plt_dict: dict[str, dict[str, npt.NDArray]] = {} cond_name_to_parameters_dict = {} is_log_dict: dict[str, bool] = {} if should_replace_slice_values: slice_values = none_throws(fixed_features).parameters else: fixed_features = ObservationFeatures(parameters={}) fixed_values = get_fixed_values(model, slice_values, trial_index) prediction_features = [] for x in grid: predf = deepcopy(none_throws(fixed_features)) predf.parameters = fixed_values.copy() predf.parameters[param_name] = x prediction_features.append(predf) f, cov = model.predict(prediction_features) for metric_name in metric_names: pd, cntp, f_plt, rd, _, _, _, _, _, sd_plt, ls = _get_slice_predictions( model=model, param_name=param_name, metric_name=metric_name, generator_runs_dict=generator_runs_dict, relative=relative, density=density, slice_values=slice_values, fixed_features=fixed_features, ) plot_data_dict[metric_name] = pd raw_data_dict[metric_name] = rd cond_name_to_parameters_dict[metric_name] = cntp # pyre-fixme[6]: For 2nd argument expected `Dict[str, # ndarray[typing.Any, typing.Any]]` but got `ndarray[typing.Any, # dtype[typing.Any]]`. sd_plt_dict[metric_name] = np.sqrt(cov[metric_name][metric_name]) is_log_dict[metric_name] = ls config = { "arm_data": plot_data_dict, "arm_name_to_parameters": cond_name_to_parameters_dict, "f": f, "fit_data": raw_data_dict, "grid": grid, "metrics": metric_names, "param": param_name, "rel": relative, "setx": fixed_values, "sd": sd_plt_dict, "is_log": is_log_dict, } config = AxPlotConfig(config, plot_type=AxPlotTypes.GENERIC).data arm_data = config["arm_data"] arm_name_to_parameters = config["arm_name_to_parameters"] f = config["f"] fit_data = config["fit_data"] grid = config["grid"] metrics = config["metrics"] param = config["param"] rel = config["rel"] setx = config["setx"] sd = config["sd"] is_log = config["is_log"] # layout xrange = axis_range(grid, is_log[metrics[0]]) xtype = "log" if is_log_dict[metrics[0]] else "linear" for i, metric in enumerate(metrics): cur_visible = i == 0 metric = metrics[i] traces = slice_config_to_trace( arm_data[metric], arm_name_to_parameters[metric], f[metric], fit_data[metric], grid, metric, param, rel, setx, sd[metric], is_log[metric], cur_visible, ) pbutton_data_args["x"] += [trace["x"] for trace in traces] pbutton_data_args["y"] += [trace["y"] for trace in traces] pbutton_data_args["error_y"] += [ ( { "type": "data", "array": trace["error_y"]["array"], "visible": True, "color": "black", } if "error_y" in trace and "array" in trace["error_y"] else [] ) for trace in traces ] if first_param_bool: init_traces.extend(traces) pbutton_args = [ pbutton_data_args, { "xaxis.title": param_name, "xaxis.range": xrange, "xaxis.type": xtype, }, ] pbuttons.append({"args": pbutton_args, "label": param_name, "method": "update"}) if first_param_bool: xaxis_init_format = { "anchor": "y", "autorange": False, "exponentformat": "e", "range": xrange, "tickfont": {"size": 11}, "tickmode": "auto", "title": param_name, "type": xtype, } first_param_bool = False # Populate mbuttons, which allows the user to select which metric to plot mbuttons = [] # pyre-fixme[61]: `metrics` is undefined, or not always defined. for i, metric in enumerate(metrics): # pyre-fixme[61]: `arm_data` is undefined, or not always defined. trace_cnt = 3 + len(arm_data[metric]["out_of_sample"].keys()) # pyre-fixme[61]: `metrics` is undefined, or not always defined. visible = [False] * (len(metrics) * trace_cnt) for j in range(i * trace_cnt, (i + 1) * trace_cnt): visible[j] = True mbuttons.append( { "method": "update", "args": [{"visible": visible}, {"yaxis.title": metric}], "label": metric, } ) layout = { "title": "Predictions for a 1-d slice of the parameter space", "annotations": [ { "showarrow": False, "text": "Choose metric:", "x": 0.225, "xanchor": "right", "xref": "paper", "y": -0.455, "yanchor": "bottom", "yref": "paper", }, { "showarrow": False, "text": "Choose parameter:", "x": 0.225, "xanchor": "right", "xref": "paper", "y": -0.305, "yanchor": "bottom", "yref": "paper", }, ], "updatemenus": [ { "y": -0.35, "x": 0.25, "xanchor": "left", "yanchor": "top", "buttons": mbuttons, "direction": "up", }, { "y": -0.2, "x": 0.25, "xanchor": "left", "yanchor": "top", "buttons": pbuttons, "direction": "up", }, ], "hovermode": "closest", "xaxis": xaxis_init_format, "yaxis": { "anchor": "x", "autorange": True, "tickfont": {"size": 11}, "tickmode": "auto", # pyre-fixme[61]: `metrics` is undefined, or not always defined. "title": metrics[0], }, } return go.Figure(data=init_traces, layout=layout)
[docs] def interact_slice( model: ModelBridge, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: dict[str, Any] | None = None, fixed_features: ObservationFeatures | None = None, trial_index: int | None = None, ) -> AxPlotConfig: """Create interactive plot with predictions for a 1-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. Ignored if fixed_features is specified. fixed_features: An ObservationFeatures object containing the values of features (including non-parameter features like context) to be set in the slice. Returns: AxPlotConfig: interactive plot of objective vs. parameter """ return AxPlotConfig( # pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got # `Figure`. data=interact_slice_plotly( model=model, generator_runs_dict=generator_runs_dict, relative=relative, density=density, slice_values=slice_values, fixed_features=fixed_features, trial_index=trial_index, ), plot_type=AxPlotTypes.GENERIC, )