Factorial design with empirical Bayes and Thompson Sampling
This tutorial illustrates how to run a factorial experiment. In such an experiment, each parameter (factor) can be assigned one of multiple discrete values (levels). A full-factorial experiment design explores all possible combinations of factors and levels.
For instance, consider a banner with a title and an image. We are considering two different titles and three different images. A full-factorial experiment will compare all 2*3=6 possible combinations of title and image, to see which version of the banner performs the best.
In this example, we first run an exploratory batch to collect data on all possible combinations. Then we use empirical Bayes to model the data and shrink noisy estimates toward the mean. Next, we use Thompson Sampling to suggest a set of arms (combinations of factors and levels) on which to collect more data. We repeat the process until we have identified the best performing combination(s).
import sys
in_colab = 'google.colab' in sys.modules
if in_colab:
%pip install ax-platform
import numpy as np
import pandas as pd
import sklearn as skl
from typing import Dict, Optional, Tuple, Union
from ax import (
Arm,
ChoiceParameter,
Models,
ParameterType,
SearchSpace,
Experiment,
OptimizationConfig,
Objective,
)
from ax.plot.scatter import plot_fitted
from ax.utils.notebook.plotting import render, init_notebook_plotting
from ax.utils.stats.statstools import agresti_coull_sem
import plotly.io as pio
init_notebook_plotting()
if in_colab:
pio.renderers.default = "colab"
[INFO 02-03 20:33:30] ax.utils.notebook.plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.
[INFO 02-03 20:33:30] ax.utils.notebook.plotting: Please see
(https://ax.dev/tutorials/visualizations.html#Fix-for-plots-that-are-not-rendering)
if visualizations are not rendering.
1. Define the search space
First, we define our search space. A factorial search space contains a ChoiceParameter for each factor, where the values of the parameter are its levels.
search_space = SearchSpace(
parameters=[
ChoiceParameter(
name="factor1",
parameter_type=ParameterType.STRING,
values=["level11", "level12", "level13"],
),
ChoiceParameter(
name="factor2",
parameter_type=ParameterType.STRING,
values=["level21", "level22"],
),
ChoiceParameter(
name="factor3",
parameter_type=ParameterType.STRING,
values=["level31", "level32", "level33", "level34"],
),
]
)
/tmp/ipykernel_9125/3645676115.py:3: AxParameterWarning:
is_ordered is not specified for ChoiceParameter "factor1". Defaulting to False since the parameter is a string with more than 2 choices.. To override this behavior (or avoid this warning), specify is_ordered during ChoiceParameter construction. Note that choice parameters with exactly 2 choices are always considered ordered and that the user-supplied is_ordered has no effect in this particular case.
/tmp/ipykernel_9125/3645676115.py:3: AxParameterWarning:
sort_values is not specified for ChoiceParameter "factor1". Defaulting to False for parameters of ParameterType STRING. To override this behavior (or avoid this warning), specify sort_values during ChoiceParameter construction.
/tmp/ipykernel_9125/3645676115.py:8: AxParameterWarning:
is_ordered is not specified for ChoiceParameter "factor2". Defaulting to True since there are exactly two choices.. To override this behavior (or avoid this warning), specify is_ordered during ChoiceParameter construction. Note that choice parameters with exactly 2 choices are always considered ordered and that the user-supplied is_ordered has no effect in this particular case.
/tmp/ipykernel_9125/3645676115.py:8: AxParameterWarning:
sort_values is not specified for ChoiceParameter "factor2". Defaulting to False for parameters of ParameterType STRING. To override this behavior (or avoid this warning), specify sort_values during ChoiceParameter construction.
/tmp/ipykernel_9125/3645676115.py:13: AxParameterWarning:
is_ordered is not specified for ChoiceParameter "factor3". Defaulting to False since the parameter is a string with more than 2 choices.. To override this behavior (or avoid this warning), specify is_ordered during ChoiceParameter construction. Note that choice parameters with exactly 2 choices are always considered ordered and that the user-supplied is_ordered has no effect in this particular case.
/tmp/ipykernel_9125/3645676115.py:13: AxParameterWarning:
sort_values is not specified for ChoiceParameter "factor3". Defaulting to False for parameters of ParameterType STRING. To override this behavior (or avoid this warning), specify sort_values during ChoiceParameter construction.
2. Define a custom metric
Second, we define a custom metric, which is responsible for computing the mean and standard error of a given arm.
In this example, each possible parameter value is given a coefficient. The higher the level, the higher the coefficient, and the higher the coefficients, the greater the mean.
The standard error of each arm is determined by the weight passed into the evaluation function, which represents the size of the population on which this arm was evaluated. The higher the weight, the greater the sample size, and thus the lower the standard error.
from ax import Data, Metric
from ax.utils.common.result import Ok
import pandas as pd
from random import random
one_hot_encoder = skl.preprocessing.OneHotEncoder(
categories=[par.values for par in search_space.parameters.values()],
)
class FactorialMetric(Metric):
def fetch_trial_data(self, trial):
records = []
for arm_name, arm in trial.arms_by_name.items():
params = arm.parameters
batch_size = 10000
noise_level = 0.0
weight = trial.normalized_arm_weights().get(arm, 1.0)
coefficients = np.array([0.1, 0.2, 0.3, 0.1, 0.2, 0.1, 0.2, 0.3, 0.4])
features = np.array(list(params.values())).reshape(1, -1)
encoded_features = one_hot_encoder.fit_transform(features)
z = (
coefficients @ encoded_features.T
+ np.sqrt(noise_level) * np.random.randn()
)
p = np.exp(z) / (1 + np.exp(z))
plays = np.random.binomial(batch_size, weight)
successes = np.random.binomial(plays, p)
records.append(
{
"arm_name": arm_name,
"metric_name": self.name,
"trial_index": trial.index,
"mean": float(successes) / plays,
"sem": agresti_coull_sem(successes, plays),
}
)
return Ok(value=Data(df=pd.DataFrame.from_records(records)))
3. Define the experiment
We now set up our experiment and define the status quo arm, in which each parameter is assigned to the lowest level.
from ax import Runner
class MyRunner(Runner):
def run(self, trial):
trial_metadata = {"name": str(trial.index)}
return trial_metadata
exp = Experiment(
name="my_factorial_closed_loop_experiment",
search_space=search_space,
optimization_config=OptimizationConfig(
objective=Objective(metric=FactorialMetric(name="success_metric"), minimize=False)
),
runner=MyRunner(),
)
exp.status_quo = Arm(
parameters={"factor1": "level11", "factor2": "level21", "factor3": "level31"}
)
4. Run an exploratory batch
We then generate an a set of arms that covers the full space of the factorial design, including the status quo. There are three parameters, with two, three, and four values, respectively, so there are 24 possible arms.
factorial = Models.FACTORIAL(search_space=exp.search_space)
factorial_run = factorial.gen(
n=-1
) # Number of arms to generate is derived from the search space.
print(len(factorial_run.arms))
24
Now we create a trial including all of these arms, so that we can collect data and evaluate the performance of each.
trial = exp.new_batch_trial(optimize_for_power=True).add_generator_run(
factorial_run, multiplier=1
)
By default, the weight of each arm in factorial_run
will be 1. However, to optimize
for power on the contrasts of k
groups against the status quo, the status quo should
be sqrt(k)
larger than any of the treatment groups. Since we have 24 different arms in
our search space, the status quo should be roughly five times larger. That larger weight
is automatically set by Ax under the hood if optimize_for_power
kwarg is set to True
on new batched trial creation.
trial._status_quo_weight_override
4.795831523312719
5. Iterate using Thompson Sampling
Next, we run multiple trials (iterations of the experiment) to hone in on the optimal arm(s).
In each iteration, we first collect data about all arms in that trial by calling
trial.run()
and trial.mark_complete()
. Then we run Thompson Sampling, which assigns
a weight to each arm that is proportional to the probability of that arm being the best.
Arms whose weight exceed min_weight
are added to the next trial, so that we can gather
more data on their performance.
models = []
for i in range(4):
print(f"Running trial {i+1}...")
trial.run()
trial.mark_completed()
thompson = Models.THOMPSON(experiment=exp, data=trial.fetch_data(), min_weight=0.01)
models.append(thompson)
thompson_run = thompson.gen(n=-1)
trial = exp.new_batch_trial(optimize_for_power=True).add_generator_run(thompson_run)
Running trial 1...
Running trial 2...
Running trial 3...
Running trial 4...
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:
Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
Plot 1: Predicted outcomes for each arm in initial trial
The plot below shows the mean and standard error for each arm in the first trial. We can see that the standard error for the status quo is the smallest, since this arm was assigned 5x weight.
render(plot_fitted(models[0], metric="success_metric", rel=False))
Plot 2: Predicted outcomes for arms in last trial
The following plot below shows the mean and standard error for each arm that made it to the last trial (as well as the status quo, which appears throughout).
render(plot_fitted(models[-1], metric="success_metric", rel=False))
As expected given our evaluation function, arms with higher levels perform better and are given higher weight. Below we see the arms that made it to the final trial.
results = pd.DataFrame(
[
{"values": ",".join(arm.parameters.values()), "weight": weight}
for arm, weight in trial.normalized_arm_weights().items()
]
)
print(results)
values weight
0 level13,level22,level34 0.416585
1 level12,level22,level33 0.205725
2 level13,level22,level33 0.011665
3 level11,level21,level31 0.366025
Plot 3: Rollout Process
We can also visualize the progression of the experience in the following rollout chart. Each bar represents a trial, and the width of the bands within a bar are proportional to the weight of the arms in that trial.
In the first trial, all arms appear with equal weight, except for the status quo. By the last trial, we have narrowed our focus to only four arms, with arm 0_22 (the arm with the highest levels) having the greatest weight.
from ax.plot.bandit_rollout import plot_bandit_rollout
from ax.utils.notebook.plotting import render
render(plot_bandit_rollout(exp))
Plot 4: Marginal Effects
Finally, we can examine which parameter values had the greatest effect on the overall
arm value. As we see in the diagram below, arms whose parameters were assigned the lower
level values (such as levell1
, levell2
, level31
and level32
) performed worse
than average, whereas arms with higher levels performed better than average.
from ax.plot.marginal_effects import plot_marginal_effects
render(plot_marginal_effects(models[0], "success_metric"))