Source code for ax.metrics.noisy_function
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from typing import Any, List, Optional
import numpy as np
import pandas as pd
from ax.core.base_trial import BaseTrial
from ax.core.data import Data
from ax.core.metric import Metric
[docs]class NoisyFunctionMetric(Metric):
"""A metric defined by a generic deterministic function, with normal noise
with mean 0 and mean_sd scale added to the result.
"""
def __init__(
self,
name: str,
param_names: List[str],
noise_sd: float = 0.0,
lower_is_better: Optional[bool] = None,
) -> None:
"""
Metric is computed by evaluating a deterministic function, implemented
in f.
f will expect an array x, which is constructed from the arm
parameters by extracting the values of the parameter names given in
param_names, in that order.
Args:
name: Name of the metric
param_names: An ordered list of names of parameters to be passed
to the deterministic function.
noise_sd: Scale of normal noise added to the function result.
lower_is_better: Flag for metrics which should be minimized.
"""
self.param_names = param_names
self.noise_sd = noise_sd
super().__init__(name=name, lower_is_better=lower_is_better)
[docs] def clone(self) -> "NoisyFunctionMetric":
return self.__class__(
name=self._name,
param_names=self.param_names,
noise_sd=self.noise_sd,
lower_is_better=self.lower_is_better,
)
[docs] def fetch_trial_data(
self, trial: BaseTrial, noisy: bool = True, **kwargs: Any
) -> Data:
noise_sd = self.noise_sd if noisy else 0.0
arm_names = []
mean = []
for name, arm in trial.arms_by_name.items():
arm_names.append(name)
x = np.array([arm.parameters[p] for p in self.param_names])
mean.append(self.f(x) + np.random.randn() * noise_sd)
df = pd.DataFrame(
{
"arm_name": arm_names,
"metric_name": self.name,
"mean": mean,
"sem": noise_sd,
"trial_index": trial.index,
"n": 10000 / len(arm_names),
"frac_nonnull": mean,
}
)
return Data(df=df)
[docs] def f(self, x: np.ndarray) -> float:
"""The deterministic function that produces the metric outcomes."""
raise NotImplementedError