Source code for ax.utils.testing.metrics.branin_backend_map
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Iterable, List, Optional
import numpy as np
from ax.core.map_data import MapKeyInfo
from ax.metrics.branin_map import BraninTimestampMapMetric
from ax.utils.testing.metrics.backend_simulator_map import (
BackendSimulatorTimestampMapMetric,
)
[docs]class BraninBackendMapMetric(
BackendSimulatorTimestampMapMetric, BraninTimestampMapMetric
):
"""A Branin ``BackendSimulatorTimestampMapMetric`` with a multiplicative
factor of ``1 - exp(-rate * t)`` where ``t`` is the runtime of the trial."""
def __init__(
self,
name: str,
param_names: List[str],
# pyre-fixme[24]: Generic type `MapKeyInfo` expects 1 type parameter.
map_key_infos: Optional[Iterable[MapKeyInfo]] = None,
noise_sd: float = 0.0,
lower_is_better: Optional[bool] = True,
cache_evaluations: bool = True,
rate: float = 0.5,
delta_t: float = 1.0,
) -> None:
"""The ``BraninTimestampMapMetric`` integrated with the backend simulator.
Args:
name: Name of the metric.
param_names: An ordered list of names of parameters to be passed
to the deterministic function.
noise_sd: Scale of normal noise added to the function result.
lower_is_better: Flag for metrics which should be minimized.
rate: Parameter of the multiplicative factor.
delta_t: The time delta between intermediate results, used in
``convert_to_timestamps``.
"""
BackendSimulatorTimestampMapMetric.__init__(
self,
name=name,
param_names=param_names,
map_key_infos=map_key_infos
if map_key_infos is not None
else [MapKeyInfo(key="timestamp", default_value=0.0)],
noise_sd=noise_sd,
lower_is_better=lower_is_better,
cache_evaluations=cache_evaluations,
)
self.rate = rate
self.delta_t = delta_t
self._timestamp = -1
[docs] def convert_to_timestamps(
self, start_time: Optional[float], end_time: float
) -> List[float]:
"""Given a starting and current time, get the list of intermediate
timestamps at which we have observations."""
if start_time is None:
# NOTE: This can be the case for trials on backend simulator
# that are queued.
return []
num_periods_running = (end_time - start_time) // self.delta_t
return list(np.arange(num_periods_running) * self.delta_t)