Source code for ax.metrics.tensorboard

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import logging

from logging import Logger
from typing import Dict, Iterable, List, NamedTuple, Union

import pandas as pd
from ax.core.map_data import MapKeyInfo
from ax.metrics.curve import AbstractCurveMetric
from ax.utils.common.logger import get_logger

logger: Logger = get_logger(__name__)

try:
    from tensorboard.backend.event_processing import (
        plugin_event_multiplexer as event_multiplexer,
    )
    from tensorboard.compat.proto import types_pb2

    logging.getLogger("tensorboard").setLevel(logging.CRITICAL)

[docs] class TensorboardCurveMetric(AbstractCurveMetric): """A `CurveMetric` for getting Tensorboard curves.""" map_key_info: MapKeyInfo[float] = MapKeyInfo(key="steps", default_value=0.0)
[docs] @classmethod def get_curves_from_ids( cls, ids: Iterable[Union[int, str]] ) -> Dict[Union[int, str], Dict[str, pd.Series]]: """Get curve data from tensorboard logs. NOTE: If the ids are not simple paths/posix locations, subclass this metric and replace this method with an appropriate one that retrieves the log results. Args: ids: A list of string paths to tensorboard log directories. Returns: A dictionary mapping metric names to pandas Series of data. """ return {idx: get_tb_from_posix(str(idx)) for idx in ids}
[docs] def get_tb_from_posix(path: str) -> Dict[str, pd.Series]: r"""Get Tensorboard data from a posix path. Args: path: The posix path for the directory that contains the tensorboard logs. Returns: A dictionary mapping metric names to pandas Series of data. """ logger.debug(f"Reading TB logs from {path}.") mul = event_multiplexer.EventMultiplexer(max_reload_threads=20) mul.AddRunsFromDirectory(path, None) mul.Reload() scalar_dict = mul.PluginRunToTagToContent("scalars") raw_result = [ {"tag": tag, "event": mul.Tensors(run, tag)} for run, run_dict in scalar_dict.items() for tag in run_dict ] tb_run_data = {} for item in raw_result: latest_start_time = _get_latest_start_time(item["event"]) steps = [e.step for e in item["event"] if e.wall_time >= latest_start_time] vals = [ _get_event_value(e) for e in item["event"] if e.wall_time >= latest_start_time ] key = item["tag"] series = pd.Series(index=steps, data=vals).dropna() if any(series.index.duplicated()): # pyre-ignore[16] # take average of repeated observations of the same "step" series = series.groupby(series.index).mean() # pyre-ignore[16] logger.debug( f"Found duplicate steps for tag {key}. " "Removing duplicates by averaging." ) tb_run_data[key] = series return tb_run_data
# pyre-fixme[24]: Generic type `list` expects 1 type parameter, use # `typing.List` to avoid runtime subscripting errors. def _get_latest_start_time(events: List) -> float: """In each directory, there may be previous training runs due to restarting training jobs. Args: events: A list of TensorEvents. Returns: The start time of the latest training run. """ events.sort(key=lambda e: e.wall_time) start_time = events[0].wall_time for i in range(1, len(events)): # detect points in time where restarts occurred if events[i].step < events[i - 1].step: start_time = events[i].wall_time return start_time def _get_event_value(e: NamedTuple) -> float: r"""Helper function to check the dtype and then get the value stored in a TensorEvent.""" tensor = e.tensor_proto # pyre-ignore[16] if tensor.dtype == types_pb2.DT_FLOAT: return tensor.float_val[0] elif tensor.dtype == types_pb2.DT_DOUBLE: return tensor.double_val[0] elif tensor.dtype == types_pb2.DT_INT32: return tensor.int_val[0] else: raise ValueError(f"Tensorboard dtype {tensor.dtype} not supported.") except ImportError: logger.warning( "tensorboard package not found. If you would like to use " "TensorboardCurveMetric, please install tensorboard." ) pass