Source code for ax.plot.scatter

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numbers
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import plotly.graph_objs as go
from ax.core.observation import Observation, ObservationFeatures
from ax.modelbridge.base import ModelBridge
from ax.plot.base import (
    CI_OPACITY,
    DECIMALS,
    AxPlotConfig,
    AxPlotTypes,
    PlotInSampleArm,
    PlotMetric,
    PlotOutOfSampleArm,
    Z,
)
from ax.plot.color import COLORS, DISCRETE_COLOR_SCALE, rgba
from ax.plot.helper import (
    TNullableGeneratorRunsDict,
    _format_CI,
    _format_dict,
    _wrap_metric,
    arm_name_to_tuple,
    get_plot_data,
    infer_is_relative,
    resize_subtitles,
)
from ax.utils.stats.statstools import relativize
from plotly import subplots


# type aliases
Traces = List[Dict[str, Any]]


def _error_scatter_data(
    arms: List[Union[PlotInSampleArm, PlotOutOfSampleArm]],
    y_axis_var: PlotMetric,
    x_axis_var: Optional[PlotMetric] = None,
    status_quo_arm: Optional[PlotInSampleArm] = None,
) -> Tuple[List[float], Optional[List[float]], List[float], List[float]]:
    y_metric_key = "y_hat" if y_axis_var.pred else "y"
    y_sd_key = "se_hat" if y_axis_var.pred else "se"

    arm_names = [a.name for a in arms]
    y = [getattr(a, y_metric_key).get(y_axis_var.metric, np.nan) for a in arms]
    y_se = [getattr(a, y_sd_key).get(y_axis_var.metric, np.nan) for a in arms]

    # Delta method if relative to status quo arm
    if y_axis_var.rel:
        if status_quo_arm is None:
            raise ValueError("`status_quo_arm` cannot be None for relative effects.")
        y_rel, y_se_rel = relativize(
            means_t=y,
            sems_t=y_se,
            mean_c=getattr(status_quo_arm, y_metric_key).get(y_axis_var.metric),
            sem_c=getattr(status_quo_arm, y_sd_key).get(y_axis_var.metric),
            as_percent=True,
        )
        y = y_rel.tolist()
        y_se = y_se_rel.tolist()

    # x can be metric for a metric or arm names
    if x_axis_var is None:
        x = arm_names
        x_se = None
    else:
        x_metric_key = "y_hat" if x_axis_var.pred else "y"
        x_sd_key = "se_hat" if x_axis_var.pred else "se"
        x = [getattr(a, x_metric_key).get(x_axis_var.metric, np.nan) for a in arms]
        x_se = [getattr(a, x_sd_key).get(x_axis_var.metric, np.nan) for a in arms]

        if x_axis_var.rel:
            # Delta method if relative to status quo arm
            x_rel, x_se_rel = relativize(
                means_t=x,
                sems_t=x_se,
                mean_c=getattr(status_quo_arm, x_metric_key).get(x_axis_var.metric),
                sem_c=getattr(status_quo_arm, x_sd_key).get(x_axis_var.metric),
                as_percent=True,
            )
            x = x_rel.tolist()
            x_se = x_se_rel.tolist()
    return x, x_se, y, y_se


def _error_scatter_trace(
    arms: List[Union[PlotInSampleArm, PlotOutOfSampleArm]],
    y_axis_var: PlotMetric,
    x_axis_var: Optional[PlotMetric] = None,
    y_axis_label: Optional[str] = None,
    x_axis_label: Optional[str] = None,
    status_quo_arm: Optional[PlotInSampleArm] = None,
    show_CI: bool = True,
    name: str = "In-sample",
    color: Tuple[int] = COLORS.STEELBLUE.value,
    visible: bool = True,
    legendgroup: Optional[str] = None,
    showlegend: bool = True,
    hoverinfo: str = "text",
    show_arm_details_on_hover: bool = True,
    show_context: bool = False,
    arm_noun: str = "arm",
) -> Dict[str, Any]:
    """Plot scatterplot with error bars.

    Args:
        arms (List[Union[PlotInSampleArm, PlotOutOfSampleArm]]):
            a list of in-sample or out-of-sample arms.
            In-sample arms have observed data, while out-of-sample arms
            just have predicted data. As a result,
            when passing out-of-sample arms, pred must be True.
        y_axis_var: name of metric for y-axis, along with whether
            it is observed or predicted.
        x_axis_var: name of metric for x-axis,
            along with whether it is observed or predicted. If None, arm names
            are automatically used.
        y_axis_label: custom label to use for y axis.
            If None, use metric name from `y_axis_var`.
        x_axis_label: custom label to use for x axis.
            If None, use metric name from `x_axis_var` if that is not None.
        status_quo_arm: the status quo
            arm. Necessary for relative metrics.
        show_CI: if True, plot confidence intervals.
        name: name of trace. Default is "In-sample".
        color: color as rgb tuple. Default is
            (128, 177, 211), which corresponds to COLORS.STEELBLUE.
        visible: if True, trace is visible (default).
        legendgroup: group for legends.
        showlegend: if True, legend if rendered.
        hoverinfo: information to show on hover. Default is
            custom text.
        show_arm_details_on_hover: if True, display
            parameterizations of arms on hover. Default is True.
        show_context: if True and show_arm_details_on_hover,
            context will be included in the hover.
        arm_noun: noun to use instead of "arm" (e.g. group)
    """
    x, x_se, y, y_se = _error_scatter_data(
        arms=arms,
        y_axis_var=y_axis_var,
        x_axis_var=x_axis_var,
        status_quo_arm=status_quo_arm,
    )
    labels = []

    arm_names = [a.name for a in arms]

    # No relativization if no x variable.
    rel_x = x_axis_var.rel if x_axis_var else False
    rel_y = y_axis_var.rel

    for i in range(len(arm_names)):
        heading = f"<b>{arm_noun.title()} {arm_names[i]}</b><br>"
        x_lab = (
            "{name}: {estimate}{perc} {ci}<br>".format(
                name=x_axis_var.metric if x_axis_label is None else x_axis_label,
                estimate=(
                    round(x[i], DECIMALS) if isinstance(x[i], numbers.Number) else x[i]
                ),
                ci="" if x_se is None else _format_CI(x[i], x_se[i], rel_x),
                perc="%" if rel_x else "",
            )
            if x_axis_var is not None
            else ""
        )
        y_lab = "{name}: {estimate}{perc} {ci}<br>".format(
            name=y_axis_var.metric if y_axis_label is None else y_axis_label,
            estimate=(
                round(y[i], DECIMALS) if isinstance(y[i], numbers.Number) else y[i]
            ),
            ci="" if y_se is None else _format_CI(y[i], y_se[i], rel_y),
            perc="%" if rel_y else "",
        )

        parameterization = (
            _format_dict(arms[i].parameters, "Parameterization")
            if show_arm_details_on_hover
            else ""
        )

        context = (
            # Expected `Dict[str, Optional[Union[bool, float, str]]]` for 1st anonymous
            # parameter to call `ax.plot.helper._format_dict` but got
            # `Optional[Dict[str, Union[float, str]]]`.
            # pyre-fixme[6]:
            _format_dict(arms[i].context_stratum, "Context")
            if show_arm_details_on_hover
            and show_context  # noqa W503
            and arms[i].context_stratum  # noqa W503
            else ""
        )

        labels.append(
            "{arm_name}<br>{xlab}{ylab}{param_blob}{context}".format(
                arm_name=heading,
                xlab=x_lab,
                ylab=y_lab,
                param_blob=parameterization,
                context=context,
            )
        )
        i += 1
    trace = go.Scatter(
        x=x,
        y=y,
        marker={"color": rgba(color)},
        mode="markers",
        name=name,
        text=labels,
        hoverinfo=hoverinfo,
    )

    if show_CI:
        if x_se is not None:
            trace.update(
                error_x={
                    "type": "data",
                    "array": np.multiply(x_se, Z),
                    "color": rgba(color, CI_OPACITY),
                }
            )
        if y_se is not None:
            trace.update(
                error_y={
                    "type": "data",
                    "array": np.multiply(y_se, Z),
                    "color": rgba(color, CI_OPACITY),
                }
            )
    if visible is not None:
        trace.update(visible=visible)
    if legendgroup is not None:
        trace.update(legendgroup=legendgroup)
    if showlegend is not None:
        trace.update(showlegend=showlegend)
    return trace


def _multiple_metric_traces(
    model: ModelBridge,
    metric_x: str,
    metric_y: str,
    generator_runs_dict: TNullableGeneratorRunsDict,
    rel_x: bool,
    rel_y: bool,
    fixed_features: Optional[ObservationFeatures] = None,
    data_selector: Optional[Callable[[Observation], bool]] = None,
) -> Traces:
    """Plot traces for multiple metrics given a model and metrics.

    Args:
        model: model to draw predictions from.
        metric_x: metric to plot on the x-axis.
        metric_y: metric to plot on the y-axis.
        generator_runs_dict: a mapping from
            generator run name to generator run.
        rel_x: if True, use relative effects on metric_x.
        rel_y: if True, use relative effects on metric_y.
        fixed_features: Fixed features to use when making model predictions.
        data_selector: Function for selecting observations for plotting.

    """
    plot_data, _, _ = get_plot_data(
        model,
        generator_runs_dict if generator_runs_dict is not None else {},
        {metric_x, metric_y},
        fixed_features=fixed_features,
        data_selector=data_selector,
    )

    status_quo_arm = (
        None
        if plot_data.status_quo_name is None
        else plot_data.in_sample.get(plot_data.status_quo_name)
    )

    traces = [
        _error_scatter_trace(
            # Expected `List[Union[PlotInSampleArm, PlotOutOfSampleArm]]`
            # for 1st anonymous parameter to call
            # `ax.plot.scatter._error_scatter_trace` but got
            # `List[PlotInSampleArm]`.
            # pyre-fixme[6]:
            list(plot_data.in_sample.values()),
            x_axis_var=PlotMetric(metric_x, pred=False, rel=rel_x),
            y_axis_var=PlotMetric(metric_y, pred=False, rel=rel_y),
            status_quo_arm=status_quo_arm,
            visible=False,
        ),
        _error_scatter_trace(
            # Expected `List[Union[PlotInSampleArm, PlotOutOfSampleArm]]`
            # for 1st anonymous parameter to call
            # `ax.plot.scatter._error_scatter_trace` but got
            # `List[PlotInSampleArm]`.
            # pyre-fixme[6]:
            list(plot_data.in_sample.values()),
            x_axis_var=PlotMetric(metric_x, pred=True, rel=rel_x),
            y_axis_var=PlotMetric(metric_y, pred=True, rel=rel_y),
            status_quo_arm=status_quo_arm,
            visible=True,
        ),
    ]

    for i, (generator_run_name, cand_arms) in enumerate(
        (plot_data.out_of_sample or {}).items(), start=1
    ):
        traces.append(
            _error_scatter_trace(
                # pyre-fixme[6]: Expected `List[Union[PlotInSampleArm,
                #  PlotOutOfSampleArm]]` for 1st param but got
                #  `List[PlotOutOfSampleArm]`.
                list(cand_arms.values()),
                x_axis_var=PlotMetric(metric_x, pred=True, rel=rel_x),
                y_axis_var=PlotMetric(metric_y, pred=True, rel=rel_y),
                status_quo_arm=status_quo_arm,
                name=generator_run_name,
                color=DISCRETE_COLOR_SCALE[i],
            )
        )
    return traces


[docs]def plot_multiple_metrics( model: ModelBridge, metric_x: str, metric_y: str, generator_runs_dict: TNullableGeneratorRunsDict = None, rel: bool = True, fixed_features: Optional[ObservationFeatures] = None, data_selector: Optional[Callable[[Observation], bool]] = None, ) -> AxPlotConfig: """Plot raw values or predictions of two metrics for arms. All arms used in the model are included in the plot. Additional arms can be passed through the `generator_runs_dict` argument. Args: model: model to draw predictions from. metric_x: metric to plot on the x-axis. metric_y: metric to plot on the y-axis. generator_runs_dict: a mapping from generator run name to generator run. rel: if True, use relative effects. Default is True. data_selector: Function for selecting observations for plotting. """ traces = _multiple_metric_traces( model, metric_x, metric_y, generator_runs_dict, rel_x=rel, rel_y=rel, fixed_features=fixed_features, data_selector=data_selector, ) num_cand_traces = len(generator_runs_dict) if generator_runs_dict is not None else 0 layout = go.Layout( title="Objective Tradeoffs", hovermode="closest", updatemenus=[ { "x": 1.25, "y": 0.67, "buttons": [ { "args": [ { "error_x.width": 4, "error_x.thickness": 2, "error_y.width": 4, "error_y.thickness": 2, } ], "label": "Yes", "method": "restyle", }, { "args": [ { "error_x.width": 0, "error_x.thickness": 0, "error_y.width": 0, "error_y.thickness": 0, } ], "label": "No", "method": "restyle", }, ], "yanchor": "middle", "xanchor": "left", }, { "x": 1.25, "y": 0.57, "buttons": [ { "args": [ {"visible": ([False, True] + [True] * num_cand_traces)} ], "label": "Modeled", "method": "restyle", }, { "args": [ {"visible": ([True, False] + [False] * num_cand_traces)} ], "label": "Observed", "method": "restyle", }, ], "yanchor": "middle", "xanchor": "left", }, ], annotations=[ { "x": 1.18, "y": 0.7, "xref": "paper", "yref": "paper", "text": "Show CI", "showarrow": False, "yanchor": "middle", }, { "x": 1.18, "y": 0.6, "xref": "paper", "yref": "paper", "text": "Type", "showarrow": False, "yanchor": "middle", }, ], xaxis={ "title": metric_x + (" (%)" if rel else ""), "zeroline": True, "zerolinecolor": "red", }, yaxis={ "title": metric_y + (" (%)" if rel else ""), "zeroline": True, "zerolinecolor": "red", }, width=800, height=600, font={"size": 10}, ) fig = go.Figure(data=traces, layout=layout) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
[docs]def plot_objective_vs_constraints( model: ModelBridge, objective: str, subset_metrics: Optional[List[str]] = None, generator_runs_dict: TNullableGeneratorRunsDict = None, rel: bool = True, infer_relative_constraints: Optional[bool] = False, fixed_features: Optional[ObservationFeatures] = None, data_selector: Optional[Callable[[Observation], bool]] = None, ) -> AxPlotConfig: """Plot the tradeoff between an objetive and all other metrics in a model. All arms used in the model are included in the plot. Additional arms can be passed through via the `generator_runs_dict` argument. Fixed features input can be used to override fields of the insample arms when making model predictions. Args: model: model to draw predictions from. objective: metric to optimize. Plotted on the x-axis. subset_metrics: list of metrics to plot on the y-axes if need a subset of all metrics in the model. generator_runs_dict: a mapping from generator run name to generator run. rel: if True, use relative effects. Default is True. infer_relative_constraints: if True, read relative spec from model's optimization config. Absolute constraints will not be relativized; relative ones will be. Objectives will respect the `rel` parameter. Metrics that are not constraints will be relativized. fixed_features: Fixed features to use when making model predictions. data_selector: Function for selecting observations for plotting. """ if subset_metrics is not None: metrics = subset_metrics else: metrics = [m for m in model.metric_names if m != objective] metric_dropdown = [] if infer_relative_constraints: rels = infer_is_relative(model, metrics, non_constraint_rel=rel) if rel: rels[objective] = True else: rels[objective] = False else: if rel: rels = {metric: True for metric in metrics} rels[objective] = True else: rels = {metric: False for metric in metrics} rels[objective] = False # set plotted data to the first outcome plot_data = _multiple_metric_traces( model, objective, metrics[0], generator_runs_dict, rel_x=rels[objective], rel_y=rels[metrics[0]], fixed_features=fixed_features, data_selector=data_selector, ) for metric in metrics: otraces = _multiple_metric_traces( model, objective, metric, generator_runs_dict, rel_x=rels[objective], rel_y=rels[metric], fixed_features=fixed_features, data_selector=data_selector, ) # Current version of Plotly does not allow updating the yaxis label # on dropdown (via relayout) simultaneously with restyle metric_dropdown.append( { "args": [ { "y": [t["y"] for t in otraces], "error_y.array": [t["error_y"]["array"] for t in otraces], "text": [t["text"] for t in otraces], }, {"yaxis.title": metric + (" (%)" if rels[metric] else "")}, ], "label": metric, "method": "update", } ) num_cand_traces = len(generator_runs_dict) if generator_runs_dict is not None else 0 layout = go.Layout( title="Objective Tradeoffs", hovermode="closest", updatemenus=[ { "x": 1.25, "y": 0.62, "buttons": [ { "args": [ { "error_x.width": 4, "error_x.thickness": 2, "error_y.width": 4, "error_y.thickness": 2, } ], "label": "Yes", "method": "restyle", }, { "args": [ { "error_x.width": 0, "error_x.thickness": 0, "error_y.width": 0, "error_y.thickness": 0, } ], "label": "No", "method": "restyle", }, ], "yanchor": "middle", "xanchor": "left", }, { "x": 1.25, "y": 0.52, "buttons": [ { "args": [ {"visible": ([False, True] + [True] * num_cand_traces)} ], "label": "Modeled", "method": "restyle", }, { "args": [ {"visible": ([True, False] + [False] * num_cand_traces)} ], "label": "Observed", "method": "restyle", }, ], "yanchor": "middle", "xanchor": "left", }, { "x": 1.25, "y": 0.72, "yanchor": "middle", "xanchor": "left", "buttons": metric_dropdown, }, ], annotations=[ { "x": 1.18, "y": 0.72, "xref": "paper", "yref": "paper", "text": "Y-Axis", "showarrow": False, "yanchor": "middle", }, { "x": 1.18, "y": 0.62, "xref": "paper", "yref": "paper", "text": "Show CI", "showarrow": False, "yanchor": "middle", }, { "x": 1.18, "y": 0.52, "xref": "paper", "yref": "paper", "text": "Type", "showarrow": False, "yanchor": "middle", }, ], xaxis={ "title": objective + (" (%)" if rels[objective] else ""), "zeroline": True, "zerolinecolor": "red", }, yaxis={ "title": metrics[0] + (" (%)" if rels[metrics[0]] else ""), "zeroline": True, "zerolinecolor": "red", }, width=900, height=600, font={"size": 10}, ) fig = go.Figure(data=plot_data, layout=layout) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
[docs]def lattice_multiple_metrics( model: ModelBridge, generator_runs_dict: TNullableGeneratorRunsDict = None, rel: bool = True, show_arm_details_on_hover: bool = False, data_selector: Optional[Callable[[Observation], bool]] = None, ) -> AxPlotConfig: """Plot raw values or predictions of combinations of two metrics for arms. Args: model: model to draw predictions from. generator_runs_dict: a mapping from generator run name to generator run. rel: if True, use relative effects. Default is True. show_arm_details_on_hover: if True, display parameterizations of arms on hover. Default is False. data_selector: Function for selecting observations for plotting. """ metrics = model.metric_names fig = subplots.make_subplots( rows=len(metrics), cols=len(metrics), print_grid=False, shared_xaxes=False, shared_yaxes=False, ) plot_data, _, _ = get_plot_data( model, generator_runs_dict if generator_runs_dict is not None else {}, metrics, data_selector=data_selector, ) status_quo_arm = ( None if plot_data.status_quo_name is None else plot_data.in_sample.get(plot_data.status_quo_name) ) # iterate over all combinations of metrics and generate scatter traces for i, o1 in enumerate(metrics, start=1): for j, o2 in enumerate(metrics, start=1): if o1 != o2: # in-sample observed and predicted obs_insample_trace = _error_scatter_trace( # Expected `List[Union[PlotInSampleArm, # PlotOutOfSampleArm]]` for 1st anonymous parameter to call # `ax.plot.scatter._error_scatter_trace` but got # `List[PlotInSampleArm]`. # pyre-fixme[6]: list(plot_data.in_sample.values()), x_axis_var=PlotMetric(o1, pred=False, rel=rel), y_axis_var=PlotMetric(o2, pred=False, rel=rel), status_quo_arm=status_quo_arm, showlegend=(i == 1 and j == 2), legendgroup="In-sample", visible=False, show_arm_details_on_hover=show_arm_details_on_hover, ) predicted_insample_trace = _error_scatter_trace( # Expected `List[Union[PlotInSampleArm, # PlotOutOfSampleArm]]` for 1st anonymous parameter to call # `ax.plot.scatter._error_scatter_trace` but got # `List[PlotInSampleArm]`. # pyre-fixme[6]: list(plot_data.in_sample.values()), x_axis_var=PlotMetric(o1, pred=True, rel=rel), y_axis_var=PlotMetric(o2, pred=True, rel=rel), status_quo_arm=status_quo_arm, legendgroup="In-sample", showlegend=(i == 1 and j == 2), visible=True, show_arm_details_on_hover=show_arm_details_on_hover, ) fig.append_trace(obs_insample_trace, j, i) fig.append_trace(predicted_insample_trace, j, i) # iterate over models here for k, (generator_run_name, cand_arms) in enumerate( (plot_data.out_of_sample or {}).items(), start=1 ): fig.append_trace( _error_scatter_trace( # pyre-fixme[6]: Expected `List[Union[PlotInSampleArm, # PlotOutOfSampleArm]]` for 1st param but got # `List[PlotOutOfSampleArm]`. list(cand_arms.values()), x_axis_var=PlotMetric(o1, pred=True, rel=rel), y_axis_var=PlotMetric(o2, pred=True, rel=rel), status_quo_arm=status_quo_arm, name=generator_run_name, color=DISCRETE_COLOR_SCALE[k], showlegend=(i == 1 and j == 2), legendgroup=generator_run_name, show_arm_details_on_hover=show_arm_details_on_hover, ), j, i, ) else: # if diagonal is set to True, add box plots fig.append_trace( go.Box( y=[arm.y[o1] for arm in plot_data.in_sample.values()], name=None, marker={"color": rgba(COLORS.STEELBLUE.value)}, showlegend=False, legendgroup="In-sample", visible=False, hoverinfo="none", ), j, i, ) fig.append_trace( go.Box( y=[arm.y_hat[o1] for arm in plot_data.in_sample.values()], name=None, marker={"color": rgba(COLORS.STEELBLUE.value)}, showlegend=False, legendgroup="In-sample", hoverinfo="none", ), j, i, ) for k, (generator_run_name, cand_arms) in enumerate( (plot_data.out_of_sample or {}).items(), start=1 ): fig.append_trace( go.Box( y=[arm.y_hat[o1] for arm in cand_arms.values()], name=None, marker={"color": rgba(DISCRETE_COLOR_SCALE[k])}, showlegend=False, legendgroup=generator_run_name, hoverinfo="none", ), j, i, ) fig["layout"].update( height=800, width=960, font={"size": 10}, hovermode="closest", legend={ "orientation": "h", "x": 0, "y": 1.05, "xanchor": "left", "yanchor": "middle", }, updatemenus=[ { "x": 0.35, "y": 1.08, "xanchor": "left", "yanchor": "middle", "buttons": [ { "args": [ { "error_x.width": 0, "error_x.thickness": 0, "error_y.width": 0, "error_y.thickness": 0, } ], "label": "No", "method": "restyle", }, { "args": [ { "error_x.width": 4, "error_x.thickness": 2, "error_y.width": 4, "error_y.thickness": 2, } ], "label": "Yes", "method": "restyle", }, ], }, { "x": 0.1, "y": 1.08, "xanchor": "left", "yanchor": "middle", "buttons": [ { "args": [ { "visible": ( ( [False, True] + [True] * len(plot_data.out_of_sample or {}) ) * (len(metrics) ** 2) ) } ], "label": "Modeled", "method": "restyle", }, { "args": [ { "visible": ( ( [True, False] + [False] * len(plot_data.out_of_sample or {}) ) * (len(metrics) ** 2) ) } ], "label": "In-sample", "method": "restyle", }, ], }, ], annotations=[ { "x": 0.02, "y": 1.1, "xref": "paper", "yref": "paper", "text": "Type", "showarrow": False, "yanchor": "middle", "xanchor": "left", }, { "x": 0.30, "y": 1.1, "xref": "paper", "yref": "paper", "text": "Show CI", "showarrow": False, "yanchor": "middle", "xanchor": "left", }, ], ) # add metric names to axes - add to each subplot if boxplots on the # diagonal and axes are not shared; else, add to the leftmost y-axes # and bottom x-axes. for i, o in enumerate(metrics): pos_x = len(metrics) * len(metrics) - len(metrics) + i + 1 pos_y = 1 + (len(metrics) * i) fig["layout"]["xaxis{}".format(pos_x)].update( title=_wrap_metric(o), titlefont={"size": 10} ) fig["layout"]["yaxis{}".format(pos_y)].update( title=_wrap_metric(o), titlefont={"size": 10} ) # do not put x-axis ticks for boxplots boxplot_xaxes = [] for trace in fig["data"]: if trace["type"] == "box": # stores the xaxes which correspond to boxplot subplots # since we use xaxis1, xaxis2, etc, in plotly.py boxplot_xaxes.append("xaxis{}".format(trace["xaxis"][1:])) else: # clear all error bars since default is no CI trace["error_x"].update(width=0, thickness=0) trace["error_y"].update(width=0, thickness=0) for xaxis in boxplot_xaxes: fig["layout"][xaxis]["showticklabels"] = False return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
# Single metric fitted values def _single_metric_traces( model: ModelBridge, metric: str, generator_runs_dict: TNullableGeneratorRunsDict, rel: bool, show_arm_details_on_hover: bool = True, showlegend: bool = True, show_CI: bool = True, arm_noun: str = "arm", fixed_features: Optional[ObservationFeatures] = None, data_selector: Optional[Callable[[Observation], bool]] = None, ) -> Traces: """Plot scatterplots with errors for a single metric (y-axis). Arms are plotted on the x-axis. Args: model: model to draw predictions from. metric: name of metric to plot. generator_runs_dict: a mapping from generator run name to generator run. rel: if True, plot relative predictions. show_arm_details_on_hover: if True, display parameterizations of arms on hover. Default is True. show_legend: if True, show legend for trace. show_CI: if True, render confidence intervals. arm_noun: noun to use instead of "arm" (e.g. group) fixed_features: Fixed features to use when making model predictions. data_selector: Function for selecting observations for plotting. """ plot_data, _, _ = get_plot_data( model, generator_runs_dict or {}, {metric}, fixed_features=fixed_features, data_selector=data_selector, ) status_quo_arm = ( None if plot_data.status_quo_name is None else plot_data.in_sample.get(plot_data.status_quo_name) ) traces = [ _error_scatter_trace( # Expected `List[Union[PlotInSampleArm, PlotOutOfSampleArm]]` # for 1st anonymous parameter to call # `ax.plot.scatter._error_scatter_trace` but got # `List[PlotInSampleArm]`. # pyre-fixme[6]: list(plot_data.in_sample.values()), x_axis_var=None, y_axis_var=PlotMetric(metric, pred=True, rel=rel), status_quo_arm=status_quo_arm, legendgroup="In-sample", showlegend=showlegend, show_arm_details_on_hover=show_arm_details_on_hover, show_CI=show_CI, arm_noun=arm_noun, ) ] # Candidates for i, (generator_run_name, cand_arms) in enumerate( (plot_data.out_of_sample or {}).items(), start=1 ): traces.append( _error_scatter_trace( # pyre-fixme[6]: Expected `List[Union[PlotInSampleArm, # PlotOutOfSampleArm]]` for 1st param but got # `List[PlotOutOfSampleArm]`. list(cand_arms.values()), x_axis_var=None, y_axis_var=PlotMetric(metric, pred=True, rel=rel), status_quo_arm=status_quo_arm, name=generator_run_name, color=DISCRETE_COLOR_SCALE[i], legendgroup=generator_run_name, showlegend=showlegend, show_arm_details_on_hover=show_arm_details_on_hover, show_CI=show_CI, arm_noun=arm_noun, ) ) return traces
[docs]def plot_fitted( model: ModelBridge, metric: str, generator_runs_dict: TNullableGeneratorRunsDict = None, rel: bool = True, custom_arm_order: Optional[List[str]] = None, custom_arm_order_name: str = "Custom", show_CI: bool = True, data_selector: Optional[Callable[[Observation], bool]] = None, ) -> AxPlotConfig: """Plot fitted metrics. Args: model: model to use for predictions. metric: metric to plot predictions for. generator_runs_dict: a mapping from generator run name to generator run. rel: if True, use relative effects. Default is True. custom_arm_order: a list of arm names in the order corresponding to how they should be plotted on the x-axis. If not None, this is the default ordering. custom_arm_order_name: name for custom ordering to show in the ordering dropdown. Default is 'Custom'. show_CI: if True, render confidence intervals. data_selector: Function for selecting observations for plotting. """ traces = _single_metric_traces( model, metric, generator_runs_dict, rel, show_CI=show_CI, data_selector=data_selector, ) # order arm name sorting arm numbers within batch names_by_arm = sorted( np.unique(np.concatenate([d["x"] for d in traces])), key=lambda x: arm_name_to_tuple(x), ) # get arm names sorted by effect size names_by_effect = list( OrderedDict.fromkeys( np.concatenate([d["x"] for d in traces]) .flatten() .take(np.argsort(np.concatenate([d["y"] for d in traces]).flatten())) ) ) # options for ordering arms (x-axis) xaxis_categoryorder = "array" xaxis_categoryarray = names_by_arm order_options = [ { "args": [ {"xaxis.categoryorder": "array", "xaxis.categoryarray": names_by_arm} ], "label": "Name", "method": "relayout", }, { "args": [ {"xaxis.categoryorder": "array", "xaxis.categoryarray": names_by_effect} ], "label": "Effect Size", "method": "relayout", }, ] # if a custom order has been passed, default to that if custom_arm_order is not None: xaxis_categoryorder = "array" xaxis_categoryarray = custom_arm_order order_options = [ { "args": [ { "xaxis.categoryorder": "array", "xaxis.categoryarray": custom_arm_order, } ], "label": custom_arm_order_name, "method": "relayout", } # Union[List[str... ] + order_options layout = go.Layout( title="Predicted Outcomes", hovermode="closest", updatemenus=[ { "x": 1.25, "y": 0.67, "buttons": list(order_options), "yanchor": "middle", "xanchor": "left", } ], yaxis={ "zerolinecolor": "red", "title": "{}{}".format(metric, " (%)" if rel else ""), }, xaxis={ "tickangle": 45, "categoryorder": xaxis_categoryorder, "categoryarray": xaxis_categoryarray, }, annotations=[ { "x": 1.18, "y": 0.72, "xref": "paper", "yref": "paper", "text": "Sort By", "showarrow": False, "yanchor": "middle", } ], font={"size": 10}, ) fig = go.Figure(data=traces, layout=layout) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
[docs]def tile_fitted( model: ModelBridge, generator_runs_dict: TNullableGeneratorRunsDict = None, rel: bool = True, show_arm_details_on_hover: bool = False, show_CI: bool = True, arm_noun: str = "arm", metrics: Optional[List[str]] = None, fixed_features: Optional[ObservationFeatures] = None, data_selector: Optional[Callable[[Observation], bool]] = None, ) -> AxPlotConfig: """Tile version of fitted outcome plots. Args: model: model to use for predictions. generator_runs_dict: a mapping from generator run name to generator run. rel: if True, use relative effects. Default is True. show_arm_details_on_hover: if True, display parameterizations of arms on hover. Default is False. show_CI: if True, render confidence intervals. arm_noun: noun to use instead of "arm" (e.g. group) metrics: List of metric names to restrict to when plotting. fixed_features: Fixed features to use when making model predictions. data_selector: Function for selecting observations for plotting. """ metrics = metrics or list(model.metric_names) nrows = int(np.ceil(len(metrics) / 2)) ncols = min(len(metrics), 2) # make subplots (plot per row) fig = subplots.make_subplots( rows=nrows, cols=ncols, print_grid=False, shared_xaxes=False, shared_yaxes=False, subplot_titles=tuple(metrics), horizontal_spacing=0.05, vertical_spacing=0.30 / nrows, ) name_order_args: Dict[str, Any] = {} name_order_axes: Dict[str, Dict[str, Any]] = {} effect_order_args: Dict[str, Any] = {} for i, metric in enumerate(metrics): data = _single_metric_traces( model, metric, generator_runs_dict, rel, showlegend=i == 0, show_arm_details_on_hover=show_arm_details_on_hover, show_CI=show_CI, arm_noun=arm_noun, fixed_features=fixed_features, data_selector=data_selector, ) # order arm name sorting arm numbers within batch names_by_arm = sorted( np.unique(np.concatenate([d["x"] for d in data])), key=lambda x: arm_name_to_tuple(x), ) # get arm names sorted by effect size names_by_effect = list( OrderedDict.fromkeys( np.concatenate([d["x"] for d in data]) .flatten() .take(np.argsort(np.concatenate([d["y"] for d in data]).flatten())) ) ) # options for ordering arms (x-axis) # Note that xaxes need to be references as xaxis, xaxis2, xaxis3, etc. # for the purposes of updatemenus argument (dropdown) in layout. # However, when setting the initial ordering layout, the keys should be # xaxis1, xaxis2, xaxis3, etc. Note the discrepancy for the initial # axis. label = "" if i == 0 else i + 1 name_order_args["xaxis{}.categoryorder".format(label)] = "array" name_order_args["xaxis{}.categoryarray".format(label)] = names_by_arm effect_order_args["xaxis{}.categoryorder".format(label)] = "array" effect_order_args["xaxis{}.categoryarray".format(label)] = names_by_effect name_order_axes["xaxis{}".format(i + 1)] = { "categoryorder": "array", "categoryarray": names_by_arm, "type": "category", } name_order_axes["yaxis{}".format(i + 1)] = { "ticksuffix": "%" if rel else "", "zerolinecolor": "red", } for d in data: fig.append_trace(d, int(np.floor(i / ncols)) + 1, i % ncols + 1) order_options = [ {"args": [name_order_args], "label": "Name", "method": "relayout"}, {"args": [effect_order_args], "label": "Effect Size", "method": "relayout"}, ] # if odd number of plots, need to manually remove the last blank subplot # generated by `subplots.make_subplots` if len(metrics) % 2 == 1: fig["layout"].pop("xaxis{}".format(nrows * ncols)) fig["layout"].pop("yaxis{}".format(nrows * ncols)) # allocate 400 px per plot fig["layout"].update( margin={"t": 0}, hovermode="closest", updatemenus=[ { "x": 0.15, "y": 1 + 0.40 / nrows, "buttons": order_options, "xanchor": "left", "yanchor": "middle", } ], font={"size": 10}, width=650 if ncols == 1 else 950, height=300 * nrows, legend={ "orientation": "h", "x": 0, "y": 1 + 0.20 / nrows, "xanchor": "left", "yanchor": "middle", }, **name_order_axes, ) # append dropdown annotations fig["layout"]["annotations"] = fig["layout"]["annotations"] + ( { "x": 0.5, "y": 1 + 0.40 / nrows, "xref": "paper", "yref": "paper", "font": {"size": 14}, "text": "Predicted Outcomes", "showarrow": False, "xanchor": "center", "yanchor": "middle", }, { "x": 0.05, "y": 1 + 0.40 / nrows, "xref": "paper", "yref": "paper", "text": "Sort By", "showarrow": False, "xanchor": "left", "yanchor": "middle", }, ) fig = resize_subtitles(figure=fig, size=10) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
[docs]def interact_fitted( model: ModelBridge, generator_runs_dict: TNullableGeneratorRunsDict = None, rel: bool = True, show_arm_details_on_hover: bool = True, show_CI: bool = True, arm_noun: str = "arm", metrics: Optional[List[str]] = None, fixed_features: Optional[ObservationFeatures] = None, data_selector: Optional[Callable[[Observation], bool]] = None, ) -> AxPlotConfig: """Interactive fitted outcome plots for each arm used in fitting the model. Choose the outcome to plot using a dropdown. Args: model: model to use for predictions. generator_runs_dict: a mapping from generator run name to generator run. rel: if True, use relative effects. Default is True. show_arm_details_on_hover: if True, display parameterizations of arms on hover. Default is True. show_CI: if True, render confidence intervals. arm_noun: noun to use instead of "arm" (e.g. group) metrics: List of metric names to restrict to when plotting. fixed_features: Fixed features to use when making model predictions. data_selector: Function for selecting observations for plotting. """ traces_per_metric = ( 1 if generator_runs_dict is None else len(generator_runs_dict) + 1 ) metrics = sorted(metrics or model.metric_names) traces = [] dropdown = [] for i, metric in enumerate(metrics): data = _single_metric_traces( model, metric, generator_runs_dict, rel, showlegend=i == 0, show_arm_details_on_hover=show_arm_details_on_hover, show_CI=show_CI, arm_noun=arm_noun, fixed_features=fixed_features, data_selector=data_selector, ) for d in data: d["visible"] = i == 0 traces.append(d) # only the first two traces are visible (corresponding to first outcome # in dropdown) is_visible = [False] * (len(metrics) * traces_per_metric) for j in range((traces_per_metric * i), (traces_per_metric * (i + 1))): is_visible[j] = True # on dropdown change, restyle dropdown.append( {"args": ["visible", is_visible], "label": metric, "method": "restyle"} ) layout = go.Layout( xaxis={"title": arm_noun.title(), "zeroline": False, "type": "category"}, yaxis={ "ticksuffix": "%" if rel else "", "title": ("Relative " if rel else "") + "Effect", "zeroline": True, "zerolinecolor": "red", }, hovermode="closest", updatemenus=[ { "buttons": dropdown, "x": 0.075, "xanchor": "left", "y": 1.1, "yanchor": "middle", } ], annotations=[ { "font": {"size": 12}, "showarrow": False, "text": "Metric", "x": 0.05, "xanchor": "right", "xref": "paper", "y": 1.1, "yanchor": "middle", "yref": "paper", } ], legend={ "orientation": "h", "x": 0.065, "xanchor": "left", "y": 1.2, "yanchor": "middle", }, height=500, ) if traces_per_metric > 1: layout["annotations"] = layout["annotations"] + ( { "font": {"size": 12}, "showarrow": False, "text": "Arm Source", "x": 0.05, "xanchor": "right", "xref": "paper", "y": 1.2, "yanchor": "middle", "yref": "paper", }, ) return AxPlotConfig( data=go.Figure(data=traces, layout=layout), plot_type=AxPlotTypes.GENERIC )