Source code for ax.modelbridge.transforms.utils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

from collections import defaultdict
from collections.abc import Callable
from math import isnan
from numbers import Number
from typing import Any, TYPE_CHECKING

import numpy as np
from ax.core.observation import Observation, ObservationData, ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import Parameter
from ax.core.parameter_constraint import ParameterConstraint
from ax.core.search_space import RobustSearchSpace, SearchSpace
from ax.modelbridge.transforms.derelativize import Derelativize
from scipy.stats import norm


if TYPE_CHECKING:
    # import as module to make sphinx-autodoc-typehints happy
    from ax import modelbridge as modelbridge_module  # noqa F401


# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use `typing.Dict`
#  to avoid runtime subscripting errors.
[docs] class ClosestLookupDict(dict): """A dictionary with numeric keys that looks up the closest key.""" def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) # pyre-fixme[4]: Attribute must be annotated. self._keys = sorted(self.keys()) # pyre-fixme[2]: Parameter annotation cannot be `Any`. def __setitem__(self, key: Number, val: Any) -> None: if not isinstance(key, Number): raise ValueError("ClosestLookupDict only allows numerical keys.") super().__setitem__(key, val) # pyre-fixme[6]: For 2nd argument expected `Union[bytes, complex, float, # int, generic, str]` but got `Number`. ipos = np.searchsorted(self._keys, key) self._keys.insert(ipos, key) # pyre-fixme[3]: Return annotation cannot be `Any`. def __getitem__(self, key: Number) -> Any: try: return super().__getitem__(key) except KeyError: if not self.keys(): raise RuntimeError("ClosestLookupDict is empty.") # pyre-fixme[6]: For 2nd argument expected `Union[bytes, complex, float, # int, generic, str]` but got `Number`. ipos = np.searchsorted(self._keys, key) if ipos == 0: return super().__getitem__(self._keys[0]) elif ipos == len(self._keys): return super().__getitem__(self._keys[-1]) lkey, rkey = self._keys[ipos - 1 : ipos + 1] if np.abs(key - lkey) <= np.abs(key - rkey): # pyre-ignore [58] return super().__getitem__(lkey) else: return super().__getitem__(rkey)
[docs] def get_data( observation_data: list[ObservationData], metric_names: list[str] | None = None, raise_on_non_finite_data: bool = True, ) -> dict[str, list[float]]: """Extract all metrics if `metric_names` is None. Raises a value error if any data is non-finite. Args: observation_data: List of observation data. metric_names: List of metric names. raise_on_non_finite_data: If true, raises an exception on nan/inf. Returns: A dictionary mapping metric names to lists of metric values. """ Ys = defaultdict(list) for obsd in observation_data: for i, m in enumerate(obsd.metric_names): if metric_names is None or m in metric_names: val = obsd.means[i] if raise_on_non_finite_data and (not np.isfinite(val)): raise ValueError(f"Non-finite data found for metric {m}: {val}") Ys[m].append(val) return Ys
[docs] def match_ci_width_truncated( mean: float, variance: float, transform: Callable[[float], float], level: float = 0.95, margin: float = 0.001, lower_bound: float = 0.0, upper_bound: float = 1.0, clip_mean: bool = False, ) -> tuple[float, float]: """Estimate a transformed variance using the match ci width method. See log_y transform for the original. Here, bounds are forced to lie within a [lower_bound, upper_bound] interval after transformation.""" fac = norm.ppf(1 - (1 - level) / 2) if clip_mean: mean = np.clip(mean, lower_bound + margin, upper_bound - margin) new_mean = transform(mean) if isnan(variance): new_variance = variance else: d = fac * np.sqrt(variance) right = min(mean + d, upper_bound - margin) left = max(mean - d, lower_bound + margin) width_asym = transform(right) - transform(left) new_variance = (width_asym / 2 / fac) ** 2 return new_mean, new_variance
[docs] def construct_new_search_space( search_space: SearchSpace, parameters: list[Parameter], parameter_constraints: list[ParameterConstraint] | None = None, ) -> SearchSpace: """Construct a search space with the transformed arguments. If the `search_space` is a `RobustSearchSpace`, this will use its environmental variables and distributions, and remove the environmental variables from `parameters` before constructing. Args: parameters: List of transformed parameter objects. parameter_constraints: List of parameter constraints. Returns: The new search space instance. """ new_kwargs: dict[str, Any] = { "parameters": parameters, "parameter_constraints": parameter_constraints, } if isinstance(search_space, RobustSearchSpace): env_vars = list(search_space._environmental_variables.values()) if env_vars: # Add environmental variables and remove them from parameters. new_kwargs["environmental_variables"] = env_vars new_kwargs["parameters"] = [p for p in parameters if p not in env_vars] new_kwargs["parameter_distributions"] = search_space.parameter_distributions new_kwargs["num_samples"] = search_space.num_samples return search_space.__class__(**new_kwargs)
[docs] def derelativize_optimization_config_with_raw_status_quo( optimization_config: OptimizationConfig, modelbridge: modelbridge_module.base.ModelBridge, observations: list[Observation] | None, ) -> OptimizationConfig: """Derelativize optimization_config using raw status-quo values""" tf = Derelativize( search_space=modelbridge.model_space.clone(), observations=observations, config={"use_raw_status_quo": True}, ) return tf.transform_optimization_config( optimization_config=optimization_config.clone(), modelbridge=modelbridge, fixed_features=ObservationFeatures(parameters={}), )