Skip to main content
Version: 0.5.0

Factorial design with empirical Bayes and Thompson Sampling

This tutorial illustrates how to run a factorial experiment. In such an experiment, each parameter (factor) can be assigned one of multiple discrete values (levels). A full-factorial experiment design explores all possible combinations of factors and levels.

For instance, consider a banner with a title and an image. We are considering two different titles and three different images. A full-factorial experiment will compare all 2*3=6 possible combinations of title and image, to see which version of the banner performs the best.

In this example, we first run an exploratory batch to collect data on all possible combinations. Then we use empirical Bayes to model the data and shrink noisy estimates toward the mean. Next, we use Thompson Sampling to suggest a set of arms (combinations of factors and levels) on which to collect more data. We repeat the process until we have identified the best performing combination(s).

import sys
in_colab = 'google.colab' in sys.modules
if in_colab:
%pip install ax-platform
import numpy as np
import pandas as pd
import sklearn as skl
from typing import Dict, Optional, Tuple, Union
from ax import (
Arm,
ChoiceParameter,
Models,
ParameterType,
SearchSpace,
Experiment,
OptimizationConfig,
Objective,
)
from ax.plot.scatter import plot_fitted
from ax.utils.notebook.plotting import render, init_notebook_plotting
from ax.utils.stats.statstools import agresti_coull_sem
import plotly.io as pio
init_notebook_plotting()
if in_colab:
pio.renderers.default = "colab"
Out:

[INFO 02-03 20:33:30] ax.utils.notebook.plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.

Out:

[INFO 02-03 20:33:30] ax.utils.notebook.plotting: Please see

(https://ax.dev/tutorials/visualizations.html#Fix-for-plots-that-are-not-rendering)

if visualizations are not rendering.

1. Define the search space

First, we define our search space. A factorial search space contains a ChoiceParameter for each factor, where the values of the parameter are its levels.

search_space = SearchSpace(
parameters=[
ChoiceParameter(
name="factor1",
parameter_type=ParameterType.STRING,
values=["level11", "level12", "level13"],
),
ChoiceParameter(
name="factor2",
parameter_type=ParameterType.STRING,
values=["level21", "level22"],
),
ChoiceParameter(
name="factor3",
parameter_type=ParameterType.STRING,
values=["level31", "level32", "level33", "level34"],
),
]
)
Out:

/tmp/ipykernel_9125/3645676115.py:3: AxParameterWarning:

is_ordered is not specified for ChoiceParameter "factor1". Defaulting to False since the parameter is a string with more than 2 choices.. To override this behavior (or avoid this warning), specify is_ordered during ChoiceParameter construction. Note that choice parameters with exactly 2 choices are always considered ordered and that the user-supplied is_ordered has no effect in this particular case.

/tmp/ipykernel_9125/3645676115.py:3: AxParameterWarning:

sort_values is not specified for ChoiceParameter "factor1". Defaulting to False for parameters of ParameterType STRING. To override this behavior (or avoid this warning), specify sort_values during ChoiceParameter construction.

/tmp/ipykernel_9125/3645676115.py:8: AxParameterWarning:

is_ordered is not specified for ChoiceParameter "factor2". Defaulting to True since there are exactly two choices.. To override this behavior (or avoid this warning), specify is_ordered during ChoiceParameter construction. Note that choice parameters with exactly 2 choices are always considered ordered and that the user-supplied is_ordered has no effect in this particular case.

/tmp/ipykernel_9125/3645676115.py:8: AxParameterWarning:

sort_values is not specified for ChoiceParameter "factor2". Defaulting to False for parameters of ParameterType STRING. To override this behavior (or avoid this warning), specify sort_values during ChoiceParameter construction.

/tmp/ipykernel_9125/3645676115.py:13: AxParameterWarning:

is_ordered is not specified for ChoiceParameter "factor3". Defaulting to False since the parameter is a string with more than 2 choices.. To override this behavior (or avoid this warning), specify is_ordered during ChoiceParameter construction. Note that choice parameters with exactly 2 choices are always considered ordered and that the user-supplied is_ordered has no effect in this particular case.

/tmp/ipykernel_9125/3645676115.py:13: AxParameterWarning:

sort_values is not specified for ChoiceParameter "factor3". Defaulting to False for parameters of ParameterType STRING. To override this behavior (or avoid this warning), specify sort_values during ChoiceParameter construction.

2. Define a custom metric

Second, we define a custom metric, which is responsible for computing the mean and standard error of a given arm.

In this example, each possible parameter value is given a coefficient. The higher the level, the higher the coefficient, and the higher the coefficients, the greater the mean.

The standard error of each arm is determined by the weight passed into the evaluation function, which represents the size of the population on which this arm was evaluated. The higher the weight, the greater the sample size, and thus the lower the standard error.

from ax import Data, Metric
from ax.utils.common.result import Ok
import pandas as pd
from random import random


one_hot_encoder = skl.preprocessing.OneHotEncoder(
categories=[par.values for par in search_space.parameters.values()],
)


class FactorialMetric(Metric):
def fetch_trial_data(self, trial):
records = []
for arm_name, arm in trial.arms_by_name.items():
params = arm.parameters
batch_size = 10000
noise_level = 0.0
weight = trial.normalized_arm_weights().get(arm, 1.0)
coefficients = np.array([0.1, 0.2, 0.3, 0.1, 0.2, 0.1, 0.2, 0.3, 0.4])
features = np.array(list(params.values())).reshape(1, -1)
encoded_features = one_hot_encoder.fit_transform(features)
z = (
coefficients @ encoded_features.T
+ np.sqrt(noise_level) * np.random.randn()
)
p = np.exp(z) / (1 + np.exp(z))
plays = np.random.binomial(batch_size, weight)
successes = np.random.binomial(plays, p)
records.append(
{
"arm_name": arm_name,
"metric_name": self.name,
"trial_index": trial.index,
"mean": float(successes) / plays,
"sem": agresti_coull_sem(successes, plays),
}
)
return Ok(value=Data(df=pd.DataFrame.from_records(records)))

3. Define the experiment

We now set up our experiment and define the status quo arm, in which each parameter is assigned to the lowest level.

from ax import Runner


class MyRunner(Runner):
def run(self, trial):
trial_metadata = {"name": str(trial.index)}
return trial_metadata


exp = Experiment(
name="my_factorial_closed_loop_experiment",
search_space=search_space,
optimization_config=OptimizationConfig(
objective=Objective(metric=FactorialMetric(name="success_metric"), minimize=False)
),
runner=MyRunner(),
)
exp.status_quo = Arm(
parameters={"factor1": "level11", "factor2": "level21", "factor3": "level31"}
)

4. Run an exploratory batch

We then generate an a set of arms that covers the full space of the factorial design, including the status quo. There are three parameters, with two, three, and four values, respectively, so there are 24 possible arms.

factorial = Models.FACTORIAL(search_space=exp.search_space)
factorial_run = factorial.gen(
n=-1
) # Number of arms to generate is derived from the search space.
print(len(factorial_run.arms))
Out:

24

Now we create a trial including all of these arms, so that we can collect data and evaluate the performance of each.

trial = exp.new_batch_trial(optimize_for_power=True).add_generator_run(
factorial_run, multiplier=1
)

By default, the weight of each arm in factorial_run will be 1. However, to optimize for power on the contrasts of k groups against the status quo, the status quo should be sqrt(k) larger than any of the treatment groups. Since we have 24 different arms in our search space, the status quo should be roughly five times larger. That larger weight is automatically set by Ax under the hood if optimize_for_power kwarg is set to True on new batched trial creation.

trial._status_quo_weight_override
Out:

4.795831523312719

5. Iterate using Thompson Sampling

Next, we run multiple trials (iterations of the experiment) to hone in on the optimal arm(s).

In each iteration, we first collect data about all arms in that trial by calling trial.run() and trial.mark_complete(). Then we run Thompson Sampling, which assigns a weight to each arm that is proportional to the probability of that arm being the best. Arms whose weight exceed min_weight are added to the next trial, so that we can gather more data on their performance.

models = []
for i in range(4):
print(f"Running trial {i+1}...")
trial.run()
trial.mark_completed()
thompson = Models.THOMPSON(experiment=exp, data=trial.fetch_data(), min_weight=0.01)
models.append(thompson)
thompson_run = thompson.gen(n=-1)
trial = exp.new_batch_trial(optimize_for_power=True).add_generator_run(thompson_run)
Out:

Running trial 1...

Running trial 2...

Running trial 3...

Running trial 4...

Out:

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/tmp/ipykernel_9125/2149564789.py:35: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

Plot 1: Predicted outcomes for each arm in initial trial

The plot below shows the mean and standard error for each arm in the first trial. We can see that the standard error for the status quo is the smallest, since this arm was assigned 5x weight.

render(plot_fitted(models[0], metric="success_metric", rel=False))
loading...

Plot 2: Predicted outcomes for arms in last trial

The following plot below shows the mean and standard error for each arm that made it to the last trial (as well as the status quo, which appears throughout).

render(plot_fitted(models[-1], metric="success_metric", rel=False))
loading...

As expected given our evaluation function, arms with higher levels perform better and are given higher weight. Below we see the arms that made it to the final trial.

results = pd.DataFrame(
[
{"values": ",".join(arm.parameters.values()), "weight": weight}
for arm, weight in trial.normalized_arm_weights().items()
]
)
print(results)
Out:

values weight

0 level13,level22,level34 0.416585

1 level12,level22,level33 0.205725

2 level13,level22,level33 0.011665

3 level11,level21,level31 0.366025

Plot 3: Rollout Process

We can also visualize the progression of the experience in the following rollout chart. Each bar represents a trial, and the width of the bands within a bar are proportional to the weight of the arms in that trial.

In the first trial, all arms appear with equal weight, except for the status quo. By the last trial, we have narrowed our focus to only four arms, with arm 0_22 (the arm with the highest levels) having the greatest weight.

from ax.plot.bandit_rollout import plot_bandit_rollout
from ax.utils.notebook.plotting import render

render(plot_bandit_rollout(exp))
loading...

Plot 4: Marginal Effects

Finally, we can examine which parameter values had the greatest effect on the overall arm value. As we see in the diagram below, arms whose parameters were assigned the lower level values (such as levell1, levell2, level31 and level32) performed worse than average, whereas arms with higher levels performed better than average.

from ax.plot.marginal_effects import plot_marginal_effects

render(plot_marginal_effects(models[0], "success_metric"))
loading...