Source code for ax.utils.sensitivity.derivative_measures

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from copy import deepcopy
from typing import Callable, List, Optional, Union

import torch
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.sensitivity.derivative_gp import posterior_derivative
from botorch.models.model import Model
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.sampling import draw_sobol_samples
from botorch.utils.transforms import unnormalize
from gpytorch.distributions import MultivariateNormal


[docs]class GpDGSMGpMean(object): mean_gradients: Optional[torch.Tensor] = None bootstrap_indices: Optional[torch.Tensor] = None mean_gradients_btsp: Optional[List[torch.Tensor]] = None def __init__( self, model: Model, bounds: torch.Tensor, derivative_gp: bool = False, kernel_type: Optional[str] = None, Y_scale: float = 1.0, num_mc_samples: int = 10**4, input_qmc: bool = False, dtype: torch.dtype = torch.double, num_bootstrap_samples: int = 1, ) -> None: r"""Computes three types of derivative based measures: the gradient, the gradient square and the gradient absolute measures. Args: model: A BoTorch model. bounds: Parameter bounds over which to evaluate model sensitivity. derivative_gp: If true, the derivative of the GP is used to compute the gradient instead of backward. If `kernel_type` is matern_l1, only the mean function of derivative GP can be used, and the variance is not defined. kernel_type: Takes "rbf" or "matern_l1" or "matern_l2", set only if `derivative_gp` is true. Y_scale: Scale the derivatives by this amount, to undo scaling done on the training data. num_mc_samples: The number of MonteCarlo grid samples input_qmc: If True, a qmc Sobol grid is use instead of uniformly random. dtype: Can be provided if the GP is fit to data of type `torch.float`. num_bootstrap_samples: If higher than 1, the method will compute the dgsm measure `num_bootstrap_samples` times by selecting subsamples from the `input_mc_samples` and return the variance and standard error across all computed measures. """ # pyre-fixme[4]: Attribute must be annotated. self.dim = checked_cast(tuple, model.train_inputs)[0].shape[-1] self.derivative_gp = derivative_gp self.kernel_type = kernel_type # pyre-fixme[4]: Attribute must be annotated. self.bootstrap = num_bootstrap_samples > 1 # pyre-fixme[4]: Attribute must be annotated. self.num_bootstrap_samples = ( num_bootstrap_samples - 1 ) # deduct 1 because the first is meant to be the full grid if self.derivative_gp and (self.kernel_type is None): raise ValueError("Kernel type has to be specified to use derivative GP") self.num_mc_samples = num_mc_samples if input_qmc: # pyre-fixme[4]: Attribute must be annotated. self.input_mc_samples = ( draw_sobol_samples(bounds=bounds, n=num_mc_samples, q=1) .squeeze(1) .to(dtype) ) else: self.input_mc_samples = unnormalize( torch.rand(num_mc_samples, self.dim, dtype=dtype), bounds=bounds, ) if self.derivative_gp: posterior = posterior_derivative( model, self.input_mc_samples, not_none(self.kernel_type) ) else: self.input_mc_samples.requires_grad = True posterior = checked_cast( GPyTorchPosterior, model.posterior(self.input_mc_samples) ) self._compute_gradient_quantities(posterior, Y_scale) def _compute_gradient_quantities( self, posterior: Union[GPyTorchPosterior, MultivariateNormal], Y_scale: float ) -> None: if self.derivative_gp: self.mean_gradients = checked_cast(torch.Tensor, posterior.mean) * Y_scale else: predictive_mean = posterior.mean torch.sum(predictive_mean).backward() self.mean_gradients = ( checked_cast(torch.Tensor, self.input_mc_samples.grad) * Y_scale ) if self.bootstrap: subset_size = 2 self.bootstrap_indices = torch.randint( 0, self.num_mc_samples, (self.num_bootstrap_samples, subset_size) ) self.mean_gradients_btsp = [ torch.index_select( checked_cast(torch.Tensor, self.mean_gradients), 0, indices ) for indices in self.bootstrap_indices ]
[docs] def aggregation( self, transform_fun: Callable[[torch.Tensor], torch.Tensor] ) -> torch.Tensor: gradients_measure = torch.tensor( [ torch.mean(transform_fun(not_none(self.mean_gradients)[:, i])) for i in range(self.dim) ] ) if not (self.bootstrap): return gradients_measure else: gradients_measures_btsp = [gradients_measure.unsqueeze(0)] for b in range(self.num_bootstrap_samples): gradients_measures_btsp.append( torch.tensor( [ torch.mean( transform_fun( not_none(self.mean_gradients_btsp)[b][:, i] ) ) for i in range(self.dim) ] ).unsqueeze(0) ) gradients_measures_btsp = torch.cat(gradients_measures_btsp, dim=0) return ( torch.cat( [ gradients_measures_btsp.mean(dim=0).unsqueeze(0), gradients_measures_btsp.var(dim=0).unsqueeze(0), torch.sqrt( gradients_measures_btsp.var(dim=0) / (self.num_bootstrap_samples + 1) ).unsqueeze(0), ], dim=0, ) .t() .detach() )
[docs] def gradient_measure(self) -> torch.Tensor: r"""Computes the gradient measure: Returns: if `self.num_bootstrap_samples > 1` Tensor: (values, var_mc, stderr_mc) x dim else Tensor: (values) x dim """ return self.aggregation(torch.tensor)
[docs] def gradient_absolute_measure(self) -> torch.Tensor: r"""Computes the gradient absolute measure: Returns: if `self.num_bootstrap_samples > 1` Tensor: (values, var_mc, stderr_mc) x dim else Tensor: (values) x dim """ return self.aggregation(torch.abs)
[docs] def gradients_square_measure(self) -> torch.Tensor: r"""Computes the gradient square measure: Returns: if `num_bootstrap_samples > 1` Tensor: (values, var_mc, stderr_mc) x dim else Tensor: (values) x dim """ return self.aggregation(torch.square)
[docs]class GpDGSMGpSampling(GpDGSMGpMean): samples_gradients: Optional[torch.Tensor] = None samples_gradients_btsp: Optional[List[torch.Tensor]] = None def __init__( self, model: Model, bounds: torch.Tensor, num_gp_samples: int, derivative_gp: bool = False, kernel_type: Optional[str] = None, Y_scale: float = 1.0, num_mc_samples: int = 10**4, input_qmc: bool = False, gp_sample_qmc: bool = False, dtype: torch.dtype = torch.double, num_bootstrap_samples: int = 1, ) -> None: r"""Computes three types of derivative based measures: the gradient, the gradient square and the gradient absolute measures. Args: model: A BoTorch model. bounds: Parameter bounds over which to evaluate model sensitivity. num_gp_samples: If method is "GP samples", the number of GP samples has to be set. derivative_gp: If true, the derivative of the GP is used to compute the gradient instead of backward. If `kernel_type` is matern_l1, `derivative_gp` should be False because the variance is not defined. kernel_type: Takes "rbf" or "matern_l1" or "matern_l2", set only if `derivative_gp` is true. Y_scale: Scale the derivatives by this amount, to undo scaling done on the training data. num_mc_samples: The number of Monte Carlo grid samples. input_qmc: If True, a qmc Sobol grid is used instead of uniformly random. gp_sample_qmc: If True, the posterior sampling is done using `SobolQMCNormalSampler`. dtype: Can be provided if the GP is fit to data of type `torch.float`. num_bootstrap_samples: If higher than 1, the method will compute the dgsm measure `num_bootstrap_samples` times by selecting subsamples from the `input_mc_samples` and return the variance and standard error across all computed measures. Returns values of gradient_measure, gradient_absolute_measure and gradients_square_measure change to the following: if `num_bootstrap_samples > 1`: Tensor: (values, var_gp, stderr_gp, var_mc, stderr_mc) x dim else Tensor: (values, var_gp, stderr_gp) x dim """ self.num_gp_samples = num_gp_samples self.gp_sample_qmc = gp_sample_qmc self.num_mc_samples = num_mc_samples super().__init__( model=model, bounds=bounds, derivative_gp=derivative_gp, kernel_type=kernel_type, Y_scale=Y_scale, num_mc_samples=num_mc_samples, input_qmc=input_qmc, dtype=dtype, num_bootstrap_samples=num_bootstrap_samples, ) def _compute_gradient_quantities( self, posterior: Union[Posterior, MultivariateNormal], Y_scale: float ) -> None: if self.gp_sample_qmc: sampler = SobolQMCNormalSampler( sample_shape=torch.Size([self.num_gp_samples]), seed=0 ) samples = sampler(posterior) else: samples = posterior.rsample(torch.Size([self.num_gp_samples])) if self.derivative_gp: self.samples_gradients = samples * Y_scale else: samples_gradients = [] for j in range(self.num_gp_samples): torch.sum(samples[j]).backward(retain_graph=True) samples_gradients.append( deepcopy(self.input_mc_samples.grad).unsqueeze(0) ) self.input_mc_samples.grad.data.zero_() self.samples_gradients = torch.cat(samples_gradients, dim=0) * Y_scale if self.bootstrap: subset_size = 2 self.bootstrap_indices = torch.randint( 0, self.num_mc_samples, (self.num_bootstrap_samples, subset_size) ) self.samples_gradients_btsp = [] for j in range(self.num_gp_samples): not_none(self.samples_gradients_btsp).append( torch.cat( [ torch.index_select( not_none(self.samples_gradients)[j], 0, indices ).unsqueeze(0) for indices in not_none(self.bootstrap_indices) ], dim=0, ) )
[docs] def aggregation( self, transform_fun: Callable[[torch.Tensor], torch.Tensor] ) -> torch.Tensor: gradients_measure_list = [] for j in range(self.num_gp_samples): gradients_measure_list.append( torch.tensor( [ torch.mean( transform_fun(not_none(self.samples_gradients)[j][:, i]) ) for i in range(self.dim) ] ).unsqueeze(0) ) gradients_measure_list = torch.cat(gradients_measure_list, dim=0) if not (self.bootstrap): gradients_measure_mean_var = [] for i in range(self.dim): gradients_measure_mean_var.append( torch.tensor( [ torch.mean(gradients_measure_list[:, i]), torch.var(gradients_measure_list[:, i]), torch.sqrt( torch.var(gradients_measure_list[:, i]) / self.num_gp_samples ), ] ).unsqueeze(0) ) gradients_measure_mean_var = torch.cat(gradients_measure_mean_var, dim=0) return gradients_measure_mean_var else: gradients_measure_list_btsp = [] for j in range(self.num_gp_samples): gradients_measure_btsp = [gradients_measure_list[j].unsqueeze(0)] + [ torch.tensor( [ torch.mean( transform_fun( not_none(self.samples_gradients_btsp)[j][b][:, i] ) ) for i in range(self.dim) ] ).unsqueeze(0) for b in range(self.num_bootstrap_samples) ] gradients_measure_list_btsp.append( torch.cat(gradients_measure_btsp, dim=0).unsqueeze(0) ) gradients_measure_list_btsp = torch.cat(gradients_measure_list_btsp, dim=0) var_per_bootstrap = torch.var(gradients_measure_list_btsp, dim=0) gp_var = torch.mean(var_per_bootstrap, dim=0) gp_se = torch.sqrt(gp_var / self.num_gp_samples) var_per_gp_sample = torch.var(gradients_measure_list_btsp, dim=1) mc_var = torch.mean(var_per_gp_sample, dim=0) mc_se = torch.sqrt(mc_var / (self.num_bootstrap_samples + 1)) total_mean = gradients_measure_list_btsp.reshape(-1, self.dim).mean(dim=0) gradients_measure_mean_vargp_segp_varmc_segp = torch.cat( [ torch.tensor( [total_mean[i], gp_var[i], gp_se[i], mc_var[i], mc_se[i]] ).unsqueeze(0) for i in range(self.dim) ], dim=0, ) return gradients_measure_mean_vargp_segp_varmc_segp