#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata, TGenMetadata
from ax.models.torch.botorch import BotorchModel
from ax.models.torch_base import TorchModel
from ax.models.types import TConfig
from ax.utils.common.docutils import copy_doc
from torch import Tensor
[docs]class REMBO(BotorchModel):
"""Implements REMBO (Bayesian optimization in a linear subspace).
The (D x d) projection matrix A must be provided, and must be that used for
the initialization. In the original REMBO paper A ~ N(0, 1). Box bounds
in the low-d space must also be provided, which in the REMBO paper should
be [(-sqrt(d), sqrt(d)]^d.
Function evaluations happen in the high-D space, and so the arms on the
experiment will also be tracked in the high-D space. This class maintains
a list of points in the low-d spac that have been launched, so we can match
arms in high-D space back to their low-d point on update.
Args:
A: (D x d) projection matrix.
initial_X_d: Points in low-d space for initial data.
bounds_d: Box bounds in the low-d space.
kwargs: kwargs for BotorchModel init
"""
def __init__(
self,
A: Tensor,
initial_X_d: Tensor,
bounds_d: List[Tuple[float, float]],
**kwargs: Any,
) -> None:
self.A = A
self._pinvA = torch.pinverse(A) # compute pseudo inverse once and cache it
# Projected points in low-d space generated in the optimization
self.X_d = list(initial_X_d)
self.X_d_gen = [] # Projected points that were generated by this model
self.bounds_d = bounds_d
self.num_outputs = 0
super().__init__(**kwargs)
[docs] @copy_doc(TorchModel.fit)
def fit(
self,
Xs: List[Tensor],
Ys: List[Tensor],
Yvars: List[Tensor],
search_space_digest: SearchSpaceDigest,
metric_names: List[str],
candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None,
) -> None:
assert len(search_space_digest.task_features) == 0
assert len(search_space_digest.fidelity_features) == 0
for b in search_space_digest.bounds:
# REMBO assumes the input space is [-1, 1]^D
assert b == (-1, 1)
self.num_outputs = len(Xs)
# For convenience for now, assume X for all outcomes the same
X_D = _get_single_X(Xs)
X_d = self.project_down(X_D)
# Fit model in low-d space (adjusted to [0, 1]^d)
super().fit(
Xs=[self.to_01(X_d)] * self.num_outputs,
Ys=Ys,
Yvars=Yvars,
search_space_digest=SearchSpaceDigest(
feature_names=[f"x{i}" for i in range(self.A.shape[1])],
bounds=[(0.0, 1.0)] * len(self.bounds_d),
task_features=search_space_digest.task_features,
fidelity_features=search_space_digest.fidelity_features,
),
metric_names=metric_names,
candidate_metadata=candidate_metadata,
)
[docs] def to_01(self, X_d: Tensor) -> Tensor:
"""Map points from bounds_d to [0, 1].
Args:
X_d: Tensor in bounds_d
Returns: Tensor in [0, 1].
"""
X_d01 = X_d.clone()
for i, (lb, ub) in enumerate(self.bounds_d):
X_d01[:, i] = (X_d01[:, i] - lb) / (ub - lb)
return X_d01
[docs] def from_01(self, X_d01: Tensor) -> Tensor:
"""Map points from [0, 1] to bounds_d.
Args:
X_d01: Tensor in [0, 1]
Returns: Tensor in bounds_d.
"""
X_d = X_d01.clone()
for i, (lb, ub) in enumerate(self.bounds_d):
X_d[:, i] = X_d[:, i] * (ub - lb) + lb
return X_d
[docs] def project_down(self, X_D: Tensor) -> Tensor:
"""Map points in the high-D space to the low-d space by looking them
up in self.X_d.
We assume that X_D = self.project_up(self.X_d), except possibly with
rows shuffled. If a value in X_d cannot be found for each row in X_D,
an error will be raised.
This is quite fast relative to model fitting, so we do it in O(n^2)
time and don't worry about it.
Args:
X_D: Tensor in high-D space.
Returns:
X_d: Tensor in low-d space.
"""
X_d = []
unmatched = list(range(len(self.X_d)))
for x_D in X_D:
idx_match = None
for d_idx in unmatched:
if torch.allclose(x_D, self.project_up(self.X_d[d_idx])):
idx_match = d_idx
break
if idx_match is not None:
X_d.append(self.X_d[idx_match])
unmatched.remove(idx_match)
else:
raise ValueError("Failed to project X down.")
return torch.stack(X_d)
[docs] def project_up(self, X: Tensor) -> Tensor:
"""Project to high-dimensional space."""
Z = torch.t(self.A @ torch.t(X))
Z = torch.clamp(Z, min=-1, max=1)
return Z
[docs] @copy_doc(TorchModel.predict)
def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
# Suports preditions in both low-d and high-D space, depending on shape
# of X. For high-D, predictions are restricted to within the linear
# embedding, so can project down with pseudoinverse.
if X.shape[1] == self.A.shape[1]:
# X is in low-d space
X_d = X # pragma: no cover
else:
# Project down to low-d space
X_d = X @ torch.t(self._pinvA)
# Project X_d back up to verify X was within linear embedding
if not torch.allclose(X, X_d @ torch.t(self.A)):
raise NotImplementedError(
"Predictions outside the linear embedding not supported."
)
return super().predict(X=self.to_01(X_d))
[docs] @copy_doc(TorchModel.gen)
def gen(
self,
n: int,
bounds: List[Tuple[float, float]],
objective_weights: Tensor,
outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None,
linear_constraints: Optional[Tuple[Tensor, Tensor]] = None,
fixed_features: Optional[Dict[int, float]] = None,
pending_observations: Optional[List[Tensor]] = None,
model_gen_options: Optional[TConfig] = None,
rounding_func: Optional[Callable[[Tensor], Tensor]] = None,
target_fidelities: Optional[Dict[int, float]] = None,
) -> Tuple[Tensor, Tensor, TGenMetadata, Optional[List[TCandidateMetadata]]]:
for b in bounds:
assert b == (-1, 1)
# The following can be easily handled in the future when needed
assert linear_constraints is None
assert fixed_features is None
assert pending_observations is None
# Do gen in the low-dimensional space and project up
Xopt_01, w, _gen_metadata, _candidate_metadata = super().gen(
n=n,
bounds=[(0.0, 1.0)] * len(self.bounds_d),
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
model_gen_options=model_gen_options,
)
Xopt = self.from_01(Xopt_01)
self.X_d.extend([x.clone() for x in Xopt])
self.X_d_gen.extend([x.clone() for x in Xopt])
return self.project_up(Xopt), w, {}, None
[docs] @copy_doc(TorchModel.best_point)
def best_point(
self,
bounds: List[Tuple[float, float]],
objective_weights: Tensor,
outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None,
linear_constraints: Optional[Tuple[Tensor, Tensor]] = None,
fixed_features: Optional[Dict[int, float]] = None,
model_gen_options: Optional[TConfig] = None,
target_fidelities: Optional[Dict[int, float]] = None,
) -> Optional[Tensor]:
for b in bounds:
assert b == (-1, 1)
assert linear_constraints is None
assert fixed_features is None
x_best = super().best_point(
bounds=self.bounds_d,
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
model_gen_options=model_gen_options,
)
if x_best is not None:
x_best = self.project_up(self.from_01(x_best.unsqueeze(0))).squeeze(0)
return x_best
[docs] @copy_doc(TorchModel.cross_validate)
def cross_validate(
self,
Xs_train: List[Tensor],
Ys_train: List[Tensor],
Yvars_train: List[Tensor],
X_test: Tensor,
**kwargs: Any,
) -> Tuple[Tensor, Tensor]:
X_D = _get_single_X(Xs_train)
X_train_d = self.project_down(X_D)
X_test_d = self.project_down(X_test)
return super().cross_validate(
Xs_train=[self.to_01(X_train_d)] * self.num_outputs,
Ys_train=Ys_train,
Yvars_train=Yvars_train,
X_test=self.to_01(X_test_d),
)
[docs] @copy_doc(TorchModel.update)
def update(
self,
Xs: List[Tensor],
Ys: List[Tensor],
Yvars: List[Tensor],
candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None,
**kwargs: Any,
) -> None:
X_D = _get_single_X(Xs)
X_d = self.project_down(X_D)
super().update(
Xs=[self.to_01(X_d)] * self.num_outputs,
Ys=Ys,
Yvars=Yvars,
candidate_metadata=candidate_metadata,
)
def _get_single_X(Xs: List[Tensor]) -> Tensor:
"""Verify all X are identical, and return one.
Args:
Xs: A list of X tensors
Returns: Xs[0], after verifying they are all identical.
"""
X = Xs[0]
for i in range(1, len(Xs)):
assert torch.allclose(X, Xs[i])
return X