Source code for ax.models.torch.rembo

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import dataclasses
from typing import Any, List, Optional, Tuple

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata
from ax.models.torch.botorch import BotorchModel
from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig
from ax.utils.common.docutils import copy_doc
from botorch.utils.datasets import SupervisedDataset
from torch import Tensor


[docs]class REMBO(BotorchModel): """Implements REMBO (Bayesian optimization in a linear subspace). The (D x d) projection matrix A must be provided, and must be that used for the initialization. In the original REMBO paper A ~ N(0, 1). Box bounds in the low-d space must also be provided, which in the REMBO paper should be [(-sqrt(d), sqrt(d)]^d. Function evaluations happen in the high-D space, and so the arms on the experiment will also be tracked in the high-D space. This class maintains a list of points in the low-d spac that have been launched, so we can match arms in high-D space back to their low-d point on update. Args: A: (D x d) projection matrix. initial_X_d: Points in low-d space for initial data. bounds_d: Box bounds in the low-d space. kwargs: kwargs for BotorchModel init """ def __init__( self, A: Tensor, initial_X_d: Tensor, bounds_d: List[Tuple[float, float]], **kwargs: Any, ) -> None: self.A = A self._pinvA = torch.pinverse(A) # compute pseudo inverse once and cache it # Projected points in low-d space generated in the optimization self.X_d = list(initial_X_d) self.X_d_gen = [] # Projected points that were generated by this model self.bounds_d = bounds_d self.num_outputs = 0 super().__init__(**kwargs)
[docs] @copy_doc(TorchModel.fit) def fit( self, datasets: List[SupervisedDataset], metric_names: List[str], search_space_digest: SearchSpaceDigest, candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None, ) -> None: assert len(search_space_digest.task_features) == 0 assert len(search_space_digest.fidelity_features) == 0 for b in search_space_digest.bounds: # REMBO assumes the input space is [-1, 1]^D assert b == (-1, 1) self.num_outputs = len(datasets) # For convenience for now, assume X for all outcomes the same low_d_datasets = self._convert_and_normalize_datasets(datasets=datasets) super().fit( datasets=low_d_datasets, metric_names=metric_names, search_space_digest=SearchSpaceDigest( feature_names=[f"x{i}" for i in range(self.A.shape[1])], bounds=[(0.0, 1.0)] * len(self.bounds_d), task_features=search_space_digest.task_features, fidelity_features=search_space_digest.fidelity_features, ), candidate_metadata=candidate_metadata, )
[docs] def to_01(self, X_d: Tensor) -> Tensor: """Map points from bounds_d to [0, 1]. Args: X_d: Tensor in bounds_d Returns: Tensor in [0, 1]. """ X_d01 = X_d.clone() for i, (lb, ub) in enumerate(self.bounds_d): X_d01[:, i] = (X_d01[:, i] - lb) / (ub - lb) return X_d01
[docs] def from_01(self, X_d01: Tensor) -> Tensor: """Map points from [0, 1] to bounds_d. Args: X_d01: Tensor in [0, 1] Returns: Tensor in bounds_d. """ X_d = X_d01.clone() for i, (lb, ub) in enumerate(self.bounds_d): X_d[:, i] = X_d[:, i] * (ub - lb) + lb return X_d
[docs] def project_down(self, X_D: Tensor) -> Tensor: """Map points in the high-D space to the low-d space by looking them up in self.X_d. We assume that X_D = self.project_up(self.X_d), except possibly with rows shuffled. If a value in X_d cannot be found for each row in X_D, an error will be raised. This is quite fast relative to model fitting, so we do it in O(n^2) time and don't worry about it. Args: X_D: Tensor in high-D space. Returns: X_d: Tensor in low-d space. """ X_d = [] unmatched = list(range(len(self.X_d))) for x_D in X_D: idx_match = None for d_idx in unmatched: if torch.allclose(x_D, self.project_up(self.X_d[d_idx])): idx_match = d_idx break if idx_match is not None: X_d.append(self.X_d[idx_match]) unmatched.remove(idx_match) else: raise ValueError("Failed to project X down.") return torch.stack(X_d)
[docs] def project_up(self, X: Tensor) -> Tensor: """Project to high-dimensional space.""" Z = torch.t(self.A @ torch.t(X)) Z = torch.clamp(Z, min=-1, max=1) return Z
[docs] @copy_doc(TorchModel.predict) def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]: # Suports preditions in both low-d and high-D space, depending on shape # of X. For high-D, predictions are restricted to within the linear # embedding, so can project down with pseudoinverse. if X.shape[1] == self.A.shape[1]: # X is in low-d space X_d = X # pragma: no cover else: # Project down to low-d space X_d = X @ torch.t(self._pinvA) # Project X_d back up to verify X was within linear embedding if not torch.allclose(X, X_d @ torch.t(self.A)): raise NotImplementedError( "Predictions outside the linear embedding not supported." ) return super().predict(X=self.to_01(X_d))
[docs] @copy_doc(TorchModel.gen) def gen( self, n: int, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, ) -> TorchGenResults: for b in search_space_digest.bounds: assert b == (-1, 1) # The following can be easily handled in the future when needed assert torch_opt_config.linear_constraints is None assert torch_opt_config.fixed_features is None assert torch_opt_config.pending_observations is None # Do gen in the low-dimensional space and project up gen_results = super().gen( n=n, search_space_digest=dataclasses.replace( search_space_digest, bounds=[(0.0, 1.0)] * len(self.bounds_d), ), torch_opt_config=torch_opt_config, ) Xopt = self.from_01(gen_results.points) self.X_d.extend([x.clone() for x in Xopt]) self.X_d_gen.extend([x.clone() for x in Xopt]) return TorchGenResults( points=self.project_up(Xopt), weights=gen_results.weights, )
[docs] @copy_doc(TorchModel.best_point) def best_point( self, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, ) -> Optional[Tensor]: for b in search_space_digest.bounds: assert b == (-1, 1) assert torch_opt_config.linear_constraints is None assert torch_opt_config.fixed_features is None x_best = super().best_point( search_space_digest=dataclasses.replace( search_space_digest, bounds=self.bounds_d, ), torch_opt_config=torch_opt_config, ) if x_best is not None: x_best = self.project_up(self.from_01(x_best.unsqueeze(0))).squeeze(0) return x_best
[docs] @copy_doc(TorchModel.cross_validate) def cross_validate( self, datasets: List[SupervisedDataset], X_test: Tensor, **kwargs: Any, ) -> Tuple[Tensor, Tensor]: low_d_datasets = self._convert_and_normalize_datasets(datasets=datasets) X_test_d = self.project_down(X_test) return super().cross_validate( datasets=low_d_datasets, X_test=self.to_01(X_test_d), )
[docs] @copy_doc(TorchModel.update) def update( self, datasets: List[SupervisedDataset], candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None, **kwargs: Any, ) -> None: low_d_datasets = self._convert_and_normalize_datasets(datasets=datasets) super().update( datasets=low_d_datasets, candidate_metadata=candidate_metadata, )
def _convert_and_normalize_datasets( self, datasets: List[SupervisedDataset] ) -> List[SupervisedDataset]: X_D = _get_single_X([dataset.X() for dataset in datasets]) X_d_01 = self.to_01(self.project_down(X_D)) # Fit model in low-d space (adjusted to [0, 1]^d) return [dataclasses.replace(dataset, X=X_d_01) for dataset in datasets]
def _get_single_X(Xs: List[Tensor]) -> Tensor: """Verify all X are identical, and return one. Args: Xs: A list of X tensors Returns: Xs[0], after verifying they are all identical. """ X = Xs[0] for i in range(1, len(Xs)): assert torch.allclose(X, Xs[i]) return X