Source code for ax.metrics.factorial

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Tuple

import numpy as np
import pandas as pd
from ax.core.base_trial import BaseTrial
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.core.types import TParameterization, TParamValue
from ax.utils.common.result import Err, Ok
from ax.utils.stats.statstools import agresti_coull_sem


[docs]class FactorialMetric(Metric): """Metric for testing factorial designs assuming a main effects only logit model. """ def __init__( self, name: str, coefficients: Dict[str, Dict[TParamValue, float]], batch_size: int = 10000, noise_var: float = 0.0, ) -> None: """ Args: name: name of the metric. coefficients: a dictionary mapping factors to levels to main effects. batch_size: the sample size for one batch, distributed between arms proportionally to the design. noise_var: used in calculating the probability of each arm. """ super(FactorialMetric, self).__init__(name) self.coefficients = coefficients self.batch_size = batch_size self.noise_var = noise_var
[docs] @classmethod def is_available_while_running(cls) -> bool: # This metric does not require a trial to complete to fetch its # data, since there is no actual "data" to be fetched –– its # fabricated from parameterizations. return True
[docs] def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult: try: if not isinstance(trial, BatchTrial): raise ValueError( "Factorial metric can only fetch data for batch trials." ) if not trial.status.expecting_data: raise ValueError("Can only fetch data if trial is expecting data.") data = [] normalized_arm_weights = trial.normalized_arm_weights() for name, arm in trial.arms_by_name.items(): weight = normalized_arm_weights[arm] mean, sem = evaluation_function( parameterization=arm.parameters, weight=weight, coefficients=self.coefficients, batch_size=self.batch_size, noise_var=self.noise_var, ) n = np.random.binomial(self.batch_size, weight) data.append( { "arm_name": name, "metric_name": self.name, "mean": mean, "sem": sem, "trial_index": trial.index, "n": n, "frac_nonnull": mean, } ) return Ok(value=Data(df=pd.DataFrame(data))) except Exception as e: return Err( MetricFetchE(message=f"Failed to fetch {self.name}", exception=e) )
[docs]def evaluation_function( parameterization: TParameterization, coefficients: Dict[str, Dict[TParamValue, float]], weight: float = 1.0, batch_size: int = 10000, noise_var: float = 0.0, ) -> Tuple[float, float]: probability = _parameterization_probability( parameterization=parameterization, coefficients=coefficients, noise_var=noise_var, ) plays = np.random.binomial(batch_size, weight) successes = np.random.binomial(plays, probability) mean = float(successes) / plays sem = agresti_coull_sem(successes, plays) assert isinstance(sem, float) return mean, sem
def _parameterization_probability( parameterization: TParameterization, coefficients: Dict[str, Dict[TParamValue, float]], noise_var: float = 0.0, ) -> float: z = 0.0 for factor, level in parameterization.items(): if factor not in coefficients.keys(): raise ValueError("{} not in supplied coefficients".format(factor)) if level not in coefficients[factor].keys(): raise ValueError("{} not a valid level of {}".format(level, factor)) z += coefficients[factor][level] z += np.sqrt(noise_var) * np.random.randn() return np.exp(z) / (1 + np.exp(z))