Source code for ax.core.metric

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional

from ax.core.base import Base
from ax.core.data import Data
from ax.utils.common.serialization import extract_init_args, serialize_init_args


if TYPE_CHECKING:  # pragma: no cover
    # import as module to make sphinx-autodoc-typehints happy
    from ax import core  # noqa F401


[docs]class Metric(Base): """Base class for representing metrics. Attributes: lower_is_better: Flag for metrics which should be minimized. """ def __init__(self, name: str, lower_is_better: Optional[bool] = None) -> None: """Inits Metric. Args: name: Name of metric. lower_is_better: Flag for metrics which should be minimized. """ self._name = name self.lower_is_better = lower_is_better @property def name(self) -> str: """Get name of metric.""" return self._name
[docs] @classmethod def serialize_init_args(cls, metric: "Metric") -> Dict[str, Any]: """Serialize the properties needed to initialize the metric. Used for storage. """ return serialize_init_args( object=metric, exclude_fields=["name", "lower_is_better", "precomp_config"] )
[docs] @classmethod def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]: """Given a dictionary, extract the properties needed to initialize the metric. Used for storage. """ return extract_init_args(args=args, class_=cls)
[docs] def fetch_trial_data(self, trial: core.base_trial.BaseTrial, **kwargs: Any) -> Data: """Fetch data for one trial.""" raise NotImplementedError( f"Metric {self.name} does not implement data-fetching logic." ) # pragma: no cover
[docs] def fetch_experiment_data( self, experiment: core.experiment.Experiment, **kwargs: Any ) -> Data: """Fetch this metric's data for an experiment. Default behavior is to fetch data from all trials expecting data and concatenate the results. """ return Data.from_multiple_data( [ self.fetch_trial_data(trial, **kwargs) if trial.status.expecting_data else Data() for trial in experiment.trials.values() ] )
[docs] @classmethod def fetch_trial_data_multi( cls, trial: core.base_trial.BaseTrial, metrics: Iterable[Metric], **kwargs: Any ) -> Data: """Fetch multiple metrics data for one trial. Default behavior calls `fetch_trial_data` for each metric. Subclasses should override this to trial data computation for multiple metrics. """ return Data.from_multiple_data( [metric.fetch_trial_data(trial, **kwargs) for metric in metrics] )
[docs] @classmethod def fetch_experiment_data_multi( cls, experiment: core.experiment.Experiment, metrics: Iterable[Metric], trials: Optional[Iterable[core.base_trial.BaseTrial]] = None, **kwargs: Any, ) -> Data: """Fetch multiple metrics data for an experiment. Default behavior calls `fetch_trial_data_multi` for each trial. Subclasses should override to batch data computation across trials + metrics. """ return Data.from_multiple_data( [ cls.fetch_trial_data_multi(trial, metrics, **kwargs) if trial.status.expecting_data else Data() for trial in (experiment.trials.values() if trials is None else trials) ] )
[docs] def clone(self) -> "Metric": """Create a copy of this Metric.""" return Metric(name=self.name, lower_is_better=self.lower_is_better)
def __repr__(self) -> str: return "{class_name}('{metric_name}')".format( class_name=self.__class__.__name__, metric_name=self.name )