Source code for ax.core.metric

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Type

from ax.core.data import Data
from ax.utils.common.equality import Base
from ax.utils.common.serialization import extract_init_args, serialize_init_args


if TYPE_CHECKING:  # pragma: no cover
    # import as module to make sphinx-autodoc-typehints happy
    from ax import core  # noqa F401


[docs]class Metric(Base): """Base class for representing metrics. The `fetch_trial_data` method is the essential method to override when subclassing, which specifies how to retrieve a Metric, for a given trial. A Metric must return a Data object, which requires (at minimum) the following: https://ax.dev/api/_modules/ax/core/data.html#Data.required_columns Attributes: lower_is_better: Flag for metrics which should be minimized. properties: Properties specific to a particular metric. """ def __init__( self, name: str, lower_is_better: Optional[bool] = None, properties: Optional[Dict[str, Any]] = None, ) -> None: """Inits Metric. Args: name: Name of metric. lower_is_better: Flag for metrics which should be minimized. properties: Dictionary of this metric's properties """ self._name = name self.lower_is_better = lower_is_better self.properties = properties or {} @property def name(self) -> str: """Get name of metric.""" return self._name @property def fetch_multi_group_by_metric(self) -> Type[Metric]: """Metric class, with which to group this metric in `Experiment._metrics_by_class`, which is used to combine metrics on experiment into groups and then fetch their data via `Metric.fetch_trial_data_multi` for each group. NOTE: By default, this property will just return the class on which it is defined; however, in some cases it is useful to group metrics by their superclass, in which case this property should return that superclass. """ return self.__class__
[docs] @classmethod def serialize_init_args(cls, metric: "Metric") -> Dict[str, Any]: """Serialize the properties needed to initialize the metric. Used for storage. """ return serialize_init_args( object=metric, exclude_fields=["name", "lower_is_better", "precomp_config"] )
[docs] @classmethod def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]: """Given a dictionary, extract the properties needed to initialize the metric. Used for storage. """ return extract_init_args(args=args, class_=cls)
[docs] def fetch_trial_data(self, trial: core.base_trial.BaseTrial, **kwargs: Any) -> Data: """Fetch data for one trial.""" raise NotImplementedError( f"Metric {self.name} does not implement data-fetching logic." ) # pragma: no cover
[docs] def fetch_experiment_data( self, experiment: core.experiment.Experiment, **kwargs: Any ) -> Data: """Fetch this metric's data for an experiment. Default behavior is to fetch data from all trials expecting data and concatenate the results. """ return Data.from_multiple_data( [ self.fetch_trial_data(trial, **kwargs) if trial.status.expecting_data else Data() for trial in experiment.trials.values() ] )
[docs] @classmethod def fetch_trial_data_multi( cls, trial: core.base_trial.BaseTrial, metrics: Iterable[Metric], **kwargs: Any ) -> Data: """Fetch multiple metrics data for one trial. Default behavior calls `fetch_trial_data` for each metric. Subclasses should override this to trial data computation for multiple metrics. """ return Data.from_multiple_data( [metric.fetch_trial_data(trial, **kwargs) for metric in metrics] )
[docs] @classmethod def fetch_experiment_data_multi( cls, experiment: core.experiment.Experiment, metrics: Iterable[Metric], trials: Optional[Iterable[core.base_trial.BaseTrial]] = None, **kwargs: Any, ) -> Data: """Fetch multiple metrics data for an experiment. Default behavior calls `fetch_trial_data_multi` for each trial. Subclasses should override to batch data computation across trials + metrics. """ return Data.from_multiple_data( [ cls.fetch_trial_data_multi(trial, metrics, **kwargs) if trial.status.expecting_data else Data() for trial in (experiment.trials.values() if trials is None else trials) ] )
[docs] def clone(self) -> "Metric": """Create a copy of this Metric.""" return Metric(name=self.name, lower_is_better=self.lower_is_better)
def __repr__(self) -> str: return "{class_name}('{metric_name}')".format( class_name=self.__class__.__name__, metric_name=self.name )