Source code for ax.benchmark.benchmark_metric

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any

import pandas as pd
from ax.core.base_trial import BaseTrial

from ax.core.data import Data
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.utils.common.result import Err, Ok


[docs] class BenchmarkMetric(Metric): """A generic metric used for observed values produced by Ax Benchmarks. Compatible with results generated by `BenchmarkRunner`. """ def __init__( self, name: str, lower_is_better: bool, # TODO: Do we need to define this here? observe_noise_sd: bool = True, outcome_index: int | None = None, ) -> None: """ Args: name: Name of the metric. lower_is_better: If `True`, lower metric values are considered better. observe_noise_sd: If `True`, the standard deviation of the observation noise is included in the `sem` column of the the returned data. If `False`, `sem` is set to `None` (meaning that the model will have to infer the noise level). outcome_index: The index of the output. This is applicable in settings where the underlying test problem is evaluated in a vectorized fashion across multiple outputs, without providing a name for each output. In such cases, `outcome_index` is used in `fetch_trial_data` to extract `Ys` and `Yvars`, and `name` is the name of the metric. """ super().__init__(name=name, lower_is_better=lower_is_better) # Declare `lower_is_better` as bool (rather than optional as in the base class) self.lower_is_better: bool = lower_is_better self.observe_noise_sd = observe_noise_sd self.outcome_index = outcome_index
[docs] def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult: """ Args: trial: The trial from which to fetch data. kwargs: Unsupported and will raise an exception. Returns: A MetricFetchResult containing the data for the requested metric. """ if len(kwargs) > 0: raise NotImplementedError( f"Arguments {set(kwargs)} are not supported in " f"{self.__class__.__name__}.fetch_trial_data." ) outcome_index = self.outcome_index if outcome_index is None: # Look up the index based on the outcome name under which we track the data # as part of `run_metadata`. outcome_names = trial.run_metadata.get("outcome_names") if outcome_names is None: raise RuntimeError( "Trials' `run_metadata` must contain `outcome_names` if " "no `outcome_index` is provided." ) outcome_index = outcome_names.index(self.name) try: arm_names = list(trial.arms_by_name.keys()) all_Ys = trial.run_metadata["Ys"] Ys = [all_Ys[arm_name][outcome_index] for arm_name in arm_names] if self.observe_noise_sd: stdvs = [ trial.run_metadata["Ystds"][arm_name][outcome_index] for arm_name in arm_names ] else: stdvs = [float("nan")] * len(Ys) df = pd.DataFrame( { "arm_name": arm_names, "metric_name": self.name, "mean": Ys, "sem": stdvs, "trial_index": trial.index, } ) return Ok(value=Data(df=df)) except Exception as e: return Err( MetricFetchE( message=f"Failed to obtain data for trial {trial.index}", exception=e, ) )