Source code for ax.metrics.tensorboard
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import itertools
import logging
from typing import Iterable, Dict, List, Optional, Union
import pandas as pd
from ax.core.map_data import MapKeyInfo
from ax.metrics.curve import AbstractCurveMetric
from ax.utils.common.logger import get_logger
logger = get_logger(__name__)
RESULTS_KEY = "vis_metrics"
try:
from tensorboard.backend.event_processing import (
plugin_event_multiplexer as event_multiplexer,
)
logging.getLogger("tensorboard").setLevel(logging.CRITICAL)
[docs] class TensorboardCurveMetric(AbstractCurveMetric):
"""A `CurveMetric` for getting Tensorboard curves."""
MAP_KEY = MapKeyInfo(key="steps", default_value=0)
[docs] @classmethod
def get_curves_from_ids(
cls, ids: Iterable[Union[int, str]]
) -> Dict[Union[int, str], Dict[str, pd.Series]]:
"""Get curve data from tensorboard logs.
Args:
ids: A list of string paths to tensorboard log directories.
NOTE: If the ids are not simple paths/posix locations, subclass this
metric and replace this method with an appropriate one that
retrieves the log results.
"""
result = {}
for id_ in ids:
tb = get_tb_from_posix(str(id_))
if tb is not None:
result[id_] = tb
return result
[docs] def get_tb_from_posix(path: str) -> Optional[Dict[str, pd.Series]]:
r"""Get Tensorboard data from a posix path.
Args:
path: The posix path for the directory that contains the
tensorboard logs.
Returns:
A dictionary mapping metric names to pandas Series of data.
If the path does not exist, return None.
"""
logger.info(f"Reading TB logs from {path}.")
mul = event_multiplexer.EventMultiplexer(max_reload_threads=20)
mul.AddRunsFromDirectory(path, None)
mul.Reload()
scalar_dict = mul.PluginRunToTagToContent("scalars")
raw_result = [
{"tag": tag, "event": mul.Tensors(run, tag)}
for run, run_dict in scalar_dict.items()
for tag in run_dict
]
tb_run_data = {}
for item in raw_result:
latest_start_time = _get_latest_start_time(item["event"])
steps = [e.step for e in item["event"] if e.wall_time >= latest_start_time]
vals = list(
itertools.chain.from_iterable(
[
e.tensor_proto.float_val
for e in item["event"]
if e.wall_time >= latest_start_time
]
)
)
key = item["tag"]
series = pd.Series(index=steps, data=vals).dropna()
if any(series.index.duplicated()): # pyre-ignore[16]
# take average of repeated observations of the same "step"
series = series.groupby(steps).mean() # pyre-ignore[16]
logger.warning(
f"Found duplicate steps for tag {key}. "
"Removing duplicates by averaging."
)
tb_run_data[key] = series
return tb_run_data
def _get_latest_start_time(events: List) -> float:
r"""In each directory, there may be previous training runs due
to restarting training jobs.
Args:
events: A list of TensorEvents.
Returns:
The start time of the latest training run.
"""
events.sort(key=lambda e: e.wall_time)
start_time = events[0].wall_time
for i in range(1, len(events)):
# detect points in time where restarts occurred
if events[i].step < events[i - 1].step:
start_time = events[i].wall_time
return start_time
except ImportError:
logger.warning(
"tensorboard package not found. If you would like to use "
"TensorboardCurveMetric, please install tensorboard."
)
pass