Source code for ax.modelbridge.map_torch

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any

import numpy as np
import numpy.typing as npt

import torch
from ax.core.base_trial import TrialStatus
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.map_data import MapData
from ax.core.observation import (
    Observation,
    ObservationData,
    ObservationFeatures,
    observations_from_map_data,
    separate_observations,
)
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.core.types import TCandidateMetadata
from ax.modelbridge.base import GenResults
from ax.modelbridge.modelbridge_utils import (
    array_to_observation_data,
    observation_features_to_array,
    parse_observation_features,
)
from ax.modelbridge.torch import FIT_MODEL_ERROR, TorchModelBridge
from ax.modelbridge.transforms.base import Transform
from ax.models.torch_base import TorchModel
from ax.models.types import TConfig
from ax.utils.common.constants import Keys
from pyre_extensions import none_throws


# A mapping from map_key to its target (or final) value; by default,
# we assume normalization to [0, 1], making 1.0 the target value.
# Used in both generation and prediction.
DEFAULT_TARGET_MAP_VALUES = {"steps": 1.0}


[docs] class MapTorchModelBridge(TorchModelBridge): """A model bridge for using torch-based models that fit on MapData. Most of the `TorchModelBridge` functionality is retained, except that this class should be used in the case where `model` makes use of map_key values. For example, the use case of fitting a joint surrogate model on `(parameters, map_key)`, while candidate generation is only for `parameters`. """ def __init__( self, experiment: Experiment, search_space: SearchSpace, data: Data, model: TorchModel, transforms: list[type[Transform]], transform_configs: dict[str, TConfig] | None = None, torch_dtype: torch.dtype | None = None, torch_device: torch.device | None = None, status_quo_name: str | None = None, status_quo_features: ObservationFeatures | None = None, optimization_config: OptimizationConfig | None = None, fit_out_of_design: bool = False, fit_on_init: bool = True, fit_abandoned: bool = False, default_model_gen_options: TConfig | None = None, map_data_limit_rows_per_metric: int | None = None, map_data_limit_rows_per_group: int | None = None, ) -> None: """ Applies transforms and fits model. Args: experiment: Is used to get arm parameters. Is not mutated. search_space: Search space for fitting the model. Constraints need not be the same ones used in gen. data: Ax Data. model: Interface will be specified in subclass. If model requires initialization, that should be done prior to its use here. transforms: List of uninitialized transform classes. Forward transforms will be applied in this order, and untransforms in the reverse order. transform_configs: A dictionary from transform name to the transform config dictionary. torch_dtype: Torch data type. torch_device: Torch device. status_quo_name: Name of the status quo arm. Can only be used if Data has a single set of ObservationFeatures corresponding to that arm. status_quo_features: ObservationFeatures to use as status quo. Either this or status_quo_name should be specified, not both. optimization_config: Optimization config defining how to optimize the model. fit_out_of_design: If specified, all training data is returned. Otherwise, only in design points are returned. fit_on_init: Whether to fit the model on initialization. This can be used to skip model fitting when a fitted model is not needed. To fit the model afterwards, use `_process_and_transform_data` to get the transformed inputs and call `_fit_if_implemented` with the transformed inputs. fit_abandoned: Whether data for abandoned arms or trials should be included in model training data. If ``False``, only non-abandoned points are returned. default_model_gen_options: Options passed down to `model.gen(...)`. map_data_limit_rows_per_metric: Subsample the map data so that the total number of rows per metric is limited by this value. map_data_limit_rows_per_group: Subsample the map data so that the number of rows in the `map_key` column for each (arm, metric) is limited by this value. """ if not isinstance(data, MapData): raise ValueError( "`MapTorchModelBridge expects `MapData` instead of `Data`." ) if any(isinstance(t, BatchTrial) for t in experiment.trials.values()): raise ValueError("MapTorchModelBridge does not support batch trials.") # pyre-fixme[4]: Attribute must be annotated. self._map_key_features = data.map_keys self._map_data_limit_rows_per_metric = map_data_limit_rows_per_metric self._map_data_limit_rows_per_group = map_data_limit_rows_per_group super().__init__( experiment=experiment, search_space=search_space, data=data, model=model, transforms=transforms, transform_configs=transform_configs, torch_dtype=torch_dtype, torch_device=torch_device, status_quo_name=status_quo_name, status_quo_features=status_quo_features, optimization_config=optimization_config, fit_out_of_design=fit_out_of_design, fit_on_init=fit_on_init, default_model_gen_options=default_model_gen_options, ) @property def statuses_to_fit_map_metric(self) -> set[TrialStatus]: return self.statuses_to_fit @property def parameters_with_map_keys(self) -> list[str]: """The parameters used for fitting the model, including map_keys.""" # NOTE: This list determines the order of feature columns in the training data. # Learning-curve-based modeling methods assume that the last columns are # map_keys, so we place self._map_key_features on the end. # TODO: Plumb down the `map_key` feature indices to the model, so that we don't # have to make the assumption in the above note. return self.parameters + self._map_key_features def _predict( self, observation_features: list[ObservationFeatures] ) -> list[ObservationData]: """This method is updated from `TorchModelBridge._predict(...) in that it will accept observation features with or without map_keys. If observation features do not contain map_keys, it will insert them based on `target_map_values`. """ if not self.model: raise ValueError(FIT_MODEL_ERROR.format(action="_model_predict")) # The fitted model expects map_keys. If they do not exist, we use the # target values. target_map_values = self._default_model_gen_options.get( "target_map_values", DEFAULT_TARGET_MAP_VALUES ) for p in self._map_key_features: for obsf in observation_features: if p not in obsf.parameters: obsf.parameters[p] = target_map_values[p] # pyre-ignore[16] # Convert observation features to array X = observation_features_to_array( self.parameters_with_map_keys, observation_features ) f, cov = none_throws(self.model).predict(X=self._array_to_tensor(X)) f = f.detach().cpu().clone().numpy() cov = cov.detach().cpu().clone().numpy() # Convert resulting arrays to observations return array_to_observation_data(f=f, cov=cov, outcomes=self.outcomes) def _fit( self, model: TorchModel, search_space: SearchSpace, observations: list[Observation], parameters: list[str] | None = None, **kwargs: Any, ) -> None: """The difference from `TorchModelBridge._fit(...)` is that we use `self.parameters_with_map_keys` instead of `self.parameters`. """ self.parameters = list(search_space.parameters.keys()) if parameters is None: parameters = self.parameters_with_map_keys super()._fit( model=model, search_space=search_space, observations=observations, parameters=parameters, **kwargs, ) def _gen( self, n: int, search_space: SearchSpace, pending_observations: dict[str, list[ObservationFeatures]], fixed_features: ObservationFeatures | None, model_gen_options: TConfig | None = None, optimization_config: OptimizationConfig | None = None, ) -> GenResults: """An updated version of `TorchModelBridge._gen(...) that first injects `map_dim_to_target` (e.g., `{-1: 1.0}`) into `model_gen_options` so that the target values of the map_keys are known during candidate generation. """ model_gen_options = self._add_map_dim_to_target(options=model_gen_options or {}) return super()._gen( n=n, search_space=search_space, pending_observations=pending_observations, fixed_features=fixed_features, model_gen_options=model_gen_options, optimization_config=optimization_config, ) def _array_to_observation_features( self, X: npt.NDArray, candidate_metadata: list[TCandidateMetadata] | None, ) -> list[ObservationFeatures]: """The difference b/t this method and TorchModelBridge._array_to_observation_features(...) is that this one makes use of `self.parameters_with_map_keys`. """ return parse_observation_features( X=X, param_names=self.parameters_with_map_keys, candidate_metadata=candidate_metadata, ) def _prepare_observations( self, experiment: Experiment | None, data: Data | None ) -> list[Observation]: """The difference b/t this method and ModelBridge._prepare_observations(...) is that this one uses `observations_from_map_data`. """ if experiment is None or data is None: return [] return observations_from_map_data( experiment=experiment, map_data=data, # pyre-ignore[6]: Checked in __init__. map_keys_as_parameters=True, limit_rows_per_metric=self._map_data_limit_rows_per_metric, limit_rows_per_group=self._map_data_limit_rows_per_group, statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, ) def _compute_in_design( self, search_space: SearchSpace, observations: list[Observation] ) -> list[bool]: """The difference b/t this method and ModelBridge._compute_in_design(...) is that this one correctly excludes map_keys when checking membership in search space (as map_keys are not explicitly in the search space). """ return [ search_space.check_membership( # Exclude map key features when checking { p: v for p, v in obs.features.parameters.items() if p not in self._map_key_features } ) for obs in observations ] def _cross_validate( self, search_space: SearchSpace, cv_training_data: list[Observation], cv_test_points: list[ObservationFeatures], parameters: list[str] | None = None, use_posterior_predictive: bool = False, **kwargs: Any, ) -> list[ObservationData]: """Make predictions at cv_test_points using only the data in obs_feats and obs_data. The difference from `TorchModelBridge._cross_validate` is that here we do cross validation on the parameters + map_keys. There is some extra logic to filter out out-of-design points in the map_key dimension. """ if parameters is None: parameters = self.parameters_with_map_keys cv_test_data = super()._cross_validate( search_space=search_space, cv_training_data=cv_training_data, cv_test_points=cv_test_points, parameters=parameters, # we pass the map_keys too by default use_posterior_predictive=use_posterior_predictive, **kwargs, ) observation_features, observation_data = separate_observations(cv_training_data) # Since map_keys are used as features, there can be the possibility that # models for different outcomes were fit on different ranges of map_key # values; for example, this is the case if we (1) mix learning curve data with # standard data (taking default map value), or (2) are in a situation where # some learning curves start later than others. These prediction results are # "out-of-design" in the map_key dimension, so we should filter them out. map_key_ranges = self._get_map_key_ranges( observation_features=observation_features, observation_data=observation_data, ) cv_test_data = self._filter_outcomes_out_of_map_range( observation_features=cv_test_points, observation_data=cv_test_data, map_key_ranges=map_key_ranges, ) return cv_test_data def _filter_outcomes_out_of_map_range( self, observation_features: list[ObservationFeatures], observation_data: list[ObservationData], # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. map_key_ranges: dict[str, dict[str, tuple | None]], ) -> list[ObservationData]: """Uses `map_key_ranges` to detect which `observation_features` have out-of-range map_keys and filters out the corresponding outcomes in `observation_data`. """ filtered_observation_data = [] for obsf, obsd in zip(observation_features, observation_data): metric_names = obsd.metric_names means = obsd.means covariance = obsd.covariance for o in self.outcomes: if o in metric_names: for p in self._map_key_features: map_key_value = obsf.parameters[p] map_key_range = map_key_ranges[o][p] if map_key_range is not None: range_min, range_max = map_key_range # pyre-fixme[58]: `<` is not supported for operand types # `Union[None, bool, float, int, str]` and `Any`. # pyre-fixme[58]: `>` is not supported for operand types # `Union[None, bool, float, int, str]` and `Any`. if map_key_value < range_min or map_key_value > range_max: p_idx = metric_names.index(o) metric_names.pop(p_idx) means = np.delete(means, p_idx, axis=0) covariance = np.delete(covariance, p_idx, axis=0) covariance = np.delete(covariance, p_idx, axis=1) break new_obsd = ObservationData( metric_names=metric_names, means=means, covariance=covariance ) filtered_observation_data.append(new_obsd) return filtered_observation_data def _get_map_key_ranges( self, observation_features: list[ObservationFeatures], observation_data: list[ObservationData], # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. ) -> dict[str, dict[str, tuple | None]]: """Get ranges of map_key values in observation features. Returns a dict of the form: {"outcome": {"map_key": (min_val, max_val)}}. """ map_values = {o: {p: [] for p in self._map_key_features} for o in self.outcomes} for obsd, obsf in zip(observation_data, observation_features): for p in self._map_key_features: param_value = obsf.parameters[p] for o in obsd.metric_names: map_values[o][p].append(param_value) # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[24]: Generic type `list` expects 1 type parameter, use # `typing.List` to avoid runtime subscripting errors. def get_range(values: list): return (min(values), max(values)) if len(values) > 0 else None return { o: {p: get_range(map_values[o][p]) for p in self._map_key_features} for o in self.outcomes } def _add_map_dim_to_target(self, options: TConfig) -> TConfig: """Convert `target_map_values` to `map_dim_to_target`, a form useable by the acquisition function and insert into options dict. """ target_map_values = self._default_model_gen_options.get("target_map_values") if target_map_values is None: target_map_values = DEFAULT_TARGET_MAP_VALUES param_and_map = self.parameters_with_map_keys map_dim_to_target = { param_and_map.index(p): target_map_values[p] # pyre-ignore[16] for p in self._map_key_features } options[Keys.ACQF_KWARGS] = { # pyre-ignore[32] **options.get(Keys.ACQF_KWARGS, {}), "map_dim_to_target": map_dim_to_target, } return options