Source code for ax.models.torch.rembo

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from ax.core.types import TCandidateMetadata, TConfig, TGenMetadata
from ax.models.torch.botorch import BotorchModel
from ax.models.torch_base import TorchModel
from ax.utils.common.docutils import copy_doc
from torch import Tensor


[docs]class REMBO(BotorchModel): """Implements REMBO (Bayesian optimization in a linear subspace). The (D x d) projection matrix A must be provided, and must be that used for the initialization. In the original REMBO paper A ~ N(0, 1). Box bounds in the low-d space must also be provided, which in the REMBO paper should be [(-sqrt(d), sqrt(d)]^d. Function evaluations happen in the high-D space, and so the arms on the experiment will also be tracked in the high-D space. This class maintains a list of points in the low-d spac that have been launched, so we can match arms in high-D space back to their low-d point on update. Args: A: (D x d) projection matrix. initial_X_d: Points in low-d space for initial data. bounds_d: Box bounds in the low-d space. kwargs: kwargs for BotorchModel init """ def __init__( self, A: Tensor, initial_X_d: Tensor, bounds_d: List[Tuple[float, float]], **kwargs: Any, ) -> None: self.A = A self._pinvA = torch.pinverse(A) # compute pseudo inverse once and cache it # Projected points in low-d space generated in the optimization self.X_d = list(initial_X_d) self.X_d_gen = [] # Projected points that were generated by this model self.bounds_d = bounds_d self.num_outputs = 0 super().__init__(**kwargs)
[docs] @copy_doc(TorchModel.fit) def fit( self, Xs: List[Tensor], Ys: List[Tensor], Yvars: List[Tensor], bounds: List[Tuple[float, float]], task_features: List[int], feature_names: List[str], metric_names: List[str], fidelity_features: List[int], candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None, ) -> None: assert len(task_features) == 0 assert len(fidelity_features) == 0 for b in bounds: # REMBO assumes the input space is [-1, 1]^D assert b == (-1, 1) self.num_outputs = len(Xs) # For convenience for now, assume X for all outcomes the same X_D = _get_single_X(Xs) X_d = self.project_down(X_D) # Fit model in low-d space (adjusted to [0, 1]^d) super().fit( Xs=[self.to_01(X_d)] * self.num_outputs, Ys=Ys, Yvars=Yvars, bounds=[(0.0, 1.0)] * len(self.bounds_d), task_features=task_features, feature_names=[f"x{i}" for i in range(self.A.shape[1])], metric_names=metric_names, fidelity_features=fidelity_features, candidate_metadata=candidate_metadata, )
[docs] def to_01(self, X_d: Tensor) -> Tensor: """Map points from bounds_d to [0, 1]. Args: X_d: Tensor in bounds_d Returns: Tensor in [0, 1]. """ X_d01 = X_d.clone() for i, (lb, ub) in enumerate(self.bounds_d): X_d01[:, i] = (X_d01[:, i] - lb) / (ub - lb) return X_d01
[docs] def from_01(self, X_d01: Tensor) -> Tensor: """Map points from [0, 1] to bounds_d. Args: X_d01: Tensor in [0, 1] Returns: Tensor in bounds_d. """ X_d = X_d01.clone() for i, (lb, ub) in enumerate(self.bounds_d): X_d[:, i] = X_d[:, i] * (ub - lb) + lb return X_d
[docs] def project_down(self, X_D: Tensor) -> Tensor: """Map points in the high-D space to the low-d space by looking them up in self.X_d. We assume that X_D = self.project_up(self.X_d), except possibly with rows shuffled. If a value in X_d cannot be found for each row in X_D, an error will be raised. This is quite fast relative to model fitting, so we do it in O(n^2) time and don't worry about it. Args: X_D: Tensor in high-D space. Returns: X_d: Tensor in low-d space. """ X_d = [] unmatched = list(range(len(self.X_d))) for x_D in X_D: idx_match = None for d_idx in unmatched: if torch.allclose(x_D, self.project_up(self.X_d[d_idx])): idx_match = d_idx break if idx_match is not None: X_d.append(self.X_d[idx_match]) unmatched.remove(idx_match) else: raise ValueError("Failed to project X down.") return torch.stack(X_d)
[docs] def project_up(self, X: Tensor) -> Tensor: """Project to high-dimensional space.""" Z = torch.t(self.A @ torch.t(X)) Z = torch.clamp(Z, min=-1, max=1) return Z
[docs] @copy_doc(TorchModel.predict) def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]: # Suports preditions in both low-d and high-D space, depending on shape # of X. For high-D, predictions are restricted to within the linear # embedding, so can project down with pseudoinverse. if X.shape[1] == self.A.shape[1]: # X is in low-d space X_d = X # pragma: no cover else: # Project down to low-d space X_d = X @ torch.t(self._pinvA) # Project X_d back up to verify X was within linear embedding if not torch.allclose(X, X_d @ torch.t(self.A)): raise NotImplementedError( "Predictions outside the linear embedding not supported." ) return super().predict(X=self.to_01(X_d))
[docs] @copy_doc(TorchModel.gen) def gen( self, n: int, bounds: List[Tuple[float, float]], objective_weights: Tensor, outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None, linear_constraints: Optional[Tuple[Tensor, Tensor]] = None, fixed_features: Optional[Dict[int, float]] = None, pending_observations: Optional[List[Tensor]] = None, model_gen_options: Optional[TConfig] = None, rounding_func: Optional[Callable[[Tensor], Tensor]] = None, target_fidelities: Optional[Dict[int, float]] = None, ) -> Tuple[Tensor, Tensor, TGenMetadata, Optional[List[TCandidateMetadata]]]: for b in bounds: assert b == (-1, 1) # The following can be easily handled in the future when needed assert linear_constraints is None assert fixed_features is None assert pending_observations is None # Do gen in the low-dimensional space and project up Xopt_01, w, _gen_metadata, _candidate_metadata = super().gen( n=n, bounds=[(0.0, 1.0)] * len(self.bounds_d), objective_weights=objective_weights, outcome_constraints=outcome_constraints, model_gen_options=model_gen_options, ) Xopt = self.from_01(Xopt_01) self.X_d.extend([x.clone() for x in Xopt]) self.X_d_gen.extend([x.clone() for x in Xopt]) return self.project_up(Xopt), w, {}, None
[docs] @copy_doc(TorchModel.best_point) def best_point( self, bounds: List[Tuple[float, float]], objective_weights: Tensor, outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None, linear_constraints: Optional[Tuple[Tensor, Tensor]] = None, fixed_features: Optional[Dict[int, float]] = None, model_gen_options: Optional[TConfig] = None, target_fidelities: Optional[Dict[int, float]] = None, ) -> Optional[Tensor]: for b in bounds: assert b == (-1, 1) assert linear_constraints is None assert fixed_features is None x_best = super().best_point( bounds=self.bounds_d, objective_weights=objective_weights, outcome_constraints=outcome_constraints, model_gen_options=model_gen_options, ) if x_best is not None: x_best = self.project_up(self.from_01(x_best.unsqueeze(0))).squeeze(0) return x_best
[docs] @copy_doc(TorchModel.cross_validate) def cross_validate( self, Xs_train: List[Tensor], Ys_train: List[Tensor], Yvars_train: List[Tensor], X_test: Tensor, **kwargs: Any, ) -> Tuple[Tensor, Tensor]: X_D = _get_single_X(Xs_train) X_train_d = self.project_down(X_D) X_test_d = self.project_down(X_test) return super().cross_validate( Xs_train=[self.to_01(X_train_d)] * self.num_outputs, Ys_train=Ys_train, Yvars_train=Yvars_train, X_test=self.to_01(X_test_d), )
[docs] @copy_doc(TorchModel.update) def update( self, Xs: List[Tensor], Ys: List[Tensor], Yvars: List[Tensor], candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None, **kwargs: Any, ) -> None: X_D = _get_single_X(Xs) X_d = self.project_down(X_D) super().update( Xs=[self.to_01(X_d)] * self.num_outputs, Ys=Ys, Yvars=Yvars, candidate_metadata=candidate_metadata, )
def _get_single_X(Xs: List[Tensor]) -> Tensor: """Verify all X are identical, and return one. Args: Xs: A list of X tensors Returns: Xs[0], after verifying they are all identical. """ X = Xs[0] for i in range(1, len(Xs)): assert torch.allclose(X, Xs[i]) return X