#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Any
import numpy as np
import pandas as pd
from ax.core.base_trial import BaseTrial
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.core.types import TParameterization, TParamValue
from ax.utils.common.result import Err, Ok
from ax.utils.stats.statstools import agresti_coull_sem
[docs]
class FactorialMetric(Metric):
"""Metric for testing factorial designs assuming a main effects only
logit model.
"""
def __init__(
self,
name: str,
coefficients: dict[str, dict[TParamValue, float]],
batch_size: int = 10000,
noise_var: float = 0.0,
) -> None:
"""
Args:
name: name of the metric.
coefficients: a dictionary mapping
factors to levels to main effects.
batch_size: the sample size for one batch, distributed
between arms proportionally to the design.
noise_var: used in calculating the probability of
each arm.
"""
super().__init__(name)
self.coefficients = coefficients
self.batch_size = batch_size
self.noise_var = noise_var
[docs]
@classmethod
def is_available_while_running(cls) -> bool:
# This metric does not require a trial to complete to fetch its
# data, since there is no actual "data" to be fetched –– its
# fabricated from parameterizations.
return True
[docs]
def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult:
try:
if not isinstance(trial, BatchTrial):
raise ValueError(
"Factorial metric can only fetch data for batch trials."
)
if not trial.status.expecting_data:
raise ValueError("Can only fetch data if trial is expecting data.")
data = []
normalized_arm_weights = trial.normalized_arm_weights()
for name, arm in trial.arms_by_name.items():
weight = normalized_arm_weights[arm]
mean, sem = evaluation_function(
parameterization=arm.parameters,
weight=weight,
coefficients=self.coefficients,
batch_size=self.batch_size,
noise_var=self.noise_var,
)
n = np.random.binomial(self.batch_size, weight)
data.append(
{
"arm_name": name,
"metric_name": self.name,
"mean": mean,
"sem": sem,
"trial_index": trial.index,
"n": n,
"frac_nonnull": mean,
}
)
return Ok(value=Data(df=pd.DataFrame(data)))
except Exception as e:
return Err(
MetricFetchE(message=f"Failed to fetch {self.name}", exception=e)
)
[docs]
def evaluation_function(
parameterization: TParameterization,
coefficients: dict[str, dict[TParamValue, float]],
weight: float = 1.0,
batch_size: int = 10000,
noise_var: float = 0.0,
) -> tuple[float, float]:
probability = _parameterization_probability(
parameterization=parameterization,
coefficients=coefficients,
noise_var=noise_var,
)
plays = np.random.binomial(batch_size, weight)
successes = np.random.binomial(plays, probability)
mean = float(successes) / plays
sem = agresti_coull_sem(successes, plays)
assert isinstance(sem, float)
return mean, sem
def _parameterization_probability(
parameterization: TParameterization,
coefficients: dict[str, dict[TParamValue, float]],
noise_var: float = 0.0,
) -> float:
z = 0.0
for factor, level in parameterization.items():
if factor not in coefficients.keys():
raise ValueError(f"{factor} not in supplied coefficients")
if level not in coefficients[factor].keys():
raise ValueError(f"{level} not a valid level of {factor}")
z += coefficients[factor][level]
z += np.sqrt(noise_var) * np.random.randn()
return np.exp(z) / (1 + np.exp(z))