Source code for ax.benchmark.metrics.benchmark

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

from typing import Any, Optional

from ax.benchmark.metrics.base import BenchmarkMetricBase, GroundTruthMetricMixin
from ax.benchmark.metrics.utils import _fetch_trial_data
from ax.core.base_trial import BaseTrial
from ax.core.metric import MetricFetchResult


[docs]class BenchmarkMetric(BenchmarkMetricBase): """A generic metric used for observed values produced by Ax Benchmarks. Compatible e.g. with results generated by `BotorchTestProblemRunner` and `SurrogateRunner`. Attributes: has_ground_truth: Whether or not there exists a ground truth for this metric, i.e. whether each observation has an associated ground truth value. This is trivially true for deterministic metrics, and is also true for metrics where synthetic observation noise is added to its (deterministic) values. This is not true for metrics that are inherently noisy. """ has_ground_truth: bool = True def __init__( self, name: str, lower_is_better: bool, # TODO: Do we need to define this here? observe_noise_sd: bool = True, outcome_index: Optional[int] = None, ) -> None: """ Args: name: Name of the metric. lower_is_better: If `True`, lower metric values are considered better. observe_noise_sd: If `True`, the standard deviation of the observation noise is included in the `sem` column of the the returned data. If `False`, `sem` is set to `None` (meaning that the model will have to infer the noise level). outcome_index: The index of the output. This is applicable in settings where the underlying test problem is evaluated in a vectorized fashion across multiple outputs, without providing a name for each output. In such cases, `outcome_index` is used in `fetch_trial_data` to extract `Ys` and `Yvars`, and `name` is the name of the metric. """ super().__init__(name=name, lower_is_better=lower_is_better) # Declare `lower_is_better` as bool (rather than optional as in the base class) self.lower_is_better: bool = lower_is_better self.observe_noise_sd = observe_noise_sd self.outcome_index = outcome_index
[docs] def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult: if len(kwargs) > 0: raise NotImplementedError( f"Arguments {set(kwargs)} are not supported in " f"{self.__class__.__name__}.fetch_trial_data." ) return _fetch_trial_data( trial=trial, metric_name=self.name, outcome_index=self.outcome_index, include_noise_sd=self.observe_noise_sd, ground_truth=False, )
[docs] def make_ground_truth_metric(self) -> BenchmarkMetricBase: """Create a ground truth version of this metric.""" return GroundTruthBenchmarkMetric(original_metric=self)
[docs]class GroundTruthBenchmarkMetric(BenchmarkMetric, GroundTruthMetricMixin): def __init__(self, original_metric: BenchmarkMetric) -> None: """ Args: original_metric: The original BenchmarkMetric to which this metric corresponds. """ super().__init__( name=self.get_ground_truth_name(original_metric), lower_is_better=original_metric.lower_is_better, observe_noise_sd=False, outcome_index=original_metric.outcome_index, ) self.original_metric = original_metric
[docs] def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult: if len(kwargs) > 0: raise NotImplementedError( f"Arguments {set(kwargs)} are not supported in " f"{self.__class__.__name__}.fetch_trial_data." ) return _fetch_trial_data( trial=trial, metric_name=self.name, outcome_index=self.outcome_index, include_noise_sd=False, ground_truth=True, )
[docs] def make_ground_truth_metric(self) -> BenchmarkMetricBase: """Create a ground truth version of this metric.""" return self