Source code for ax.utils.sensitivity.derivative_gp

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import math

import torch
from botorch.models.model import Model
from gpytorch.distributions import MultivariateNormal
from torch import Tensor


[docs] def get_KXX_inv(gp: Model) -> Tensor: r"""Get the inverse matrix of K(X,X). Args: gp: Botorch model. Returns: The inverse of K(X,X). """ L_inv_upper = gp.prediction_strategy.covar_cache.detach() return L_inv_upper @ L_inv_upper.transpose(0, 1)
[docs] def get_KxX_dx(gp: Model, x: Tensor, kernel_type: str = "rbf") -> Tensor: """Computes the analytic derivative of the kernel K(x,X) w.r.t. x. Args: gp: Botorch model. x: (n x D) Test points. kernel_type: Takes "rbf" or "matern" Returns: Tensor (n x D) The derivative of the kernel K(x,X) w.r.t. x. """ X = gp.train_inputs[0] D = X.shape[1] N = X.shape[0] n = x.shape[0] if hasattr(gp.covar_module, "outputscale"): lengthscale = gp.covar_module.base_kernel.lengthscale.detach() sigma_f = gp.covar_module.outputscale.detach() else: lengthscale = gp.covar_module.lengthscale.detach() sigma_f = 1.0 if kernel_type == "rbf": K_xX = gp.covar_module(x, X).evaluate() part1 = -torch.eye(D, device=x.device, dtype=x.dtype) / lengthscale**2 part2 = x.view(n, 1, D) - X.view(1, N, D) return part1 @ (part2 * K_xX.view(n, N, 1)).transpose(1, 2) # Else, we have a Matern kernel mean = x.reshape(-1, x.size(-1)).mean(0)[(None,) * (x.dim() - 1)] x1_ = (x - mean).div(lengthscale) x2_ = (X - mean).div(lengthscale) distance = gp.covar_module.covar_dist(x1_, x2_) exp_component = torch.exp(-math.sqrt(5.0) * distance) # pyre-ignore constant_component = (-5.0 / 3.0) * distance - (5.0 * math.sqrt(5.0) / 3.0) * ( distance**2 ) part1 = torch.eye(D, device=lengthscale.device) / lengthscale part2 = (x1_.view(n, 1, D) - x2_.view(1, N, D)) / distance.unsqueeze(2) total_k = sigma_f * constant_component * exp_component total = part1 @ (part2 * total_k.view(n, N, 1)).transpose(1, 2) return total
[docs] def get_Kxx_dx2(gp: Model, kernel_type: str = "rbf") -> Tensor: r"""Computes the analytic second derivative of the kernel w.r.t. the training data Args: gp: Botorch model. kernel_type: Takes "rbf" or "matern" Returns: Tensor (n x D x D) The second derivative of the kernel w.r.t. the training data. """ X = gp.train_inputs[0] D = X.shape[1] if hasattr(gp.covar_module, "outputscale"): lengthscale = gp.covar_module.base_kernel.lengthscale.detach() sigma_f = gp.covar_module.outputscale.detach() else: lengthscale = gp.covar_module.lengthscale.detach() sigma_f = 1.0 res = (torch.eye(D, device=lengthscale.device) / lengthscale**2) * sigma_f if kernel_type == "rbf": return res return res * (5 / 3)
[docs] def posterior_derivative( gp: Model, x: Tensor, kernel_type: str = "rbf" ) -> MultivariateNormal: r"""Computes the posterior of the derivative of the GP w.r.t. the given test points x. This follows the derivation used by GIBO in Sarah Muller, Alexander von Rohr, Sebastian Trimpe. "Local policy search with Bayesian optimization", Advances in Neural Information Processing Systems 34, NeurIPS 2021. Args: gp: Botorch model x: (n x D) Test points. kernel_type: Takes "rbf" or "matern" Returns: A Botorch Posterior. """ if gp.prediction_strategy is None: gp.posterior(x) # Call this to update prediction strategy of GPyTorch. if kernel_type not in ["rbf", "matern"]: raise ValueError("only matern and rbf kernels are supported") K_xX_dx = get_KxX_dx(gp, x, kernel_type=kernel_type) Kxx_dx2 = get_Kxx_dx2(gp, kernel_type=kernel_type) mean_d = ( K_xX_dx @ get_KXX_inv(gp) @ (gp.train_targets - gp.mean_module(gp.train_inputs[0])) ) variance_d = Kxx_dx2 - K_xX_dx @ get_KXX_inv(gp) @ K_xX_dx.transpose(1, 2) variance_d = variance_d.clamp_min(1e-9) try: return MultivariateNormal(mean_d, variance_d) except RuntimeError: variance_d_diag = torch.diagonal(variance_d, offset=0, dim1=1, dim2=2) variance_d_new = torch.diag_embed(variance_d_diag) return MultivariateNormal(mean_d, variance_d_new)