#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
import logging
from logging import Logger
from typing import Any
import numpy as np
import pandas as pd
from ax.core.base_trial import BaseTrial
from ax.core.map_data import MapData, MapKeyInfo
from ax.core.map_metric import MapMetric
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.core.trial import Trial
from ax.utils.common.logger import get_logger
from ax.utils.common.result import Err, Ok
from pyre_extensions import assert_is_instance
logger: Logger = get_logger(__name__)
SMOOTHING_DEFAULT = 0.6 # Default in Tensorboard UI
RUN_METADATA_KEY = "tb_log_dir"
try:
from tensorboard.backend.event_processing import (
plugin_event_multiplexer as event_multiplexer,
)
logging.getLogger("tensorboard").setLevel(logging.CRITICAL)
[docs]
class TensorboardMetric(MapMetric):
"""A *new* `MapMetric` for getting Tensorboard metrics."""
map_key_info: MapKeyInfo[float] = MapKeyInfo(key="step", default_value=0.0)
def __init__(
self,
name: str,
tag: str,
lower_is_better: bool | None = True,
smoothing: float = SMOOTHING_DEFAULT,
cumulative_best: bool = False,
) -> None:
"""
Args:
name: The name of the metric.
tag: The name of the learning curve in the Tensorboard Scalars tab.
lower_is_better: If True, lower curve values are considered better.
smoothing: If > 0, apply exponential weighted mean to the curve. This
is the same postprocessing as the "smoothing" slider in the
Tensorboard UI.
cumulative_best: If True, for each trial, apply cumulative best to
the curve (i.e., if lower is better, then we return a curve
representing the cumulative min of the raw curve).
"""
super().__init__(name=name, lower_is_better=lower_is_better)
self.smoothing = smoothing
self.tag = tag
self.cumulative_best = cumulative_best
[docs]
@classmethod
def is_available_while_running(cls) -> bool:
return True
[docs]
def bulk_fetch_trial_data(
self, trial: BaseTrial, metrics: list[Metric], **kwargs: Any
) -> dict[str, MetricFetchResult]:
"""Fetch multiple metrics data for one trial, using instance attributes
of the metrics.
Returns Dict of metric_name => Result
Default behavior calls `fetch_trial_data` for each metric. Subclasses should
override this to perform trial data computation for multiple metrics.
"""
tb_metrics = [
assert_is_instance(metric, TensorboardMetric) for metric in metrics
]
trial = assert_is_instance(trial, Trial)
if trial.arm is None:
raise ValueError("Trial must have arm set.")
arm_name = trial.arm.name
try:
mul = self._get_event_multiplexer_for_trial(trial=trial)
except Exception as e:
return {
metric.name: Err(
MetricFetchE(
message=f"Failed to get event multiplexer for {trial=}",
exception=e,
)
)
for metric in tb_metrics
}
scalar_dict = mul.PluginRunToTagToContent("scalars")
if len(scalar_dict) == 0:
return {
metric.name: Err(
MetricFetchE(
message=(
"No 'scalar' data found for trial in multiplexer "
f"{mul=}"
),
exception=None,
)
)
for metric in tb_metrics
}
res = {}
for metric in tb_metrics:
try:
records = [
{
"trial_index": trial.index,
"arm_name": arm_name,
"metric_name": metric.name,
self.map_key_info.key: t.step,
"mean": (
t.tensor_proto.double_val[0]
if t.tensor_proto.double_val
else t.tensor_proto.float_val[0]
),
"sem": float("nan"),
}
for run_name, tags in scalar_dict.items()
for tag in tags
if tag == metric.tag
for t in mul.Tensors(run_name, tag)
]
# If records is empty something has gone wrong: either the tag is
# not present on the multiplexer or the content referenced is empty
if len(records) == 0:
if metric.tag not in [
j for sub in scalar_dict.values() for j in sub
]:
raise KeyError(
f"Tag {metric.tag} not found on multiplexer {mul=}. "
"Did you specify this tag exactly as it appears in "
"the TensorBoard UI's Scalars tab?"
)
else:
raise ValueError(
f"Found tag {metric.tag}, but no data found for it. Is "
"the curve empty in the TensorBoard UI?"
)
df = (
pd.DataFrame(records)
# If a metric has multiple records for the same arm, metric, and
# step (sometimes caused by restarts, etc) take the mean
.groupby(["arm_name", "metric_name", self.map_key_info.key])
.mean()
.reset_index()
)
# If there are any NaNs or Infs in the data, raise an Exception
if np.any(~np.isfinite(df["mean"])):
raise ValueError("Found NaNs or Infs in data")
# Apply per-metric post-processing
# Apply cumulative "best" (min if lower_is_better)
if metric.cumulative_best:
if metric.lower_is_better:
df["mean"] = df["mean"].cummin()
else:
df["mean"] = df["mean"].cummax()
# Apply smoothing
if metric.smoothing > 0:
df["mean"] = df["mean"].ewm(alpha=metric.smoothing).mean()
# Accumulate successfully extracted timeseries
res[metric.name] = Ok(
MapData(
df=df,
map_key_infos=[self.map_key_info],
)
)
except Exception as e:
res[metric.name] = Err(
MetricFetchE(
message=f"Failed to fetch data for {metric.name}",
exception=e,
)
)
self._clear_multiplexer_if_possible(multiplexer=mul)
return res
[docs]
def fetch_trial_data(
self, trial: BaseTrial, **kwargs: Any
) -> MetricFetchResult:
"""Fetch data for one trial."""
return self.bulk_fetch_trial_data(trial=trial, metrics=[self], **kwargs)[
self.name
]
def _get_event_multiplexer_for_trial(
self, trial: BaseTrial
) -> event_multiplexer.EventMultiplexer:
"""Get an event multiplexer with the logs for a given trial."""
mul = event_multiplexer.EventMultiplexer(max_reload_threads=20)
mul.AddRunsFromDirectory(trial.run_metadata[RUN_METADATA_KEY], None)
mul.Reload()
return mul
def _clear_multiplexer_if_possible(
self, multiplexer: event_multiplexer.EventMultiplexer
) -> None:
"""
Clear the multiplexer of all data. This is a no-op here, but for some
Multiplexers which may implement a clearing method this method can be
important for managing memory consumption.
"""
pass
except ImportError:
logger.warning(
"tensorboard package not found. If you would like to use "
"TensorboardMetric, please install tensorboard."
)
pass