Source code for ax.plot.contour

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from typing import Any, Dict, Optional, Tuple

import numpy as np
from ax.core.observation import ObservationFeatures
from ax.modelbridge.base import ModelBridge
from ax.plot.base import AxPlotConfig, AxPlotTypes, PlotData
from ax.plot.color import BLUE_SCALE, GREEN_PINK_SCALE, GREEN_SCALE
from ax.plot.helper import (
    TNullableGeneratorRunsDict,
    get_fixed_values,
    get_grid_for_parameter,
    get_plot_data,
    get_range_parameter,
    get_range_parameters,
)


# type aliases
ContourPredictions = Tuple[
    PlotData, np.ndarray, np.ndarray, np.ndarray, np.ndarray, Dict[str, bool]
]


def _get_contour_predictions(
    model: ModelBridge,
    x_param_name: str,
    y_param_name: str,
    metric: str,
    generator_runs_dict: TNullableGeneratorRunsDict,
    density: int,
    slice_values: Optional[Dict[str, Any]] = None,
) -> ContourPredictions:
    """
    slice_values is a dictionary {param_name: value} for the parameters that
    are being sliced on.
    """
    x_param = get_range_parameter(model, x_param_name)
    y_param = get_range_parameter(model, y_param_name)

    plot_data, _, _ = get_plot_data(model, generator_runs_dict or {}, {metric})

    grid_x = get_grid_for_parameter(x_param, density)
    grid_y = get_grid_for_parameter(y_param, density)
    scales = {"x": x_param.log_scale, "y": y_param.log_scale}

    grid2_x, grid2_y = np.meshgrid(grid_x, grid_y)

    grid2_x = grid2_x.flatten()
    grid2_y = grid2_y.flatten()

    fixed_values = get_fixed_values(model, slice_values)

    param_grid_obsf = []
    for i in range(density ** 2):
        parameters = fixed_values.copy()
        parameters[x_param_name] = grid2_x[i]
        parameters[y_param_name] = grid2_y[i]
        param_grid_obsf.append(ObservationFeatures(parameters))

    mu, cov = model.predict(param_grid_obsf)

    f_plt = mu[metric]
    sd_plt = np.sqrt(cov[metric][metric])
    return plot_data, f_plt, sd_plt, grid_x, grid_y, scales


[docs]def plot_contour( model: ModelBridge, param_x: str, param_y: str, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: Optional[Dict[str, Any]] = None, lower_is_better: bool = False, ) -> AxPlotConfig: """Plot predictions for a 2-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions param_x: Name of parameter that will be sliced on x-axis param_y: Name of parameter that will be sliced on y-axis metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. lower_is_better: Lower values for metric are better. """ if param_x == param_y: raise ValueError("Please select different parameters for x- and y-dimensions.") data, f_plt, sd_plt, grid_x, grid_y, scales = _get_contour_predictions( model=model, x_param_name=param_x, y_param_name=param_y, metric=metric_name, generator_runs_dict=generator_runs_dict, density=density, slice_values=slice_values, ) config = { "arm_data": data, "blue_scale": BLUE_SCALE, "density": density, "f": f_plt, "green_scale": GREEN_SCALE, "green_pink_scale": GREEN_PINK_SCALE, "grid_x": grid_x, "grid_y": grid_y, "lower_is_better": lower_is_better, "metric": metric_name, "rel": relative, "sd": sd_plt, "xvar": param_x, "yvar": param_y, "x_is_log": scales["x"], "y_is_log": scales["y"], } return AxPlotConfig(config, plot_type=AxPlotTypes.CONTOUR)
[docs]def interact_contour( model: ModelBridge, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: Optional[Dict[str, Any]] = None, lower_is_better: bool = False, ) -> AxPlotConfig: """Create interactive plot with predictions for a 2-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. lower_is_better: Lower values for metric are better. """ range_parameters = get_range_parameters(model) plot_data, _, _ = get_plot_data(model, generator_runs_dict or {}, {metric_name}) # TODO T38563759: Sort parameters by feature importances param_names = [parameter.name for parameter in range_parameters] is_log_dict: Dict[str, bool] = {} grid_dict: Dict[str, np.ndarray] = {} for parameter in range_parameters: is_log_dict[parameter.name] = parameter.log_scale grid_dict[parameter.name] = get_grid_for_parameter(parameter, density) # pyre: f_dict is declared to have type `Dict[str, Dict[str, np.ndarray]]` # pyre-fixme[9]: but is used as type `Dict[str, Dict[str, typing.List[]]]`. f_dict: Dict[str, Dict[str, np.ndarray]] = { param1: {param2: [] for param2 in param_names} for param1 in param_names } # pyre: sd_dict is declared to have type `Dict[str, Dict[str, np. # pyre: ndarray]]` but is used as type `Dict[str, Dict[str, typing. # pyre-fixme[9]: List[]]]`. sd_dict: Dict[str, Dict[str, np.ndarray]] = { param1: {param2: [] for param2 in param_names} for param1 in param_names } for param1 in param_names: for param2 in param_names: _, f_plt, sd_plt, _, _, _ = _get_contour_predictions( model=model, x_param_name=param1, y_param_name=param2, metric=metric_name, generator_runs_dict=generator_runs_dict, density=density, slice_values=slice_values, ) f_dict[param1][param2] = f_plt sd_dict[param1][param2] = sd_plt config = { "arm_data": plot_data, "blue_scale": BLUE_SCALE, "density": density, "f_dict": f_dict, "green_scale": GREEN_SCALE, "green_pink_scale": GREEN_PINK_SCALE, "grid_dict": grid_dict, "lower_is_better": lower_is_better, "metric": metric_name, "rel": relative, "sd_dict": sd_dict, "is_log_dict": is_log_dict, "param_names": param_names, } return AxPlotConfig(config, plot_type=AxPlotTypes.INTERACT_CONTOUR)