Source code for ax.modelbridge.pairwise

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import List, Optional, Tuple

import numpy as np
import torch
from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.types import TCandidateMetadata
from ax.modelbridge.modelbridge_utils import detect_duplicates
from ax.modelbridge.torch import TorchModelBridge
from botorch.utils.containers import SliceContainer
from botorch.utils.datasets import RankingDataset, SupervisedDataset
from torch import Tensor


[docs]class PairwiseModelBridge(TorchModelBridge): def _convert_observations( self, observation_data: List[ObservationData], observation_features: List[ObservationFeatures], outcomes: List[str], parameters: List[str], ) -> Tuple[ List[Optional[SupervisedDataset]], Optional[List[List[TCandidateMetadata]]] ]: """Converts observations to a dictionary of `Dataset` containers and (optional) candidate metadata. """ if len(observation_features) != len(observation_data): raise ValueError("Observation features and data must have the same length!") ordered_idx = np.argsort([od.trial_index for od in observation_features]) observation_features = [observation_features[i] for i in ordered_idx] observation_data = [observation_data[i] for i in ordered_idx] ( Xs, Ys, Yvars, candidate_metadata_dict, any_candidate_metadata_is_not_none, ) = self._extract_observation_data( observation_data, observation_features, parameters ) datasets: List[Optional[SupervisedDataset]] = [] candidate_metadata = [] for outcome in outcomes: X = torch.stack(Xs[outcome], dim=0) Y = torch.tensor(Ys[outcome], dtype=torch.long).unsqueeze(-1) # Update Xs and Ys shapes for PairwiseGP Y = _binary_pref_to_comp_pair(Y=Y) X, Y = _consolidate_comparisons(X=X, Y=Y) datapoints, comparisons = X, Y.long() event_shape = torch.Size([2 * datapoints.shape[-1]]) # pyre-fixme[6]: For 2nd param expected `LongTensor` but got `Tensor`. dataset_X = SliceContainer(datapoints, comparisons, event_shape=event_shape) dataset_Y = torch.tensor([[0, 1]]).expand(comparisons.shape) dataset = RankingDataset(X=dataset_X, Y=dataset_Y) datasets.append(dataset) candidate_metadata.append(candidate_metadata_dict[outcome]) if not any_candidate_metadata_is_not_none: return datasets, None return datasets, candidate_metadata
def _binary_pref_to_comp_pair(Y: Tensor) -> Tensor: """Convert Y from binary indicator pair to index pair comparisons Convert Y from binary indicator pair such as [[0, 1], [1, 0], ...] to index comparisons like [[1, 0], [2, 3], ...] """ Y_shape = Y.shape[:-2] + (-1, 2) Y = Y.reshape(Y_shape) _validate_Y_values(Y) idx_shift = (torch.arange(0, Y.shape[-2]) * 2).unsqueeze(-1).expand_as(Y) comparison_pairs = idx_shift + (1 - Y) return comparison_pairs def _consolidate_comparisons(X: Tensor, Y: Tensor) -> Tuple[Tensor, Tensor]: """Drop duplicated Xs and update the indices in Ys accordingly""" if len(X.shape) != 2: raise ValueError("X must have 2 dimensions.") if len(Y.shape) != 2: raise ValueError("Y must have 2 dimensions.") if Y.shape[-1] != 2: raise ValueError( "The last dimension of Y must contain 2 elements " "representing the pairwise comparison." ) n = X.shape[-2] dupplicates = list(detect_duplicates(X=X)) if len(dupplicates) != 0: dup_indices, kept_indices = zip(*dupplicates) unique_indices = set(range(n)) - set(dup_indices) # After dropping the duplicates, # the kept ones' indices may also change by being shifted up new_idx_map = dict(zip(unique_indices, range(len(unique_indices)))) new_indices_for_dup = (new_idx_map[idx] for idx in kept_indices) new_idx_map.update(dict(zip(dup_indices, new_indices_for_dup))) consolidated_X = X[list(unique_indices), :] consolidated_Y = torch.tensor( [(new_idx_map[y1.item()], new_idx_map[y2.item()]) for y1, y2 in Y], dtype=torch.long, ) return consolidated_X, consolidated_Y else: return X, Y def _validate_Y_values(Y: Tensor) -> None: """Check if Ys have valid values""" # Y must have even number of elements if Y.shape[-1] != 2: raise ValueError( f"Trailing dimension of `Y` should be size 2 but is {Y.shape[-1]}" ) # all adjacent pairs must have exactly a 0 and a 1 if not (Y.min(dim=-1).values.eq(0).all() and Y.max(dim=-1).values.eq(1).all()): raise ValueError("`Y` values must be `{0, 1}.`")