#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
import math
from collections import defaultdict
from collections.abc import Iterable, Mapping
from random import random
from typing import Any
import numpy as np
import numpy.typing as npt
import pandas as pd
from ax.core.base_trial import BaseTrial
from ax.core.map_data import MapData, MapKeyInfo
from ax.core.map_metric import MapMetricFetchResult
from ax.core.metric import MetricFetchE
from ax.metrics.noisy_function_map import NoisyFunctionMapMetric
from ax.utils.common.result import Err, Ok
from ax.utils.common.typeutils import checked_cast
from ax.utils.measurement.synthetic_functions import branin
from pyre_extensions import none_throws
FIDELITY = [0.1, 0.4, 0.7, 1.0]
[docs]
class BraninTimestampMapMetric(NoisyFunctionMapMetric):
def __init__(
self,
name: str,
param_names: Iterable[str],
noise_sd: float = 0.0,
lower_is_better: bool | None = None,
rate: float | None = None,
cache_evaluations: bool = True,
) -> None:
"""A Branin map metric with an optional multiplicative factor
of `1 + exp(-rate * t)` where `t` is the runtime of the trial.
If the multiplicative factor is used, then at `t = 0`, the function
is twice the usual value, while as `t` becomes large, the values
approach the standard Branin values.
Args:
name: Name of the metric.
param_names: An ordered list of names of parameters to be passed
to the deterministic function.
noise_sd: Scale of normal noise added to the function result.
lower_is_better: Flag for metrics which should be minimized.
rate: Parameter of the multiplicative factor.
"""
self.rate = rate
# pyre-fixme[4]: Attribute must be annotated.
self._trial_index_to_timestamp = defaultdict(int)
super().__init__(
name=name,
param_names=param_names,
noise_sd=noise_sd,
lower_is_better=lower_is_better,
cache_evaluations=cache_evaluations,
)
def __eq__(self, o: BraninTimestampMapMetric) -> bool:
"""Ignore _timestamp on equality checks"""
return (
self.name == o.name
and self.param_names == o.param_names
and self.noise_sd == o.noise_sd
and self.lower_is_better == o.lower_is_better
)
[docs]
def fetch_trial_data(
self, trial: BaseTrial, noisy: bool = True, **kwargs: Any
) -> MapMetricFetchResult:
try:
if (
self._trial_index_to_timestamp[trial.index] == 0
or trial.status.is_running
):
self._trial_index_to_timestamp[trial.index] += 1
datas = []
for timestamp in range(self._trial_index_to_timestamp[trial.index]):
res = [
self.f(
np.fromiter(arm.parameters.values(), dtype=float),
timestamp=timestamp,
)
for arm in trial.arms
]
df = pd.DataFrame(
{
"arm_name": [arm.name for arm in trial.arms],
"metric_name": self.name,
"sem": self.noise_sd if noisy else 0.0,
"trial_index": trial.index,
"mean": [item["mean"] for item in res],
self.map_key_info.key: [
item[self.map_key_info.key] for item in res
],
}
)
datas.append(MapData(df=df, map_key_infos=[self.map_key_info]))
return Ok(value=MapData.from_multiple_map_data(datas))
except Exception as e:
return Err(
MetricFetchE(message=f"Failed to fetch {self.name}", exception=e)
)
# pyre-fixme[14]: `f` overrides method defined in `NoisyFunctionMapMetric`
# inconsistently.
[docs]
def f(self, x: npt.NDArray, timestamp: int) -> Mapping[str, Any]:
x1, x2 = x
if self.rate is not None:
weight = 1.0 + np.exp(-none_throws(self.rate) * timestamp)
else:
weight = 1.0
mean = checked_cast(float, branin(x1=x1, x2=x2)) * weight
return {"mean": mean, "timestamp": timestamp}
[docs]
class BraninFidelityMapMetric(NoisyFunctionMapMetric):
map_key_info: MapKeyInfo[float] = MapKeyInfo(key="fidelity", default_value=0.0)
def __init__(
self,
name: str,
param_names: Iterable[str],
noise_sd: float = 0.0,
lower_is_better: bool | None = None,
) -> None:
super().__init__(
name=name,
param_names=param_names,
noise_sd=noise_sd,
lower_is_better=lower_is_better,
)
self.index = -1
[docs]
def fetch_trial_data(
self, trial: BaseTrial, noisy: bool = True, **kwargs: Any
) -> MapMetricFetchResult:
self.index = -1
return super().fetch_trial_data(
trial=trial,
noisy=noisy,
**kwargs,
)
[docs]
def f(self, x: npt.NDArray) -> Mapping[str, Any]:
if self.index < len(FIDELITY):
self.index += 1
x1, x2 = x
fidelity = FIDELITY[self.index]
fidelity_penalty = random() * math.pow(1.0 - fidelity, 2.0)
mean = checked_cast(float, branin(x1=x1, x2=x2)) - fidelity_penalty
return {"mean": mean, "fidelity": fidelity}