#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import dataclasses
import re
from collections import OrderedDict
from logging import Logger
from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple, Union
import gpytorch
import numpy as np
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata
from ax.models.random.alebo_initializer import ALEBOInitializer
from ax.models.torch.botorch import BotorchModel
from ax.models.torch.botorch_defaults import get_NEI
from ax.models.torch.utils import _datasets_to_legacy_inputs
from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig
from ax.utils.common.docutils import copy_doc
from ax.utils.common.logger import get_logger
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import ExpectedImprovement
from botorch.acquisition.objective import PosteriorTransform
from botorch.models.gp_regression import FixedNoiseGP
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.optim.fit import fit_gpytorch_scipy
from botorch.optim.initializers import initialize_q_batch_nonneg
from botorch.optim.numpy_converter import module_to_array
from botorch.optim.optimize import optimize_acqf
from botorch.optim.utils import _scipy_objective_and_grad
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.utils.datasets import SupervisedDataset
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels.kernel import Kernel
from gpytorch.kernels.rbf_kernel import postprocess_rbf
from gpytorch.kernels.scale_kernel import ScaleKernel
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from scipy.optimize import approx_fprime
from torch import Tensor
logger: Logger = get_logger(__name__)
[docs]class ALEBOKernel(Kernel):
"""The kernel for ALEBO.
Suppose there exists an ARD RBF GP on an (unknown) linear embedding with
projection matrix A. We make function evaluations in a different linear
embedding with projection matrix B (known). This is the appropriate kernel
for fitting those data.
This kernel computes a Mahalanobis distance, and the (d x d) PD distance
matrix Gamma is a parameter that must be fit. This is done by fitting its
upper Cholesky decomposition, U.
Args:
B: (d x D) Projection matrix.
batch_shape: Batch shape as usual for gpytorch kernels.
"""
def __init__(self, B: Tensor, batch_shape: torch.Size) -> None:
super().__init__(
has_lengthscale=False, ard_num_dims=None, eps=0.0, batch_shape=batch_shape
)
# pyre-fixme[4]: Attribute must be annotated.
self.d, D = B.shape
assert self.d < D
self.B = B
# Initialize U
Arnd = torch.randn(D, D, dtype=B.dtype, device=B.device)
Arnd = torch.linalg.qr(Arnd)[0]
ABinv = Arnd[: self.d, :] @ torch.pinverse(B)
# U is the upper Cholesky decomposition of Gamma, the Mahalanobis
# matrix. Uvec is the upper triangular portion of U squeezed out into
# a vector.
U = torch.linalg.cholesky(torch.mm(ABinv.t(), ABinv)).t()
# pyre-fixme[4]: Attribute must be annotated.
self.triu_indx = torch.triu_indices(self.d, self.d, device=B.device)
Uvec = U[self.triu_indx.tolist()].repeat(*batch_shape, 1)
self.register_parameter(name="Uvec", parameter=torch.nn.Parameter(Uvec))
[docs] def forward(
self,
x1: Tensor,
x2: Tensor,
diag: bool = False,
last_dim_is_batch: bool = False,
**params: Any,
) -> Tensor:
"""Compute kernel distance."""
# Unpack Uvec into an upper triangular matrix U
shapeU = self.Uvec.shape[:-1] + torch.Size([self.d, self.d])
U_t = torch.zeros(shapeU, dtype=self.B.dtype, device=self.B.device)
U_t[..., self.triu_indx[1], self.triu_indx[0]] = self.Uvec
# Compute kernel distance
z1 = torch.matmul(x1, U_t)
z2 = torch.matmul(x2, U_t)
return self.covar_dist(
z1,
z2,
square_dist=True,
diag=diag,
dist_postprocess_func=postprocess_rbf,
postprocess=True,
**params,
)
[docs]class ALEBOGP(FixedNoiseGP):
"""The GP for ALEBO.
Uses the Mahalanobis kernel defined in ALEBOKernel, along with a
ScaleKernel to add a kernel variance and a fitted constant mean.
In non-batch mode, there is a single kernel that produces MVN predictions
as usual for a GP.
With b batches, each batch has its own set of kernel hyperparameters and
each batch represents a sample from the hyperparameter posterior
distribution. When making a prediction (with `__call__`), these samples are
integrated over using moment matching. So, the predictions are an MVN as
usual with the same shape as in non-batch mode.
Args:
B: (d x D) Projection matrix.
train_X: (n x d) X training data.
train_Y: (n x 1) Y training data.
train_Yvar: (n x 1) Noise variances of each training Y.
"""
def __init__(
self, B: Tensor, train_X: Tensor, train_Y: Tensor, train_Yvar: Tensor
) -> None:
super().__init__(train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar)
self.covar_module = ScaleKernel(
base_kernel=ALEBOKernel(B=B, batch_shape=self._aug_batch_shape),
batch_shape=self._aug_batch_shape,
)
self.to(train_X)
def __call__(self, x: Tensor) -> MultivariateNormal:
"""
If model is non-batch, then just make a prediction. If model has
multiple batches, then these are samples from the kernel hyperparameter
posterior and we integrate over them with moment matching.
The shape of the MVN that this outputs will be the same regardless of
whether the model is batched or not.
Args:
x: Point to be predicted.
Returns: MultivariateNormal distribution of prediction.
"""
if len(self._aug_batch_shape) == 0:
return super().__call__(x)
# Else, approximately integrate over batches with moment matching.
# Take X as (b) x q x d, and expand to (b) x ns x q x d
if x.ndim > 3:
raise ValueError("Don't know how to predict this shape") # pragma: no cover
x = x.unsqueeze(-3).expand(
x.shape[:-2]
+ torch.Size([self._aug_batch_shape[0]]) # pyre-ignore
# pyre-fixme[58]: `+` is not supported for operand types `Tuple[int,
# ...]` and `Size`.
+ x.shape[-2:]
)
mvn_b = super().__call__(x)
mu = mvn_b.mean.mean(dim=-2)
C = (
mvn_b.covariance_matrix.mean(dim=-3)
+ torch.matmul(mvn_b.mean.transpose(-2, -1), mvn_b.mean)
/ mvn_b.mean.shape[-2]
- torch.matmul(mu.unsqueeze(-1), mu.unsqueeze(-2))
) # Law of Total Covariance
mvn = MultivariateNormal(mu, C)
return mvn
[docs] def posterior(
self,
X: Tensor,
output_indices: Optional[List[int]] = None,
observation_noise: Union[bool, Tensor] = False,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs: Any,
) -> GPyTorchPosterior:
assert output_indices is None
assert not observation_noise
mvn = self(X)
posterior = GPyTorchPosterior(mvn=mvn)
if posterior_transform is not None:
return posterior_transform(posterior)
return posterior
[docs]def get_fitted_model(
B: Tensor,
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Tensor,
restarts: int,
nsamp: int,
init_state_dict: Optional[Dict[str, Tensor]],
) -> ALEBOGP:
"""Get a fitted ALEBO GP.
We do random restart optimization to get a MAP model, then use the Laplace
approximation to draw posterior samples of kernel hyperparameters, and
finally construct a batch-mode model where each batch is one of those
sampled sets of kernel hyperparameters.
Args:
B: Projection matrix.
train_X: X training data.
train_Y: Y training data.
train_Yvar: Noise variances of each training Y.
restarts: Number of restarts for MAP estimation.
nsamp: Number of samples to draw from kernel hyperparameter posterior.
init_state_dict: Optionally begin MAP estimation with this state dict.
Returns: Batch-mode (nsamp batches) fitted ALEBO GP.
"""
# Get MAP estimate.
mll = get_map_model(
B=B,
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
restarts=restarts,
init_state_dict=init_state_dict,
)
# Compute Laplace approximation of posterior
Uvec_batch, mean_constant_batch, output_scale_batch = laplace_sample_U(
mll=mll, nsamp=nsamp
)
# Construct batch model with samples
m_b = get_batch_model(
B=B,
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
Uvec_batch=Uvec_batch,
mean_constant_batch=mean_constant_batch,
output_scale_batch=output_scale_batch,
)
return m_b
[docs]def get_map_model(
B: Tensor,
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Tensor,
restarts: int,
init_state_dict: Optional[Dict[str, Tensor]],
) -> ExactMarginalLogLikelihood:
"""Do random-restart optimization for MAP fitting of an ALEBO GP model.
Args:
B: Projection matrix.
train_X: X training data.
train_Y: Y training data.
train_Yvar: Noise variances of each training Y.
restarts: Number of restarts for MAP estimation.
init_state_dict: Optionally begin MAP estimation with this state dict.
Returns: non-batch ALEBO GP with MAP kernel hyperparameters.
"""
f_best = 1e8
sd_best = {}
# Fit with random restarts
for _ in range(restarts):
m = ALEBOGP(B=B, train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar)
if init_state_dict is not None:
m.load_state_dict(init_state_dict)
mll = ExactMarginalLogLikelihood(m.likelihood, m)
mll.train()
mll, info_dict = fit_gpytorch_scipy(mll, track_iterations=False, method="tnc")
logger.debug(info_dict)
# pyre-fixme[58]: `<` is not supported for operand types
# `Union[List[botorch.optim.fit.OptimizationIteration], float]` and `float`.
if info_dict["fopt"] < f_best:
f_best = float(info_dict["fopt"]) # pyre-ignore
sd_best = m.state_dict()
# Set the final value
m = ALEBOGP(B=B, train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar)
m.load_state_dict(sd_best)
mll = ExactMarginalLogLikelihood(m.likelihood, m)
return mll
[docs]def laplace_sample_U(
mll: ExactMarginalLogLikelihood, nsamp: int
) -> Tuple[Tensor, Tensor, Tensor]:
"""Draw posterior samples of kernel hyperparameters using Laplace
approximation.
Only the Mahalanobis distance matrix is sampled.
The diagonal of the Hessian is estimated using finite differences of the
autograd gradients. The Laplace approximation is then N(p_map, inv(-H)).
We construct a set of nsamp kernel hyperparameters by drawing nsamp-1
values from this distribution, and prepending as the first sample the MAP
parameters.
Args:
mll: MLL object of MAP ALEBO GP.
nsamp: Number of samples to return.
Returns: Batch tensors of the kernel hyperparameters Uvec, mean constant,
and output scale.
"""
# Estimate diagonal of the Hessian
mll.train()
x0, property_dict, bounds = module_to_array(module=mll)
x0 = x0.astype(np.float64) # This is the MAP parameters
H = np.zeros((len(x0), len(x0)))
epsilon = 1e-4 + 1e-3 * np.abs(x0)
for i, _ in enumerate(x0):
# Compute gradient of df/dx_i wrt x_i
# pyre-fixme[53]: Captured variable `property_dict` is not annotated.
# pyre-fixme[53]: Captured variable `x0` is not annotated.
# pyre-fixme[53]: Captured variable `i` is not annotated.
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def f(x):
x_all = x0.copy()
x_all[i] = x[0]
return -_scipy_objective_and_grad(x_all, mll, property_dict)[1][i]
H[i, i] = approx_fprime(np.array([x0[i]]), f, epsilon=epsilon[i]) # pyre-ignore
# Sample only Uvec; leave mean and output scale fixed.
assert list(property_dict.keys()) == [
"model.mean_module.raw_constant",
"model.covar_module.raw_outputscale",
"model.covar_module.base_kernel.Uvec",
]
H = H[2:, 2:]
H += np.diag(-1e-3 * np.ones(H.shape[0])) # Add a nugget for inverse stability
Sigma = np.linalg.inv(-H)
samples = np.random.multivariate_normal(mean=x0[2:], cov=Sigma, size=(nsamp - 1))
# Include the MAP estimate
samples = np.vstack((x0[2:], samples))
# Reshape
attrs = property_dict["model.covar_module.base_kernel.Uvec"]
Uvec_batch = torch.tensor(samples, dtype=attrs.dtype, device=attrs.device).reshape(
nsamp, *attrs.shape
)
# Get the other properties into batch mode
mean_constant_batch = mll.model.mean_module.constant.repeat(nsamp)
output_scale_batch = mll.model.covar_module.raw_outputscale.repeat(nsamp)
return Uvec_batch, mean_constant_batch, output_scale_batch
[docs]def get_batch_model(
B: Tensor,
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Tensor,
Uvec_batch: Tensor,
mean_constant_batch: Tensor,
output_scale_batch: Tensor,
) -> ALEBOGP:
"""Construct a batch-mode ALEBO GP using batch tensors of hyperparameters.
Args:
B: Projection matrix.
train_X: X training data.
train_Y: Y training data.
train_Yvar: Noise variances of each training Y.
Uvec_batch: Batch tensor of Uvec hyperparameters.
mean_constant_batch: Batch tensor of mean constant hyperparameter.
output_scale_batch: Batch tensor of output scale hyperparameter.
Returns: Batch-mode ALEBO GP.
"""
b = Uvec_batch.size(0)
m_b = ALEBOGP(
B=B,
train_X=train_X.repeat(b, 1, 1),
train_Y=train_Y.repeat(b, 1, 1),
train_Yvar=train_Yvar.repeat(b, 1, 1),
)
m_b.train()
# Set mean constant
# pyre-fixme[16]: `Optional` has no attribute `raw_constant`.
m_b.mean_module.raw_constant.requires_grad_(False)
m_b.mean_module.raw_constant.copy_(mean_constant_batch)
m_b.mean_module.raw_constant.requires_grad_(True)
# Set output scale
m_b.covar_module.raw_outputscale.requires_grad_(False)
m_b.covar_module.raw_outputscale.copy_(output_scale_batch)
m_b.covar_module.raw_outputscale.requires_grad_(True)
# Set Uvec
m_b.covar_module.base_kernel.Uvec.requires_grad_(False)
m_b.covar_module.base_kernel.Uvec.copy_(Uvec_batch)
m_b.covar_module.base_kernel.Uvec.requires_grad_(True)
m_b.eval()
return m_b
[docs]def ei_or_nei(
model: Union[ALEBOGP, ModelListGP],
objective_weights: Tensor,
outcome_constraints: Optional[Tuple[Tensor, Tensor]],
X_observed: Tensor,
X_pending: Optional[Tensor],
q: int,
noiseless: bool,
) -> AcquisitionFunction:
"""Use analytic EI if appropriate, otherwise Monte Carlo NEI.
Analytic EI can be used if: Single outcome, no constraints, no pending
points, not batch, and no noise.
Args:
model: GP.
objective_weights: Weights on each outcome for the objective.
outcome_constraints: Outcome constraints.
X_observed: Observed points for NEI.
X_pending: Pending points.
q: Batch size.
noiseless: True if evaluations are noiseless.
Returns: An AcquisitionFunction, either analytic EI or MC NEI.
"""
if (
len(objective_weights) == 1
and outcome_constraints is None
and X_pending is None
and q == 1
and noiseless
):
maximize = objective_weights[0] > 0
if maximize:
best_f = model.train_targets.max()
else:
best_f = model.train_targets.min()
# pyre-fixme[6]: For 3rd param expected `bool` but got `Tensor`.
return ExpectedImprovement(model=model, best_f=best_f, maximize=maximize)
else:
with gpytorch.settings.max_cholesky_size(2000):
acq = get_NEI(
model=model,
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
X_observed=X_observed,
X_pending=X_pending,
)
return acq
[docs]def alebo_acqf_optimizer(
acq_function: AcquisitionFunction,
bounds: Tensor,
n: int,
inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]],
fixed_features: Optional[Dict[int, float]],
rounding_func: Optional[Callable[[Tensor], Tensor]],
raw_samples: int,
num_restarts: int,
B: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Optimize the acquisition function for ALEBO.
We are optimizing over a polytope within the subspace, and so begin each
random restart of the acquisition function optimization with points that
lie within that polytope.
"""
candidate_list, acq_value_list = [], []
candidates = torch.tensor([], device=B.device, dtype=B.dtype)
try:
base_X_pending = acq_function.X_pending
acq_has_X_pend = True
except AttributeError:
base_X_pending = None
acq_has_X_pend = False
assert n == 1
for i in range(n):
# Generate initial points for optimization inside embedding
m_init = ALEBOInitializer(B.cpu().numpy(), nsamp=10 * raw_samples)
Xrnd_npy, _ = m_init.gen(n=raw_samples, bounds=[(-1.0, 1.0)] * B.shape[1])
Xrnd = torch.tensor(Xrnd_npy, dtype=B.dtype, device=B.device).unsqueeze(1)
Yrnd = torch.matmul(Xrnd, B.t()) # Project down to the embedding
with gpytorch.settings.max_cholesky_size(2000):
with torch.no_grad():
alpha = acq_function(Yrnd)
Yinit = initialize_q_batch_nonneg(X=Yrnd, Y=alpha, n=num_restarts)
inf_bounds = ( # all constraints are encoded via inequality_constraints
torch.tensor([[-float("inf")], [float("inf")]])
.expand(2, Yrnd.shape[-1])
.to(Yrnd)
)
# Optimize the acquisition function, separately for each random restart.
candidate, acq_value = optimize_acqf(
acq_function=acq_function,
bounds=inf_bounds,
q=1,
num_restarts=num_restarts,
raw_samples=0,
options={"method": "SLSQP", "batch_limit": 1},
inequality_constraints=inequality_constraints,
batch_initial_conditions=Yinit,
sequential=False,
)
candidate_list.append(candidate)
acq_value_list.append(acq_value)
candidates = torch.cat(candidate_list, dim=-2)
if acq_has_X_pend:
acq_function.set_X_pending(
# pyre-fixme[6]: Expected `Union[List[Tensor],
# typing.Tuple[Tensor, ...]]` for 1st param but got
# `List[Union[Tensor, torch.nn.Module]]`.
torch.cat([base_X_pending, candidates], dim=-2)
if base_X_pending is not None
else candidates
)
logger.info(f"Generated sequential candidate {i+1} of {n}")
if acq_has_X_pend:
# pyre-fixme[6]: Expected `Optional[Tensor]` for 1st param but got
# `Union[None, Tensor, torch.nn.Module]`.
acq_function.set_X_pending(base_X_pending)
return candidates, torch.stack(acq_value_list)
[docs]class ALEBO(BotorchModel):
"""Does Bayesian optimization in a linear subspace with ALEBO.
The (d x D) projection down matrix B must be provided, and must be that
used for the initialization.
Function evaluations happen in the high-D space. We only evaluate points
such that x = pinverse(B) @ B @ x (that is, points inside the subspace).
Under that constraint, the projection is invertible.
Args:
B: (d x D) projection matrix (projects down).
laplace_nsamp: Number of samples for posterior sampling of kernel
hyperparameters.
fit_restarts: Number of random restarts for MAP estimation.
"""
def __init__(
self, B: Tensor, laplace_nsamp: int = 25, fit_restarts: int = 10
) -> None:
self.B = B
# pyre-fixme[4]: Attribute must be annotated.
self.Binv = torch.pinverse(B)
self.laplace_nsamp = laplace_nsamp
self.fit_restarts = fit_restarts
super().__init__(
refit_on_update=True, # Important to not get stuck in local opt.
refit_on_cv=False,
warm_start_refitting=False,
acqf_constructor=ei_or_nei, # pyre-ignore
# pyre-fixme[6]: Expected `(AcquisitionFunction, Tensor, int, Optional[Li...
acqf_optimizer=alebo_acqf_optimizer,
)
[docs] @copy_doc(TorchModel.fit)
def fit(
self,
datasets: List[SupervisedDataset],
metric_names: List[str],
search_space_digest: SearchSpaceDigest,
candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None,
) -> None:
Xs, Ys, Yvars = _datasets_to_legacy_inputs(datasets=datasets)
assert len(search_space_digest.task_features) == 0
assert len(search_space_digest.fidelity_features) == 0
for b in search_space_digest.bounds:
assert b == (-1, 1)
# GP is fit in the low-d space, so project Xs down.
self.Xs = [(self.B @ X.t()).t() for X in Xs]
self.Ys = Ys
self.Yvars = Yvars
self.device = self.B.device
self.dtype = self.B.dtype
self.model = self.get_and_fit_model(Xs=self.Xs, Ys=self.Ys, Yvars=self.Yvars)
[docs] @copy_doc(TorchModel.predict)
def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
Xd = (self.B @ X.t()).t() # Project down
with gpytorch.settings.max_cholesky_size(2000):
return super().predict(X=Xd)
[docs] @copy_doc(TorchModel.best_point)
def best_point(
self,
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
) -> Optional[Tensor]:
raise NotImplementedError
[docs] def gen(
self,
n: int,
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
) -> TorchGenResults:
"""Generate candidates.
Candidates are generated in the linear embedding with the polytope
constraints described in the paper.
model_gen_options can contain 'raw_samples' (number of samples used for
initializing the acquisition function optimization) and 'num_restarts'
(number of restarts for acquisition function optimization).
"""
for b in search_space_digest.bounds:
assert b == (-1, 1)
# The following can be easily handled in the future when needed
assert torch_opt_config.linear_constraints is None
assert torch_opt_config.fixed_features is None
assert torch_opt_config.pending_observations is None
# Setup constraints
A = torch.cat((self.Binv, -self.Binv))
b = torch.ones(2 * self.Binv.shape[0], 1, dtype=self.dtype, device=self.device)
linear_constraints = (A, b)
noiseless = max(Yvar.min().item() for Yvar in self.Yvars) < 1e-5
model_gen_options = {
"acquisition_function_kwargs": {"q": n, "noiseless": noiseless},
"optimizer_kwargs": {
"raw_samples": torch_opt_config.model_gen_options.get(
"raw_samples", 1000
),
"num_restarts": torch_opt_config.model_gen_options.get(
"num_restarts", 10
),
"B": self.B,
},
}
gen_results = super().gen(
n=n,
search_space_digest=dataclasses.replace(
search_space_digest,
bounds=[(-1e8, 1e8)] * self.B.shape[0],
),
torch_opt_config=dataclasses.replace(
torch_opt_config,
linear_constraints=linear_constraints,
model_gen_options=model_gen_options,
),
)
# Project up
Xopt = (self.Binv @ gen_results.points.t()).t()
# Sometimes numerical tolerance can have Xopt epsilon outside [-1, 1],
# so clip it back.
if Xopt.min() < -1 or Xopt.max() > 1:
logger.debug(f"Clipping from [{Xopt.min()}, {Xopt.max()}]")
Xopt = torch.clamp(Xopt, min=-1.0, max=1.0)
return dataclasses.replace(gen_results, points=Xopt)
[docs] @copy_doc(TorchModel.update)
def update(
self,
datasets: List[SupervisedDataset],
candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None,
**kwargs: Any,
) -> None:
if self.model is None:
raise RuntimeError(
"Cannot update model that has not been fit"
) # pragma: no cover
Xs, Ys, Yvars = _datasets_to_legacy_inputs(datasets=datasets)
self.Xs = [(self.B @ X.t()).t() for X in Xs] # Project down.
self.Ys = Ys
self.Yvars = Yvars
if self.refit_on_update:
state_dicts = None
else:
state_dicts = extract_map_statedict(
m_b=self.model, num_outputs=len(Xs) # pyre-ignore
)
self.model = self.get_and_fit_model(
Xs=self.Xs, Ys=self.Ys, Yvars=self.Yvars, state_dicts=state_dicts
)
[docs] @copy_doc(TorchModel.cross_validate)
def cross_validate(
self,
datasets: List[SupervisedDataset],
X_test: Tensor,
**kwargs: Any,
) -> Tuple[Tensor, Tensor]:
if self.model is None:
raise RuntimeError(
"Cannot cross-validate model that has not been fit"
) # pragma: no cover
if self.refit_on_cv:
state_dicts = None
else:
state_dicts = extract_map_statedict(
m_b=self.model, num_outputs=len(self.Xs) # pyre-ignore
)
Xs, Ys, Yvars = _datasets_to_legacy_inputs(datasets=datasets)
Xs = [X @ self.B.t() for X in Xs] # Project down.
X_test = X_test @ self.B.t()
model = self.get_and_fit_model(
Xs=Xs, Ys=Ys, Yvars=Yvars, state_dicts=state_dicts
)
return self.model_predictor(model=model, X=X_test) # pyre-ignore: [28]
[docs] def get_and_fit_model(
self,
Xs: List[Tensor],
Ys: List[Tensor],
Yvars: List[Tensor],
state_dicts: Optional[List[MutableMapping[str, Tensor]]] = None,
) -> GPyTorchModel:
"""Get a fitted ALEBO model for each outcome.
Args:
Xs: X for each outcome, already projected down.
Ys: Y for each outcome.
Yvars: Noise variance of Y for each outcome.
state_dicts: State dicts to initialize model fitting.
Returns: Fitted ALEBO model.
"""
if state_dicts is None:
state_dicts = [None] * len(Xs)
fit_restarts = self.fit_restarts
else:
fit_restarts = 1 # Warm-started
Yvars = [Yvar.clamp_min_(1e-7) for Yvar in Yvars]
models = [
get_fitted_model(
B=self.B,
train_X=X,
train_Y=Ys[i],
train_Yvar=Yvars[i],
restarts=fit_restarts,
nsamp=self.laplace_nsamp,
# pyre-fixme[6]: Expected `Optional[Dict[str, Tensor]]` for 7th
# param but got `Optional[MutableMapping[str, Tensor]]`.
init_state_dict=state_dicts[i],
)
for i, X in enumerate(Xs)
]
if len(models) == 1:
model = models[0]
else:
model = ModelListGP(*models)
model.to(Xs[0])
return model