Source code for ax.plot.base
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import enum
import json
from typing import Any, NamedTuple
from ax.core.types import TParameterization
from ax.utils.common.serialization import named_tuple_to_dict
from plotly import utils
# Constants used for numerous plots
CI_OPACITY = 0.4
DECIMALS = 3
Z = 1.96
[docs]
@enum.unique
class AxPlotTypes(enum.Enum):
"""Enum of Ax plot types."""
CONTOUR = 0
GENERIC = 1
SLICE = 2
INTERACT_CONTOUR = 3
BANDIT_ROLLOUT = 4
INTERACT_SLICE = 5
HTML = 6
# Configuration for all plots
class _AxPlotConfigBase(NamedTuple):
data: dict[str, Any]
plot_type: enum.Enum
[docs]
class AxPlotConfig(_AxPlotConfigBase):
"""Config for plots"""
def __new__(cls, data: dict[str, Any], plot_type: enum.Enum) -> "AxPlotConfig":
# Convert data to json-encodable form (strips out NamedTuple and numpy
# array). This is a lossy conversion.
dict_data = json.loads(
json.dumps(named_tuple_to_dict(data), cls=utils.PlotlyJSONEncoder)
)
# pyre-fixme[7]: Expected `AxPlotConfig` but got `NamedTuple`.
return super().__new__(cls, dict_data, plot_type)
# Structs for plot data
[docs]
class PlotInSampleArm(NamedTuple):
"""Struct for in-sample arms (both observed and predicted data)"""
name: str
parameters: TParameterization
y: dict[str, float]
y_hat: dict[str, float]
se: dict[str, float]
se_hat: dict[str, float]
context_stratum: dict[str, str | float] | None
[docs]
class PlotOutOfSampleArm(NamedTuple):
"""Struct for out-of-sample arms (only predicted data)"""
name: str
parameters: TParameterization
y_hat: dict[str, float]
se_hat: dict[str, float]
context_stratum: dict[str, str | float] | None
[docs]
class PlotData(NamedTuple):
"""Struct for plot data, including both in-sample and out-of-sample arms"""
metrics: list[str]
in_sample: dict[str, PlotInSampleArm]
out_of_sample: dict[str, dict[str, PlotOutOfSampleArm]] | None
status_quo_name: str | None
[docs]
class PlotMetric(NamedTuple):
"""Struct for metric"""
# @TODO T40555279: metric --> metric_name everywhere in plotting
metric: str
pred: bool
rel: bool