Source code for ax.benchmark.benchmark_metric
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Any
import pandas as pd
from ax.core.base_trial import BaseTrial
from ax.core.data import Data
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.utils.common.result import Err, Ok
[docs]
class BenchmarkMetric(Metric):
"""A generic metric used for observed values produced by Ax Benchmarks.
Compatible with results generated by `BenchmarkRunner`.
"""
def __init__(
self,
name: str,
lower_is_better: bool, # TODO: Do we need to define this here?
observe_noise_sd: bool = True,
outcome_index: int | None = None,
) -> None:
"""
Args:
name: Name of the metric.
lower_is_better: If `True`, lower metric values are considered better.
observe_noise_sd: If `True`, the standard deviation of the observation
noise is included in the `sem` column of the the returned data.
If `False`, `sem` is set to `None` (meaning that the model will
have to infer the noise level).
outcome_index: The index of the output. This is applicable in settings
where the underlying test problem is evaluated in a vectorized fashion
across multiple outputs, without providing a name for each output.
In such cases, `outcome_index` is used in `fetch_trial_data` to extract
`Ys` and `Yvars`, and `name` is the name of the metric.
"""
super().__init__(name=name, lower_is_better=lower_is_better)
# Declare `lower_is_better` as bool (rather than optional as in the base class)
self.lower_is_better: bool = lower_is_better
self.observe_noise_sd = observe_noise_sd
self.outcome_index = outcome_index
[docs]
def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult:
"""
Args:
trial: The trial from which to fetch data.
kwargs: Unsupported and will raise an exception.
Returns:
A MetricFetchResult containing the data for the requested metric.
"""
if len(kwargs) > 0:
raise NotImplementedError(
f"Arguments {set(kwargs)} are not supported in "
f"{self.__class__.__name__}.fetch_trial_data."
)
outcome_index = self.outcome_index
if outcome_index is None:
# Look up the index based on the outcome name under which we track the data
# as part of `run_metadata`.
outcome_names = trial.run_metadata.get("outcome_names")
if outcome_names is None:
raise RuntimeError(
"Trials' `run_metadata` must contain `outcome_names` if "
"no `outcome_index` is provided."
)
outcome_index = outcome_names.index(self.name)
try:
arm_names = list(trial.arms_by_name.keys())
all_Ys = trial.run_metadata["Ys"]
Ys = [all_Ys[arm_name][outcome_index] for arm_name in arm_names]
if self.observe_noise_sd:
stdvs = [
trial.run_metadata["Ystds"][arm_name][outcome_index]
for arm_name in arm_names
]
else:
stdvs = [float("nan")] * len(Ys)
df = pd.DataFrame(
{
"arm_name": arm_names,
"metric_name": self.name,
"mean": Ys,
"sem": stdvs,
"trial_index": trial.index,
}
)
return Ok(value=Data(df=df))
except Exception as e:
return Err(
MetricFetchE(
message=f"Failed to obtain data for trial {trial.index}",
exception=e,
)
)