Source code for ax.plot.base

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import enum
from typing import Any, Dict, List, NamedTuple, Optional, Union

import simplejson
from ax.core.types import TParameterization
from plotly import utils


# Constants used for numerous plots
CI_OPACITY = 0.4
DECIMALS = 3
Z = 1.96


[docs]class AxPlotTypes(enum.Enum): """Enum of Ax plot types.""" CONTOUR = 0 GENERIC = 1 SLICE = 2 INTERACT_CONTOUR = 3 BANDIT_ROLLOUT = 4
# Configuration for all plots class _AxPlotConfigBase(NamedTuple): data: Dict[str, Any] plot_type: enum.Enum
[docs]class AxPlotConfig(_AxPlotConfigBase): """Config for plots""" def __new__(cls, data: Dict[str, Any], plot_type: enum.Enum) -> "AxPlotConfig": # Convert data to json-encodable form (strips out NamedTuple and numpy # array). This is a lossy conversion. dict_data = simplejson.loads( simplejson.dumps( data, cls=utils.PlotlyJSONEncoder, namedtuple_as_object=True, # uses NamesTuple's `_asdict()` ) ) # pyre-fixme[7]: Expected `AxPlotConfig` but got `_AxPlotConfigBase`. return super(AxPlotConfig, cls).__new__(cls, dict_data, plot_type)
# Structs for plot data
[docs]class PlotInSampleArm(NamedTuple): """Struct for in-sample arms (both observed and predicted data)""" name: str parameters: TParameterization y: Dict[str, float] y_hat: Dict[str, float] se: Dict[str, float] se_hat: Dict[str, float] context_stratum: Optional[Dict[str, Union[str, float]]]
[docs]class PlotOutOfSampleArm(NamedTuple): """Struct for out-of-sample arms (only predicted data)""" name: str parameters: TParameterization y_hat: Dict[str, float] se_hat: Dict[str, float] context_stratum: Optional[Dict[str, Union[str, float]]]
[docs]class PlotData(NamedTuple): """Struct for plot data, including both in-sample and out-of-sample arms""" metrics: List[str] in_sample: Dict[str, PlotInSampleArm] out_of_sample: Optional[Dict[str, Dict[str, PlotOutOfSampleArm]]] status_quo_name: Optional[str]
[docs]class PlotMetric(NamedTuple): """Struct for metric""" # @TODO T40555279: metric --> metric_name everywhere in plotting metric: str pred: bool