⚠ INFO ⚠
This document discusses non-API components of Ax, which may be subject to backwards compatibility breaking changes between major library versions.
Utilizing and Creating Ax Analyses
Ax’s Analysis module provides a framework for producing plots, tables, messages, and
more to help users understand their experiments. This is facilitated via the Analysis
protocol and its various subclasses.
Analysis classes implement a method compute
which consumes an Experiment
,
GenerationStrategy
, and/or Adapter
and outputs a collection of AnalysisCards
.
These cards contain a dataframe with relevant data, a “blob” which contains data to be
rendered (ex. a plot), and miscellaneous metadata like a title, subtitle, and priority
level used for sorting. compute
returns a collection of cards so that Analyses can be
composed together. For example: the TopSurfacesPlot
computes a
SensitivityAnalysisPlot
to understand which parameters in the search space are most
relevent, then produces SlicePlot
s and ContourPlot
s for the most important surfaces.
Ax currently provides implementations for 3 base classes: (1)Analysis
-- for creating
tables, (2) PlotlyAnalysis
-- for producing plots using the Plotly library, and (3)
MarkdownAnalysis
-- for producing messages. Importantly Ax is able to save these cards
to the database using save_analysis_cards
, allowing for analyses to be pre-computed
and displayed at a later time. This is done automatically when Client.compute_analyses
is called.
Using Analyses
The simplest way to use an Analysis
is to call Client.compute_analyses
. This will
heuristically select the most relevant analyses to compute, save the cards to the
database, return them, and display them in your IPython environment if possible. Users
can also specify which analyses to compute and pass them in manually, for example:
client.compute_analyses(analyses=[TopSurfacesPlot(), Summary(), ...])
.
When developing a new Analysis
it can be useful to compute an analysis "a-la carte".
To do this, manually instantiate the Analysis
and call its compute
method. This will
return a collection of AnalysisCards
which can be displayed.
from ax import Client, RangeParameterConfig
# Create a Client and populate it with some data
client = Client()
client.configure_experiment(
name="booth_function",
parameters=[
RangeParameterConfig(
name="x1",
bounds=(-10.0, 10.0),
parameter_type="float",
),
RangeParameterConfig(
name="x2",
bounds=(-10.0, 10.0),
parameter_type="float",
),
],
)
client.configure_optimization(objective="-1 * booth")
for _ in range(10):
for trial_index, parameters in client.get_next_trials(max_trials=1).items():
client.complete_trial(
trial_index=trial_index,
raw_data={
"booth": (parameters["x1"] + 2 * parameters["x2"] - 7) ** 2
+ (2 * parameters["x1"] + parameters["x2"] - 5) ** 2
},
)
from ax.analysis.analysis import display_cards
from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot
analysis = ParallelCoordinatesPlot()
cards = analysis.compute(
experiment=client._experiment,
generation_strategy=client._generation_strategy,
# compute can optionally take in an Adapter directly instead of a GenerationStrategy
adapter=None,
)
# displcay_cards can be useful to group and sort AnalysisCards by type and level respectively
display_cards(cards=cards)
Parallel Coordinates for booth
The parallel coordinates plot displays multi-dimensional data by representing each parameter as a parallel axis. This plot helps in assessing how thoroughly the search space has been explored and in identifying patterns or clusterings associated with high-performing (good) or low-performing (bad) arms. By tracing lines across the axes, one can observe correlations and interactions between parameters, gaining insights into the relationships that contribute to the success or failure of different configurations within the experiment.
Creating a new Analysis
Let's implement a simple Analysis that returns a table counting the number of trials in
each TrialStatus
. We'll make a new class that implements the Analysis
protocol
(i.e. it defines a compute
method).
import random
from typing import Sequence
import pandas as pd
from ax.analysis.analysis import (
Analysis,
AnalysisCard,
AnalysisCardCategory,
AnalysisCardLevel,
)
from ax.core.experiment import Experiment
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.base import Adapter
class TrialStatusTable(Analysis):
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
adapter: Adapter | None = None,
) -> Sequence[AnalysisCard]:
trials_by_status = experiment.trials_by_status
records = [
{"status": status.name, "count": len(trials)}
for status, trials in trials_by_status.items()
if len(trials) > 0
]
return [
self._create_analysis_card(
title="Trials by Status",
subtitle="How many trials are in each status?",
level=AnalysisCardLevel.LOW,
category=AnalysisCardCategory.INSIGHT,
df=pd.DataFrame.from_records(records),
)
]
# Let's add some more trials of miscellaneous statuses before computing the new Analysis
for _ in range(10):
for trial_index, parameters in client.get_next_trials(max_trials=1).items():
roll = random.random()
if roll < 0.2:
client.mark_trial_failed(trial_index=trial_index)
elif roll < 0.5:
client.mark_trial_abandoned(trial_index=trial_index)
else:
client.complete_trial(
trial_index=trial_index,
raw_data={
"booth": (parameters["x1"] + 2 * parameters["x2"] - 7) ** 2
+ (2 * parameters["x1"] + parameters["x2"] - 5) ** 2
},
)
# Client.compute_analyses will call display_cards internall if display=True
cards = client.compute_analyses(analyses=[TrialStatusTable()], display=True)
/opt/hostedtoolcache/Python/3.12.10/x64/lib/python3.12/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
Trials by Status
How many trials are in each status?
status | count | |
---|---|---|
0 | FAILED | 1 |
1 | COMPLETED | 12 |
2 | ABANDONED | 7 |
Adding options to an Analysis
Imagine we wanted to add an option to change how this analysis is computed, say we wish
to toggle whether the analysis computes the number of trials in a given state or the
percentage of trials in a given state. We cannot change the input arguments to
compute
, so this must be added elsewhere.
The analysis' initializer is a natural place to put additional settings. We'll create a
TrialStatusTable.__init__
method which takes in the option as a boolean, then modify
compute
to consume this option as well. Following this patterns allows users to
specify all relevant settings before calling Client.compute_analyses
while still
allowing the underlying compute
call to remain unchanged. Standarization of the
compute
call simplifies logic elsewhere in the stack.
class TrialStatusTable(Analysis):
def __init__(self, as_fraction: bool) -> None:
super().__init__()
self.as_fraction = as_fraction
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
adapter: Adapter | None = None,
) -> Sequence[AnalysisCard]:
trials_by_status = experiment.trials_by_status
denominator = len(experiment.trials) if self.as_fraction else 1
records = [
{"status": status.name, "count": len(trials) / denominator}
for status, trials in trials_by_status.items()
if len(trials) > 0
]
return [
# Use _create_analysis_card rather than AnalysisCard to automatically populate relevant metadata
self._create_analysis_card(
title="Trials by Status",
subtitle="How many trials are in each status?",
level=AnalysisCardLevel.LOW,
category=AnalysisCardCategory.INSIGHT,
df=pd.DataFrame.from_records(records),
)
]
cards = client.compute_analyses(
analyses=[TrialStatusTable(as_fraction=True)], display=True
)
Trials by Status
How many trials are in each status?
status | count | |
---|---|---|
0 | FAILED | 0.05 |
1 | COMPLETED | 0.6 |
2 | ABANDONED | 0.35 |
Plotly Analyses
Analyses do not just have to be Pandas dataframes. Ax also defines a class
PlotlyAnalysis
class, where the compute
method returns a PlotlyAnalysisCard
containing both a dataframe and a plotly Figure
.
Implementing a PlotlyAnalysis
is not significantly different from creating a base
Analysis
. Let's create a bar chart based on TrialStatusTable
.
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from plotly import express as px
class TrialStatusTable(PlotlyAnalysis):
def __init__(self, as_fraction: bool) -> None:
super().__init__()
self.as_fraction = as_fraction
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
adapter: Adapter | None = None,
) -> Sequence[PlotlyAnalysisCard]:
trials_by_status = experiment.trials_by_status
denominator = len(experiment.trials) if self.as_fraction else 1
records = [
{"status": status.name, "count": len(trials) / denominator}
for status, trials in trials_by_status.items()
if len(trials) > 0
]
df = pd.DataFrame.from_records(records)
# Create a Plotly figure using the df we generated before
fig = px.bar(df, x="status", y="count")
return [
# Use _create_plotly_analysis_card rather than AnalysisCard to automatically populate relevant metadata
self._create_plotly_analysis_card(
title="Trials by Status",
subtitle="How many trials are in each status?",
level=AnalysisCardLevel.LOW,
category=AnalysisCardCategory.INSIGHT,
df=df,
fig=fig,
)
]
cards = client.compute_analyses(
analyses=[TrialStatusTable(as_fraction=True)], display=True
)
Trials by Status
How many trials are in each status?
Miscellaneous tips
- Many analyses rely on the same infrastructure and utility functions -- check to see if
what you need has already been implemented somewhere.
- Many analyses require an
Adapter
but can use either theAdapter
provided or the currentAdapter
on theGenerationStrategy
--extract_relevant_adapter
handles this in a consistent way - Analyses which use an
Arm
as the fundamental unit of analysis will find theprepare_arm_data
utility useful; using it will also lend theAnalysis
useful features like relativization for free
- Many analyses require an
- When writing a new
PlotlyAnalysis
check outax.analysis.plotly.utils
for guidance on using color schemes and unified tool tips - Try to follow consistent design patterns; many analyses take an optional list of
metric_names
on initialization, and interpretNone
to mean the user wants to compute a card for each metric present. Following these conventions makes things easier for downstream consumers.