Source code for ax.plot.bandit_rollout

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List

import plotly.graph_objs as go
from ax.core.batch_trial import BatchTrial
from ax.core.experiment import Experiment
from ax.plot.base import AxPlotConfig, AxPlotTypes
from ax.plot.color import MIXED_SCALE, rgba


[docs]def plot_bandit_rollout(experiment: Experiment) -> AxPlotConfig: """Plot bandit rollout from ane experiement.""" categories: List[str] = [] arms: Dict[str, Dict[str, Any]] = {} data = [] index = 0 for trial in sorted(experiment.trials.values(), key=lambda trial: trial.index): if not isinstance(trial, BatchTrial): raise ValueError("Bandit rollout graph is not supported for BaseTrial.") category = f"Round {trial.index}" categories.append(category) for arm, weight in trial.normalized_arm_weights(total=100).items(): if arm.name not in arms: arms[arm.name] = { "index": index, "name": arm.name, "x": [], "y": [], "text": [], } index += 1 arms[arm.name]["x"].append(category) arms[arm.name]["y"].append(weight) arms[arm.name]["text"].append("{:.2f}%".format(weight)) for key in arms.keys(): data.append(arms[key]) colors = [rgba(c) for c in MIXED_SCALE] layout = go.Layout( title="Rollout Process<br>Bandit Weight Graph", xaxis={ "title": "Rounds", "zeroline": False, "categoryorder": "array", "categoryarray": categories, }, yaxis={"title": "Percent", "showline": False}, barmode="stack", showlegend=False, margin={"r": 40}, ) bandit_config = {"type": "bar", "hoverinfo": "name+text", "width": 0.5} bandits = [ dict(bandit_config, marker={"color": colors[d["index"] % len(colors)]}, **d) for d in data ] for bandit in bandits: del bandit["index"] # Have to delete index or figure creation causes error fig = go.Figure(data=bandits, layout=layout) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)