Source code for ax.utils.sensitivity.derivative_gp

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
import warnings

import torch
from botorch.models.model import Model
from gpytorch.distributions import MultivariateNormal
from torch import Tensor


[docs]def get_KXX_inv(gp: Model) -> Tensor: r"""Get the inverse matrix of K(X,X). Args: gp: Botorch model. Returns: The inverse of K(X,X). """ L_inv_upper = gp.prediction_strategy.covar_cache.detach() # pyre-ignore return L_inv_upper @ L_inv_upper.transpose(0, 1)
[docs]def get_KxX_dx(gp: Model, x: Tensor, kernel_type: str = "rbf") -> Tensor: """Computes the analytic derivative of the kernel K(x,X) w.r.t. x. Args: gp: Botorch model. x: (n x D) Test points. kernel_type: Takes "rbf" or "matern_l1" or "matern_l2" Returns: Tensor (n x D) The derivative of the kernel K(x,X) w.r.t. x. """ X = gp.train_inputs[0] # pyre-ignore D = X.shape[1] N = X.shape[0] n = x.shape[0] lengthscale = gp.covar_module.base_kernel.lengthscale.detach() # pyre-ignore if kernel_type == "rbf": K_xX = gp.covar_module(x, X).evaluate() # pyre-ignore part1 = -torch.eye(D, device=x.device, dtype=x.dtype) / lengthscale**2 part2 = x.view(n, 1, D) - X.view(1, N, D) return part1 @ (part2 * K_xX.view(n, N, 1)).transpose(1, 2) # Else, we have a Matern kernel, either L1 or L2 mean = x.reshape(-1, x.size(-1)).mean(0)[(None,) * (x.dim() - 1)] x1_ = (x - mean).div(lengthscale) x2_ = (X - mean).div(lengthscale) matern_norml2 = kernel_type == "matern_l2" distance = gp.covar_module.covar_dist( # pyre-ignore x1_, x2_, square_dist=matern_norml2 ) exp_component = torch.exp(-math.sqrt(5.0) * distance) # pyre-ignore constant_component = (-5.0 / 3.0) * distance - (5.0 * math.sqrt(5.0) / 3.0) * ( distance**2 ) sigma_f = gp.covar_module.outputscale.detach() # pyre-ignore if matern_norml2: part1 = torch.eye(D, device=lengthscale.device) / lengthscale**2 part2 = 2 * (x.view(n, 1, D) - X.view(1, N, D)) else: part1 = torch.eye(D, device=lengthscale.device) / lengthscale part2 = (x1_.view(n, 1, D) - x2_.view(1, N, D)) / distance.unsqueeze(2) total_k = sigma_f * constant_component * exp_component total = part1 @ (part2 * total_k.view(n, N, 1)).transpose(1, 2) return total
[docs]def get_Kxx_dx2(gp: Model, kernel_type: str = "rbf") -> Tensor: r"""Computes the analytic second derivative of the kernel w.r.t. the training data Args: gp: Botorch model. kernel_type: Takes "rbf" or "matern_l1" or "matern_l2" Returns: Tensor (n x D x D) The second derivative of the kernel w.r.t. the training data. """ X = gp.train_inputs[0] # pyre-ignore D = X.shape[1] lengthscale = gp.covar_module.base_kernel.lengthscale.detach() # pyre-ignore if kernel_type == "rbf": sigma_f = gp.covar_module.outputscale.detach() # pyre-ignore return (torch.eye(D, device=lengthscale.device) / lengthscale**2) * sigma_f if kernel_type == "matern_l2": return torch.zeros(D, D, device=lengthscale.device) warnings.warn("second derivative of Matern undefined when x1==x2") return torch.eye(D, device=lengthscale.device) * 1e10
[docs]def posterior_derivative( gp: Model, x: Tensor, kernel_type: str = "rbf" ) -> MultivariateNormal: r"""Computes the posterior of the derivative of the GP w.r.t. the given test points x. This follows the derivation used by GIBO in Sarah Muller, Alexander von Rohr, Sebastian Trimpe. "Local policy search with Bayesian optimization", Advances in Neural Information Processing Systems 34, NeurIPS 2021. Args: gp: Botorch model x: (n x D) Test points. kernel_type: Takes "rbf" or "matern_l1" or "matern_l2" Returns: A Botorch Posterior. """ if gp.prediction_strategy is None: gp.posterior(x) # Call this to update prediction strategy of GPyTorch. if kernel_type not in ["rbf", "matern_l1", "matern_l2"]: raise ValueError("only matern and rbf kernels are supported") K_xX_dx = get_KxX_dx(gp, x, kernel_type=kernel_type) Kxx_dx2 = get_Kxx_dx2(gp, kernel_type=kernel_type) mean_d = K_xX_dx @ get_KXX_inv(gp) @ gp.train_targets variance_d = Kxx_dx2 - K_xX_dx @ get_KXX_inv(gp) @ K_xX_dx.transpose(1, 2) variance_d = variance_d.clamp_min(1e-9) try: return MultivariateNormal(mean_d, variance_d) except RuntimeError: variance_d_diag = torch.diagonal(variance_d, offset=0, dim1=1, dim2=2) variance_d_new = torch.diag_embed(variance_d_diag) return MultivariateNormal(mean_d, variance_d_new)