# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Any
from ax.core.base_trial import BaseTrial
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.map_data import MapData, MapKeyInfo
from ax.core.map_metric import MapMetric
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.utils.common.result import Err, Ok
from pyre_extensions import none_throws
def _get_no_metadata_msg(trial_index: int) -> str:
return f"No metadata available for trial {trial_index}."
def _get_no_metadata_err(trial: BaseTrial) -> Err[Data, MetricFetchE]:
return Err(
MetricFetchE(
message=_get_no_metadata_msg(trial_index=trial.index),
exception=None,
)
)
def _validate_trial_and_kwargs(
trial: BaseTrial, class_name: str, **kwargs: Any
) -> None:
"""
Validate that:
- Kwargs are empty
- No arms within a BatchTrial have been abandoned
"""
if len(kwargs) > 0:
raise NotImplementedError(
f"Arguments {set(kwargs)} are not supported in "
f"{class_name}.fetch_trial_data."
)
if isinstance(trial, BatchTrial) and len(trial.abandoned_arms) > 0:
raise NotImplementedError(
"BenchmarkMetric does not support abandoned arms in batch trials."
)
[docs]
class BenchmarkMetric(Metric):
"""A generic metric used for observed values produced by Ax Benchmarks.
Compatible with results generated by `BenchmarkRunner`.
"""
def __init__(
self,
name: str,
# Needed to be boolean (not None) for validation of MOO opt configs
lower_is_better: bool,
observe_noise_sd: bool = True,
) -> None:
"""
Args:
name: Name of the metric.
lower_is_better: If `True`, lower metric values are considered better.
observe_noise_sd: If `True`, the standard deviation of the observation
noise is included in the `sem` column of the the returned data.
If `False`, `sem` is set to `None` (meaning that the model will
have to infer the noise level).
"""
super().__init__(name=name, lower_is_better=lower_is_better)
# Declare `lower_is_better` as bool (rather than optional as in the base class)
self.lower_is_better: bool = lower_is_better
self.observe_noise_sd: bool = observe_noise_sd
[docs]
def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult:
"""
Args:
trial: The trial from which to fetch data.
kwargs: Unsupported and will raise an exception.
Returns:
A MetricFetchResult containing the data for the requested metric.
"""
_validate_trial_and_kwargs(
trial=trial, class_name=self.__class__.__name__, **kwargs
)
if len(trial.run_metadata) == 0:
return _get_no_metadata_err(trial=trial)
df = trial.run_metadata["benchmark_metadata"].dfs[self.name]
if (df["t"] > 0).any():
raise ValueError(
f"Trial {trial.index} has data from multiple time steps. This is"
" not supported by `BenchmarkMetric`; use `BenchmarkMapMetric`."
)
df = df.drop(columns=["t"])
if not self.observe_noise_sd:
df["sem"] = None
return Ok(value=Data(df=df))
[docs]
class BenchmarkMapMetric(MapMetric):
# pyre-fixme: Inconsistent override [15]: `map_key_info` overrides attribute
# defined in `MapMetric` inconsistently. Type `MapKeyInfo[int]` is not a
# subtype of the overridden attribute `MapKeyInfo[float]`
map_key_info: MapKeyInfo[int] = MapKeyInfo(key="t", default_value=0)
def __init__(
self,
name: str,
# Needed to be boolean (not None) for validation of MOO opt configs
lower_is_better: bool,
observe_noise_sd: bool = True,
) -> None:
"""
Args:
name: Name of the metric.
lower_is_better: If `True`, lower metric values are considered better.
observe_noise_sd: If `True`, the standard deviation of the observation
noise is included in the `sem` column of the the returned data.
If `False`, `sem` is set to `None` (meaning that the model will
have to infer the noise level).
"""
super().__init__(name=name, lower_is_better=lower_is_better)
# Declare `lower_is_better` as bool (rather than optional as in the base class)
self.lower_is_better: bool = lower_is_better
self.observe_noise_sd: bool = observe_noise_sd
[docs]
@classmethod
def is_available_while_running(cls) -> bool:
return True
[docs]
def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult:
"""
If the trial has been completed, look up the ``sim_start_time`` and
``sim_completed_time`` on the corresponding ``SimTrial``, and return all
data from keys 0, ..., ``sim_completed_time - sim_start_time``. If the
trial has not completed, return all data from keys 0, ..., ``sim_runtime
- sim_start_time``.
Args:
trial: The trial from which to fetch data.
kwargs: Unsupported and will raise an exception.
Returns:
A MetricFetchResult containing the data for the requested metric.
"""
_validate_trial_and_kwargs(
trial=trial, class_name=self.__class__.__name__, **kwargs
)
if len(trial.run_metadata) == 0:
return _get_no_metadata_err(trial=trial)
metadata = trial.run_metadata["benchmark_metadata"]
backend_simulator = metadata.backend_simulator
if backend_simulator is None:
max_t = float("inf")
else:
sim_trial = none_throws(
backend_simulator.get_sim_trial_by_index(trial.index)
)
# The BackendSimulator distinguishes between queued and running
# trials "for testing particular initialization cases", but these
# are all "running" to Scheduler.
# start_time = none_throws(sim_trial.sim_queued_time)
start_time = none_throws(sim_trial.sim_start_time)
if sim_trial.sim_completed_time is None: # Still running
max_t = backend_simulator.time - start_time
else:
if sim_trial.sim_completed_time > backend_simulator.time:
raise RuntimeError(
"The trial's completion time is in the future! This is "
f"unexpected. {sim_trial.sim_completed_time=}, "
f"{backend_simulator.time=}"
)
# Completed, may have stopped early
max_t = none_throws(sim_trial.sim_completed_time) - start_time
df = (
metadata.dfs[self.name]
.loc[lambda x: x["t"] <= max_t]
.rename(columns={"t": self.map_key_info.key})
)
if not self.observe_noise_sd:
df["sem"] = None
return Ok(value=MapData(df=df, map_key_infos=[self.map_key_info]))