Source code for ax.plot.slice

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from typing import Any, Dict, Optional

import numpy as np
from ax.core.observation import ObservationFeatures
from ax.modelbridge.base import ModelBridge
from ax.plot.base import AxPlotConfig, AxPlotTypes
from ax.plot.helper import (
    TNullableGeneratorRunsDict,
    get_fixed_values,
    get_grid_for_parameter,
    get_plot_data,
    get_range_parameter,
)


[docs]def plot_slice( model: ModelBridge, param_name: str, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: Optional[Dict[str, Any]] = None, ) -> AxPlotConfig: """Plot predictions for a 1-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions param_name: Name of parameter that will be sliced metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. """ if generator_runs_dict is None: generator_runs_dict = {} parameter = get_range_parameter(model, param_name) grid = get_grid_for_parameter(parameter, density) plot_data, raw_data, cond_name_to_parameters = get_plot_data( model=model, generator_runs_dict=generator_runs_dict, metric_names={metric_name} ) fixed_values = get_fixed_values(model, slice_values) prediction_features = [] for x in grid: parameters = fixed_values.copy() parameters[param_name] = x # Here we assume context is None prediction_features.append(ObservationFeatures(parameters=parameters)) f, cov = model.predict(prediction_features) f_plt = f[metric_name] sd_plt = np.sqrt(cov[metric_name][metric_name]) config = { "arm_data": plot_data, "arm_name_to_parameters": cond_name_to_parameters, "f": f_plt, "fit_data": raw_data, "grid": grid, "metric": metric_name, "param": param_name, "rel": relative, "setx": fixed_values, "sd": sd_plt, "is_log": parameter.log_scale, } return AxPlotConfig(config, plot_type=AxPlotTypes.SLICE)