#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Tuple, Union
import numpy as np
from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.core.types import TConfig, TParamValue
from ax.modelbridge.transforms.base import Transform
from ax.utils.common.logger import get_logger
if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax.modelbridge import base as base_modelbridge # noqa F401 # pragma: no cover
logger = get_logger("StandardizeY")
[docs]class StandardizeY(Transform):
"""Standardize Y, separately for each metric.
Transform is done in-place.
"""
def __init__(
self,
search_space: SearchSpace,
observation_features: List[ObservationFeatures],
observation_data: List[ObservationData],
config: Optional[TConfig] = None,
) -> None:
if len(observation_data) == 0:
raise ValueError(
"StandardizeY transform requires non-empty observation data."
)
# Compute means and SDs
Ys: DefaultDict[str, List[float]] = defaultdict(list)
for obsd in observation_data:
for i, m in enumerate(obsd.metric_names):
Ys[m].append(obsd.means[i])
# Expected `DefaultDict[Union[str, typing.Tuple[str, Optional[Union[bool, float,
# str]]]], List[float]]` for 1st anonymous parameter to call
# `ax.modelbridge.transforms.standardize_y.compute_standardization_parameters`
# but got `DefaultDict[str, List[float]]`.
# pyre-fixme[6]: Expected `DefaultDict[Union[str, Tuple[str, Optional[Union[b...
self.Ymean, self.Ystd = compute_standardization_parameters(Ys)
[docs]def compute_standardization_parameters(
Ys: DefaultDict[Union[str, Tuple[str, TParamValue]], List[float]]
) -> Tuple[
Dict[Union[str, Tuple[str, str]], float], Dict[Union[str, Tuple[str, str]], float]
]:
"""Compute mean and std. dev of Ys."""
Ymean = {k: np.mean(y) for k, y in Ys.items()}
Ystd = {k: np.std(y) for k, y in Ys.items()}
for k, s in Ystd.items():
# Don't standardize if variance is too small.
if s < 1e-8:
Ystd[k] = 1.0
logger.info(f"Outcome {k} is constant, within tolerance.")
return Ymean, Ystd