Source code for ax.modelbridge.transforms.ivw
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from logging import Logger
from typing import Dict, List
import numpy as np
from ax.core.observation import ObservationData
from ax.modelbridge.transforms.base import Transform
from ax.utils.common.logger import get_logger
logger: Logger = get_logger(__name__)
[docs]def ivw_metric_merge(
obsd: ObservationData, conflicting_noiseless: str = "warn"
) -> ObservationData:
"""Merge multiple observations of a metric with inverse variance weighting.
Correctly updates the covariance of the new merged estimates:
ybar1 = Sum_i w_i * y_i
ybar2 = Sum_j w_j * y_j
cov[ybar1, ybar2] = Sum_i Sum_j w_i * w_j * cov[y_i, y_j]
w_i will be infinity if any variance is 0. If one variance is 0., then
the IVW estimate is the corresponding mean. If there are multiple
measurements with 0 variance but means are all the same, then IVW estimate
is that mean. If there are multiple measurements and means differ, behavior
depends on argument conflicting_noiseless. "ignore" and "warn" will use
the first of the measurements as the IVW estimate. "warn" will additionally
log a warning. "raise" will raise an exception.
Args:
obsd: An ObservationData object
conflicting_noiseless: "warn", "ignore", or "raise"
"""
if len(obsd.metric_names) == len(set(obsd.metric_names)):
return obsd
if conflicting_noiseless not in {"warn", "ignore", "raise"}:
raise ValueError(
'conflicting_noiseless should be "warn", "ignore", or "raise".'
)
# Get indicies and weights for each metric.
# weights is a map from metric name to a vector of the weights for each
# measurement of that metric. indicies gives the corresponding index in
# obsd.means for each measurement.
weights: Dict[str, np.ndarray] = {}
indicies: Dict[str, List[int]] = {}
for metric_name in set(obsd.metric_names):
indcs = [i for i, mn in enumerate(obsd.metric_names) if mn == metric_name]
indicies[metric_name] = indcs
# Extract variances for observations of this metric
sigma2s = obsd.covariance[indcs, indcs]
# Check for noiseless observations
idx_noiseless = np.where(sigma2s == 0.0)[0]
if len(idx_noiseless) == 0:
# Weight is inverse of variance, normalized
# Expected `np.ndarray` for 3rd anonymous parameter to call
# `dict.__setitem__` but got `float`.
# pyre-fixme[6]:
weights[metric_name] = 1.0 / sigma2s
weights[metric_name] /= np.sum(weights[metric_name])
else:
# Check if there are conflicting means for the noiseless observations
means_noiseless = obsd.means[indcs][idx_noiseless]
_check_conflicting_means(
means_noiseless, metric_name, conflicting_noiseless
)
# The first observation gets all the weight.
weights[metric_name] = np.zeros_like(sigma2s)
weights[metric_name][idx_noiseless[0]] = 1.0
# Compute the new values
metric_names = sorted(set(obsd.metric_names))
means = np.zeros(len(metric_names))
covariance = np.zeros((len(metric_names), len(metric_names)))
for i, metric_name in enumerate(metric_names):
ys = obsd.means[indicies[metric_name]]
means[i] = np.sum(weights[metric_name] * ys)
# Calculate covariances with metric_name
for j, metric_name2 in enumerate(metric_names[i:], start=i):
for ii, idx_i in enumerate(indicies[metric_name]):
for jj, idx_j in enumerate(indicies[metric_name2]):
covariance[i, j] += (
weights[metric_name][ii]
* weights[metric_name2][jj]
* obsd.covariance[idx_i, idx_j]
)
covariance[j, i] = covariance[i, j]
return ObservationData(
metric_names=metric_names, means=means, covariance=covariance
)
def _check_conflicting_means(
means_noiseless: np.ndarray, metric_name: str, conflicting_noiseless: str
) -> None:
if np.var(means_noiseless) > 0:
message = f"Conflicting noiseless measurements for {metric_name}."
if conflicting_noiseless == "warn":
logger.warning(message)
elif conflicting_noiseless == "raise":
raise ValueError(message)
[docs]class IVW(Transform):
"""If an observation data contains multiple observations of a metric, they
are combined using inverse variance weighting.
"""
def _transform_observation_data(
self,
observation_data: List[ObservationData],
) -> List[ObservationData]:
# pyre: conflicting_noiseless is declared to have type `str` but is
# pyre-fixme[9]: used as type `typing.Union[float, int, str]`.
conflicting_noiseless: str = self.config.get("conflicting_noiseless", "warn")
return [
ivw_metric_merge(obsd=obsd, conflicting_noiseless=conflicting_noiseless)
for obsd in observation_data
]