Source code for ax.models.discrete.eb_thompson
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import logging
import numpy as np
from ax.models.discrete.thompson import ThompsonSampler
from ax.utils.common.logger import get_logger
from ax.utils.stats.statstools import positive_part_james_stein
logger: logging.Logger = get_logger(__name__)
[docs]
class EmpiricalBayesThompsonSampler(ThompsonSampler):
"""Generator for Thompson sampling using Empirical Bayes estimates.
The generator applies positive-part James-Stein Estimator to the data
passed in via `fit` and then performs Thompson Sampling.
"""
def _fit_Ys_and_Yvars(
self, Ys: list[list[float]], Yvars: list[list[float]], outcome_names: list[str]
) -> tuple[list[list[float]], list[list[float]]]:
newYs = []
newYvars = []
for i, (Y, Yvar) in enumerate(zip(Ys, Yvars)):
newY, newYvar = self._apply_shrinkage(Y, Yvar, i)
newYs.append(newY)
newYvars.append(newYvar)
return newYs, newYvars
def _apply_shrinkage(
self, Y: list[float], Yvar: list[float], outcome: int
) -> tuple[list[float], list[float]]:
npY = np.array(Y)
npYvar = np.array(Yvar)
npYsem = np.sqrt(Yvar)
try:
npY, npYsem = positive_part_james_stein(means=npY, sems=npYsem)
except ValueError as e:
logger.warning(
str(e) + f" Raw (unshrunk) estimates used for outcome: {outcome}"
)
Y = npY.tolist()
npYvar = npYsem**2
Yvar = npYvar.tolist()
return Y, Yvar