Source code for ax.modelbridge.transforms.relativize

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import json
import warnings
from math import sqrt
from typing import Dict, List, Optional, TYPE_CHECKING

import numpy as np
from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.optimization_config import (
    MultiObjectiveOptimizationConfig,
    OptimizationConfig,
)
from ax.core.search_space import SearchSpace
from ax.modelbridge.transforms.base import Transform
from ax.models.types import TConfig
from ax.utils.common.typeutils import not_none
from ax.utils.stats.statstools import relativize


if TYPE_CHECKING:
    # import as module to make sphinx-autodoc-typehints happy
    from ax import modelbridge as modelbridge_module  # noqa F401  # pragma: no cover


[docs]class Relativize(Transform): """ Change the relative flag of the given relative optimization configuration to False. This is needed in order for the new opt config to pass ModelBridge that requires non-relativized opt config. Also transforms absolute data and opt configs to relative. Requires a modelbridge with a status quo set to work. """ MISSING_STATUS_QUO_ERROR = "Cannot relativize data without status quo data" def __init__( self, search_space: Optional[SearchSpace], observation_features: List[ObservationFeatures], observation_data: List[ObservationData], modelbridge: Optional[modelbridge_module.base.ModelBridge] = None, config: Optional[TConfig] = None, ) -> None: super().__init__( search_space=search_space, observation_features=observation_features, observation_data=observation_data, modelbridge=modelbridge, config=config, ) # self.modelbridge should NOT be modified self.modelbridge = not_none( modelbridge, "Relativize transform requires a modelbridge" ) self.status_quo_by_trial = self._get_status_quo_by_trial( observation_data=observation_data, observation_features=observation_features, status_quo_feature=not_none( self.modelbridge.status_quo, self.MISSING_STATUS_QUO_ERROR ).features, )
[docs] def transform_optimization_config( self, optimization_config: OptimizationConfig, modelbridge: Optional[modelbridge_module.base.ModelBridge], fixed_features: ObservationFeatures, ) -> OptimizationConfig: r""" Change the relative flag of the given relative optimization configuration to False. This is needed in order for the new opt config to pass ModelBridge that requires non-relativized opt config. Args: opt_config: Optimization configuaration relative to status quo. Returns: Optimization configuration relative to status quo with relative flag equal to false. """ # Getting constraints constraints = [ constraint.clone() for constraint in optimization_config.outcome_constraints ] if not all( constraint.relative for constraint in optimization_config.outcome_constraints ): raise ValueError( "All constraints must be relative to use the Relativize transform." ) for constraint in constraints: constraint.relative = False if isinstance(optimization_config, MultiObjectiveOptimizationConfig): # Getting objective thresholds obj_thresholds = [ obj_threshold.clone() for obj_threshold in optimization_config.objective_thresholds ] for obj_threshold in obj_thresholds: if not obj_threshold.relative: raise ValueError( "All objective thresholds must be relative to use " "the Relativize transform." ) obj_threshold.relative = False new_optimization_config = MultiObjectiveOptimizationConfig( objective=optimization_config.objective, outcome_constraints=constraints, objective_thresholds=obj_thresholds, ) else: new_optimization_config = OptimizationConfig( objective=optimization_config.objective, outcome_constraints=constraints, ) return new_optimization_config
[docs] def transform_observation_data( self, observation_data: List[ObservationData], observation_features: List[ObservationFeatures], ) -> List[ObservationData]: return [ self._get_relative_data( data=datum, status_quo_data=not_none( self.status_quo_by_trial.get(obs_features.trial_index, None), self.MISSING_STATUS_QUO_ERROR, ), ) for datum, obs_features in zip(observation_data, observation_features) ]
[docs] def untransform_observation_data( self, observation_data: List[ObservationData], observation_features: List[ObservationFeatures], ) -> List[ObservationData]: warnings.warn( "`Relativize.untransform_observation_data()` not yet implemented. " "Returning relative data." ) return observation_data
@staticmethod def _get_relative_data( data: ObservationData, status_quo_data: ObservationData ) -> ObservationData: L = len(data.metric_names) result = ObservationData( metric_names=data.metric_names, # zeros are just to create the shape so values can be set by index means=np.zeros(L), covariance=np.zeros((L, L)), ) for i, metric in enumerate(data.metric_names): try: j = next( k for k in range(L) if status_quo_data.metric_names[k] == metric ) except (IndexError, StopIteration): raise ValueError( "Relativization cannot be performed because " "ObservationData for status quo is missing metrics" ) means_t = data.means[i] sems_t = sqrt(data.covariance[i][i]) mean_c = status_quo_data.means[j] sem_c = sqrt(status_quo_data.covariance[j][j]) # if the is the status quo if means_t == mean_c and sems_t == sem_c: means_rel, sems_rel = 0, 0 else: means_rel, sems_rel = relativize( means_t=means_t, sems_t=sems_t, mean_c=mean_c, sem_c=sem_c, as_percent=True, ) result.means[i] = means_rel result.covariance[i][i] = sems_rel**2 return result @staticmethod def _get_status_quo_by_trial( observation_data: List[ObservationData], observation_features: List[ObservationFeatures], status_quo_feature: ObservationFeatures, ) -> Dict[int, ObservationData]: status_quo_signature = json.dumps(status_quo_feature.parameters, sort_keys=True) return { obs_f.trial_index: obs_data for obs_data, obs_f in zip(observation_data, observation_features) if json.dumps(obs_f.parameters, sort_keys=True) == status_quo_signature }