Source code for ax.modelbridge.transforms.transform_to_new_sq

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

from collections.abc import Callable

from math import sqrt
from typing import TYPE_CHECKING

import numpy as np
import numpy.typing as npt
from ax.core.observation import Observation, ObservationData, ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.outcome_constraint import OutcomeConstraint
from ax.core.search_space import SearchSpace
from ax.modelbridge.transforms.relativize import BaseRelativize, get_metric_index
from ax.models.types import TConfig
from ax.utils.common.typeutils import checked_cast
from ax.utils.stats.statstools import relativize, unrelativize
from pyre_extensions import none_throws

if TYPE_CHECKING:
    # import as module to make sphinx-autodoc-typehints happy
    from ax import modelbridge as modelbridge_module  # noqa F401


[docs] class TransformToNewSQ(BaseRelativize): """Map relative values of one batch to SQ of another. Will compute the relative metrics for each arm in each batch, and will then turn those back into raw metrics but using the status quo values set on the Modelbridge. This is useful if batches are comparable on a relative scale, but have offset in their status quo. This is often approximately true for online experiments run in separate batches. Note that relativization is done using the delta method, so it will not simply be the ratio of the means.""" def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, modelbridge: modelbridge_module.base.ModelBridge | None = None, config: TConfig | None = None, ) -> None: super().__init__( search_space=search_space, observations=observations, modelbridge=modelbridge, config=config, ) self._status_quo_name: str = none_throws( none_throws(modelbridge).status_quo_name ) if config is not None: target_trial_index = config.get("target_trial_index") if target_trial_index is not None: self.default_trial_idx: int = checked_cast(int, target_trial_index) trial_indices = {} if observations is not None: trial_indices = { obs.features.trial_index for obs in observations if obs.features.trial_index is not None } # in case no target trial index is provided or the provided target # trial index is not a part of any trial from the observations, # use the smallest trial index from the observations if len(trial_indices) > 0 and (target_trial_index not in trial_indices): self.default_trial_idx = min(trial_indices) @property def control_as_constant(self) -> bool: """Whether or not the control is treated as a constant in the model.""" return True
[docs] def transform_optimization_config( self, optimization_config: OptimizationConfig, modelbridge: modelbridge_module.base.ModelBridge | None = None, fixed_features: ObservationFeatures | None = None, ) -> OptimizationConfig: return optimization_config
[docs] def untransform_outcome_constraints( self, outcome_constraints: list[OutcomeConstraint], fixed_features: ObservationFeatures | None = None, ) -> list[OutcomeConstraint]: return outcome_constraints
def _get_relative_data_from_obs( self, obs: Observation, rel_op: Callable[..., tuple[npt.NDArray, npt.NDArray]], ) -> ObservationData: idx = ( int(obs.features.trial_index) if obs.features.trial_index is not None else self.default_trial_idx ) if idx == self.default_trial_idx: # don't transform data from target batch return obs.data return super()._get_relative_data_from_obs( obs=obs, rel_op=rel_op, ) def _rel_op_on_observations( self, observations: list[Observation], rel_op: Callable[..., tuple[npt.NDArray, npt.NDArray]], ) -> list[Observation]: rel_observations = super()._rel_op_on_observations( observations=observations, rel_op=rel_op, ) return [ obs for obs in rel_observations # drop SQ observations if ( obs.arm_name != self._status_quo_name or obs.features.trial_index == self.default_trial_idx ) ] def _get_relative_data( self, data: ObservationData, status_quo_data: ObservationData, rel_op: Callable[..., tuple[npt.NDArray, npt.NDArray]], ) -> ObservationData: r""" Transform or untransform `data` based on `status_quo_data` based on `rel_op`. Args: data: ObservationData object to relativize status_quo_data: The status quo data associated with the specific trial that `data` belongs to. rel_op: relativize or unrelativize operator. control_as_constant: if treating the control metric as constant Returns: (un)transformed ObservationData """ L = len(data.metric_names) result = ObservationData( metric_names=data.metric_names, # zeros are just to create the shape so values can be set by index means=np.zeros(L), covariance=np.zeros((L, L)), ) for i, metric in enumerate(data.metric_names): j = get_metric_index(data=status_quo_data, metric_name=metric) means_t = data.means[i] sems_t = sqrt(data.covariance[i][i]) mean_c = status_quo_data.means[j] sem_c = sqrt(status_quo_data.covariance[j][j]) means_rel, sems_rel = self._get_rel_mean_sem( means_t=means_t, sems_t=sems_t, mean_c=mean_c, sem_c=sem_c, metric=metric, rel_op=rel_op, ) result.means[i] = means_rel result.covariance[i][i] = sems_rel**2 return result def _get_rel_mean_sem( self, means_t: float, sems_t: float, mean_c: float, sem_c: float, metric: str, rel_op: Callable[..., tuple[npt.NDArray, npt.NDArray]], ) -> tuple[float, float]: """Compute (un)transformed mean and sem for a single metric.""" target_status_quo_data = self.status_quo_data_by_trial[self.default_trial_idx] j = get_metric_index(data=target_status_quo_data, metric_name=metric) target_mean_c = target_status_quo_data.means[j] abs_target_mean_c = np.abs(target_mean_c) if rel_op == unrelativize: means_t = (means_t - target_mean_c) / abs_target_mean_c sems_t = sems_t / abs_target_mean_c means_rel, sems_rel = rel_op( means_t=means_t, sems_t=sems_t, mean_c=mean_c, sem_c=sem_c, as_percent=False, control_as_constant=self.control_as_constant, ) if rel_op == relativize: means_rel = means_rel * abs_target_mean_c + target_mean_c sems_rel = sems_rel * abs_target_mean_c # pyre-fixme[7]: Expected `Tuple[float, float]` but got # `Tuple[ndarray[typing.Any, typing.Any], ndarray[typing.Any, typing.Any]]`. return means_rel, sems_rel