Source code for ax.metrics.torchx
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from logging import Logger
from typing import Any, cast
import pandas as pd
from ax.core import Trial
from ax.core.base_trial import BaseTrial
from ax.core.data import Data
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.utils.common.logger import get_logger
from ax.utils.common.result import Err, Ok
from ax.utils.common.typeutils import not_none
logger: Logger = get_logger(__name__)
try:
from ax.runners.torchx import TORCHX_TRACKER_BASE
from torchx.runtime.tracking import FsspecResultTracker
except ImportError:
logger.warning(
"torchx package not found. If you would like to use TorchXMetric, please "
"install torchx."
)
pass
[docs]class TorchXMetric(Metric):
"""
Fetches AppMetric (the observation returned by the trial job/app) via the
``torchx.tracking`` module. Assumes that the app used the tracker in the
following manner:
.. code-block:: python
tracker = torchx.runtime.tracking.FsspecResultTracker(tracker_base)
tracker[str(trial_index)] = {metric_name: value}
# -- or --
tracker[str(trial_index)] = {"metric_name/mean": mean_value,
"metric_name/sem": sem_value}
"""
[docs] def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult:
try:
tracker_base = trial.run_metadata[TORCHX_TRACKER_BASE]
tracker = FsspecResultTracker(tracker_base)
res = tracker[trial.index]
if self.name in res:
mean = res[self.name]
sem = None
else:
mean = res.get(f"{self.name}/mean")
sem = res.get(f"{self.name}/sem")
if mean is None and sem is None:
raise KeyError(
f"Observation for `{self.name}` not found in tracker at base "
f"`{tracker_base}`. Ensure that the trial job is writing the "
"results at the same tracker base."
)
df_dict = {
"arm_name": not_none(cast(Trial, trial).arm).name,
"trial_index": trial.index,
"metric_name": self.name,
"mean": mean,
"sem": sem,
}
return Ok(value=Data(df=pd.DataFrame.from_records([df_dict])))
except Exception as e:
return Err(
MetricFetchE(message=f"Failed to fetch {self.name}", exception=e)
)