Source code for ax.metrics.dict_lookup
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
from ax.core.base_trial import BaseTrial
from ax.core.data import Data
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.utils.common.result import Err, Ok
from ax.utils.common.typeutils import not_none
[docs]class DictLookupMetric(Metric):
"""A metric defined by a dictionary mapping parameter values to the
corresponding metric values.
This provides an option to add normal noise with mean 0 and mean_sd scale
to the given metric values.
"""
def __init__(
self,
name: str,
param_names: List[str],
lookup_dict: Dict[Tuple[Union[str, float, int, bool], ...], float],
noise_sd: Optional[float] = 0.0,
lower_is_better: Optional[bool] = None,
) -> None:
"""Metric is computed via a dictionary look up using a tuple of
parameter values, constructed based on the ordering of parameter
names given in `param_names`.
Args:
name: Name of the metric.
param_names: An ordered list of names of parameters to be used
to construct the dictionary key.
lookup_dict: A dictionary mapping a tuple of parameter values to
the metric values.
noise_sd: Scale of normal noise added to the function result. If
None, interpret the function as noisy with unknown noise level.
lower_is_better: Flag for metrics which should be minimized.
"""
self.param_names = param_names
self.lookup_dict = lookup_dict
self.noise_sd = noise_sd
super().__init__(name=name, lower_is_better=lower_is_better)
[docs] @classmethod
def is_available_while_running(cls) -> bool:
return True
[docs] def clone(self) -> DictLookupMetric:
return self.__class__(
name=self._name,
param_names=self.param_names,
lookup_dict=self.lookup_dict,
noise_sd=self.noise_sd,
lower_is_better=self.lower_is_better,
)
[docs] def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult:
try:
noise_sd = self.noise_sd
arm_names = []
mean = []
for name, arm in trial.arms_by_name.items():
arm_names.append(name)
lookup_key = tuple(
not_none(arm.parameters[p]) for p in self.param_names
)
try:
val = self.lookup_dict[lookup_key]
except KeyError:
raise KeyError(
"Got a KeyError while attempting to retrieve the "
f"parameterization {arm.parameters} from the lookup dict. "
f"This parameterization corresponds to {lookup_key=}."
)
if noise_sd:
val = val + noise_sd * np.random.randn()
mean.append(val)
# Indicate unknown noise level in data.
if noise_sd is None:
noise_sd = float("nan")
df = pd.DataFrame(
{
"arm_name": arm_names,
"metric_name": self.name,
"mean": mean,
"sem": noise_sd,
"trial_index": trial.index,
}
)
return Ok(value=Data(df=df))
except Exception as e:
return Err(
MetricFetchE(message=f"Failed to fetch {self.name}", exception=e)
)