Source code for ax.benchmark.benchmark_metric

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

"""
Metric classes for Ax benchmarking.

Metrics vary on two dimensions: Whether they are `MapMetric`s or not, and
whether they are available while running or not.

There are four Metric classes:
- `BenchmarkMetric`: A non-Map metric
    is not available while running.
- `BenchmarkMapMetric`: For when outputs should be `MapData` (not `Data`) and
    data is available while running.
- `BenchmarkTimeVaryingMetric`: For when outputs should be `Data` and the metric
  is available while running.
- `BenchmarkMapUnavailableWhileRunningMetric`: For when outputs should be
  `MapData` and the metric is not available while running.

Any of these can be used with or without a simulator. However,
`BenchmarkMetric.fetch_trial_data` cannot take in data with multiple time steps,
as they will not be used and this is assumed to be an error. The below table
enumerates use cases.

.. list-table:: Benchmark Metrics Table
   :widths: 5 25 5 5 5 50
   :header-rows: 1

   * -
     - Metric
     - Map
     - Available while running
     - Simulator
     - Reason/use case
   * - 1
     - BenchmarkMetric
     - No
     - No
     - No
     - Vanilla
   * - 2
     - BenchmarkMetric
     - No
     - No
     - Yes
     - Asynchronous, data read only at end
   * - 3
     - BenchmarkTimeVaryingMetric
     - No
     - Yes
     - No
     - Behaves like #1 because it will never be RUNNING
   * - 4
     - BenchmarkTimeVaryingMetric
     - No
     - Yes
     - Yes
     - Scalar data that changes over time
   * - 5
     - BenchmarkMapUnavailableWhileRunningMetric
     - Yes
     - No
     - No
     - MapData that returns immediately; could be used for getting baseline
   * - 6
     - BenchmarkMapUnavailableWhileRunningMetric
     - Yes
     - No
     - Yes
     - Asynchronicity with MapData read only at end
   * - 7
     - BenchmarkMapMetric
     - Yes
     - Yes
     - No
     - Behaves same as #5
   * - 8
     - BenchmarkMapMetric
     - Yes
     - Yes
     - Yes
     - Early stopping
"""

from abc import abstractmethod
from typing import Any

from ax.benchmark.benchmark_trial_metadata import BenchmarkTrialMetadata

from ax.core.base_trial import BaseTrial
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data

from ax.core.map_data import MapData, MapKeyInfo
from ax.core.map_metric import MapMetric
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.utils.common.result import Err, Ok
from pandas import DataFrame
from pyre_extensions import none_throws


def _get_no_metadata_msg(trial_index: int) -> str:
    return f"No metadata available for trial {trial_index}."


[docs] class BenchmarkMetricBase(Metric): def __init__( self, name: str, # Needed to be boolean (not None) for validation of MOO opt configs lower_is_better: bool, observe_noise_sd: bool = True, ) -> None: """ Args: name: Name of the metric. lower_is_better: If `True`, lower metric values are considered better. observe_noise_sd: If `True`, the standard deviation of the observation noise is included in the `sem` column of the the returned data. If `False`, `sem` is set to `None` (meaning that the model will have to infer the noise level). """ super().__init__(name=name, lower_is_better=lower_is_better) # Declare `lower_is_better` as bool (rather than optional as in the base class) self.lower_is_better: bool = lower_is_better self.observe_noise_sd: bool = observe_noise_sd def _class_specific_metdata_validation( self, metadata: BenchmarkTrialMetadata | None ) -> None: return
[docs] def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult: """ Args: trial: The trial from which to fetch data. kwargs: Unsupported and will raise an exception. Returns: A MetricFetchResult containing the data for the requested metric. """ class_name = self.__class__.__name__ if len(kwargs) > 0: raise NotImplementedError( f"Arguments {set(kwargs)} are not supported in " f"{class_name}.fetch_trial_data." ) if isinstance(trial, BatchTrial) and len(trial.abandoned_arms) > 0: raise NotImplementedError( f"{self.__class__.__name__} does not support abandoned arms in " "batch trials." ) if len(trial.run_metadata) == 0: return Err( MetricFetchE( message=_get_no_metadata_msg(trial_index=trial.index), exception=None, ) ) metadata = trial.run_metadata["benchmark_metadata"] self._class_specific_metdata_validation(metadata=metadata) backend_simulator = metadata.backend_simulator df = metadata.dfs[self.name] # Filter out the observable data if backend_simulator is None: # If there's no backend simulator then no filtering is needed; the # trial will complete immediately, with all data available. available_data = df else: sim_trial = none_throws( backend_simulator.get_sim_trial_by_index(trial.index) ) # The BackendSimulator distinguishes between queued and running # trials "for testing particular initialization cases", but these # are all "running" to Scheduler. start_time = none_throws(sim_trial.sim_start_time) if sim_trial.sim_completed_time is None: # Still running max_t = backend_simulator.time - start_time elif sim_trial.sim_completed_time > backend_simulator.time: raise RuntimeError( "The trial's completion time is in the future! This is " f"unexpected. {sim_trial.sim_completed_time=}, " f"{backend_simulator.time=}" ) else: # Completed, may have stopped early -- can't assume all data available completed_time = none_throws(sim_trial.sim_completed_time) max_t = completed_time - start_time available_data = df[df["virtual runtime"] <= max_t] if not self.observe_noise_sd: available_data.loc[:, "sem"] = None return self._df_to_result(df=available_data.drop(columns=["virtual runtime"]))
@abstractmethod def _df_to_result(self, df: DataFrame) -> MetricFetchResult: """ Convert a DataFrame of observable data to Data or MapData, as appropriate for the class. """ ...
[docs] class BenchmarkMetric(BenchmarkMetricBase): """ Non-map Metric for benchmarking that is not available while running. It cannot process data with multiple time steps, as it would only return one value -- the value it has at completion time -- regardless. """ def _class_specific_metdata_validation( self, metadata: BenchmarkTrialMetadata | None ) -> None: if metadata is not None: df = metadata.dfs[self.name] if df["step"].nunique() > 1: raise ValueError( f"Trial has data from multiple time steps. This is" f" not supported by `{self.__class__.__name__}`; use " "`BenchmarkMapMetric`." ) def _df_to_result(self, df: DataFrame) -> MetricFetchResult: return Ok(value=Data(df=df.drop(columns=["step"])))
[docs] class BenchmarkTimeVaryingMetric(BenchmarkMetricBase): """ Non-Map Metric for benchmarking that is available while running. It can produce different values at different times depending on when it is called, using the `time` on a `BackendSimulator`. """
[docs] @classmethod def is_available_while_running(cls) -> bool: return True
def _df_to_result(self, df: DataFrame) -> MetricFetchResult: return Ok( value=Data(df=df[df["step"] == df["step"].max()].drop(columns=["step"])) )
[docs] class BenchmarkMapMetric(MapMetric, BenchmarkMetricBase): """MapMetric for benchmarking. It is available while running.""" # pyre-fixme: Inconsistent override [15]: `map_key_info` overrides attribute # defined in `MapMetric` inconsistently. Type `MapKeyInfo[int]` is not a # subtype of the overridden attribute `MapKeyInfo[float]` map_key_info: MapKeyInfo[int] = MapKeyInfo(key="step", default_value=0)
[docs] @classmethod def is_available_while_running(cls) -> bool: return True
def _df_to_result(self, df: DataFrame) -> MetricFetchResult: # Just in case the key was renamed by a subclass df = df.rename(columns={"step": self.map_key_info.key}) return Ok(value=MapData(df=df, map_key_infos=[self.map_key_info]))
[docs] class BenchmarkMapUnavailableWhileRunningMetric(MapMetric, BenchmarkMetricBase): # pyre-fixme: Inconsistent override [15]: `map_key_info` overrides attribute # defined in `MapMetric` inconsistently. Type `MapKeyInfo[int]` is not a # subtype of the overridden attribute `MapKeyInfo[float]` map_key_info: MapKeyInfo[int] = MapKeyInfo(key="step", default_value=0) def _df_to_result(self, df: DataFrame) -> MetricFetchResult: # Just in case the key was renamed by a subclass df = df.rename(columns={"step": self.map_key_info.key}) return Ok(value=MapData(df=df, map_key_infos=[self.map_key_info]))