Source code for ax.plot.bandit_rollout
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Any
import plotly.graph_objs as go
from ax.core.batch_trial import BatchTrial
from ax.core.experiment import Experiment
from ax.plot.base import AxPlotConfig, AxPlotTypes
from ax.plot.color import MIXED_SCALE, rgba
[docs]
def plot_bandit_rollout(experiment: Experiment) -> AxPlotConfig:
"""Plot bandit rollout from ane experiement."""
categories: list[str] = []
arms: dict[str, dict[str, Any]] = {}
data = []
index = 0
for trial in sorted(experiment.trials.values(), key=lambda trial: trial.index):
if not isinstance(trial, BatchTrial):
raise ValueError("Bandit rollout graph is not supported for BaseTrial.")
category = f"Round {trial.index}"
categories.append(category)
for arm, weight in trial.normalized_arm_weights(total=100).items():
if arm.name not in arms:
arms[arm.name] = {
"index": index,
"name": arm.name,
"x": [],
"y": [],
"text": [],
}
index += 1
arms[arm.name]["x"].append(category)
arms[arm.name]["y"].append(weight)
arms[arm.name]["text"].append(f"{weight:.2f}%")
for key in arms.keys():
data.append(arms[key])
colors = [rgba(c) for c in MIXED_SCALE]
layout = go.Layout(
title="Rollout Process<br>Bandit Weight Graph",
xaxis={
"title": "Rounds",
"zeroline": False,
"categoryorder": "array",
"categoryarray": categories,
},
yaxis={"title": "Percent", "showline": False},
barmode="stack",
showlegend=False,
margin={"r": 40},
)
bandit_config = {"type": "bar", "hoverinfo": "name+text", "width": 0.5}
bandits = [
dict(bandit_config, marker={"color": colors[d["index"] % len(colors)]}, **d)
for d in data
]
for bandit in bandits:
del bandit["index"] # Have to delete index or figure creation causes error
fig = go.Figure(data=bandits, layout=layout)
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)