Source code for ax.models.torch.botorch_modular.utils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import warnings
from collections import OrderedDict
from collections.abc import Sequence
from dataclasses import dataclass, field
from logging import Logger
from typing import Any

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import AxError, AxWarning, UnsupportedError
from ax.models.torch_base import TorchOptConfig
from ax.models.types import TConfig
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.logei import qLogNoisyExpectedImprovement
from botorch.acquisition.multi_objective.logei import (
    qLogNoisyExpectedHypervolumeImprovement,
)
from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel, GPyTorchModel
from botorch.models.model import Model, ModelList
from botorch.models.multitask import MultiTaskGP
from botorch.models.pairwise_gp import PairwiseGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.transforms import is_fully_bayesian
from gpytorch.kernels.kernel import Kernel
from gpytorch.likelihoods import Likelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from pyre_extensions import none_throws
from torch import Tensor

MIN_OBSERVED_NOISE_LEVEL = 1e-7
logger: Logger = get_logger(__name__)


[docs] @dataclass class ModelConfig: """Configuration for the BoTorch Model used in Surrogate. Args: botorch_model_class: ``Model`` class to be used as the underlying BoTorch model. If None is provided a model class will be selected (either one for all outcomes or a ModelList with separate models for each outcome) will be selected automatically based off the datasets at `construct` time. model_options: Dictionary of options / kwargs for the BoTorch ``Model`` constructed during ``Surrogate.fit``. Note that the corresponding attribute will later be updated to include any additional kwargs passed into ``BoTorchModel.fit``. mll_class: ``MarginalLogLikelihood`` class to use for model-fitting. This argument is deprecated in favor of model_configs. mll_options: Dictionary of options / kwargs for the MLL. outcome_transform_classes: List of BoTorch outcome transforms classes. Passed down to the BoTorch ``Model``. Multiple outcome transforms can be chained together using ``ChainedOutcomeTransform``. outcome_transform_options: Outcome transform classes kwargs. The keys are class string names and the values are dictionaries of outcome transform kwargs. For example, ` outcome_transform_classes = [Standardize] outcome_transform_options = { "Standardize": {"m": 1}, ` For more options see `botorch/models/transforms/outcome.py`. input_transform_classes: List of BoTorch input transforms classes. Passed down to the BoTorch ``Model``. Multiple input transforms will be chained together using ``ChainedInputTransform``. input_transform_options: Input transform classes kwargs. The keys are class string names and the values are dictionaries of input transform kwargs. For example, ` input_transform_classes = [Normalize, Round] input_transform_options = { "Normalize": {"d": 3}, "Round": {"integer_indices": [0], "categorical_features": {1: 2}}, } ` For more input options see `botorch/models/transforms/input.py`. covar_module_class: Covariance module class. This gets initialized after parsing the ``covar_module_options`` in ``covar_module_argparse``, and gets passed to the model constructor as ``covar_module``. covar_module_options: Covariance module kwargs. in favor of model_configs. likelihood: ``Likelihood`` class. This gets initialized with ``likelihood_options`` and gets passed to the model constructor. This argument is deprecated in favor of model_configs. likelihood_options: Likelihood options. """ botorch_model_class: type[Model] | None = None model_options: dict[str, Any] = field(default_factory=dict) mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood mll_options: dict[str, Any] = field(default_factory=dict) input_transform_classes: list[type[InputTransform]] | None = None input_transform_options: dict[str, dict[str, Any]] | None = field( default_factory=dict ) outcome_transform_classes: list[type[OutcomeTransform]] | None = None outcome_transform_options: dict[str, dict[str, Any]] = field(default_factory=dict) covar_module_class: type[Kernel] | None = None covar_module_options: dict[str, Any] = field(default_factory=dict) likelihood_class: type[Likelihood] | None = None likelihood_options: dict[str, Any] = field(default_factory=dict)
[docs] def use_model_list( datasets: Sequence[SupervisedDataset], botorch_model_class: type[Model], model_configs: list[ModelConfig] | None = None, metric_to_model_configs: dict[str, list[ModelConfig]] | None = None, allow_batched_models: bool = True, ) -> bool: model_configs = model_configs or [] metric_to_model_configs = metric_to_model_configs or {} if len(datasets) == 1 and datasets[0].Y.shape[-1] == 1: # There is only one outcome, so we can use a single model. return False elif ( len(model_configs) > 1 or len(metric_to_model_configs) > 0 or any(len(model_config) for model_config in metric_to_model_configs.values()) ): # There are multiple outcomes and outcomes might be modeled with different # models return True # Otherwise, the same model class is used for all outcomes. # Determine what the model class is. if len(model_configs) > 0: botorch_model_class = ( model_configs[0].botorch_model_class or botorch_model_class ) if issubclass(botorch_model_class, SaasFullyBayesianSingleTaskGP): # SAAS models do not support multiple outcomes. # Use model list if there are multiple outcomes. return len(datasets) > 1 or datasets[0].Y.shape[-1] > 1 elif issubclass(botorch_model_class, MultiTaskGP): # We wrap multi-task models into `ModelListGP` when there are # multiple outcomes. return len(datasets) > 1 or datasets[0].Y.shape[-1] > 1 elif len(datasets) == 1: # This method is called before multiple datasets are merged into # one if using a batched model. If there is one dataset here, # there should be a reason that a single model should be used: # e.g. a contextual model, where we want to jointly model the metric # each context (and context-level metrics are different outcomes). return False elif issubclass(botorch_model_class, BatchedMultiOutputGPyTorchModel) and all( torch.equal(datasets[0].X, ds.X) for ds in datasets[1:] ): # Use batch models if allowed return not allow_batched_models # If there are multiple Xs and they are not all equal, we use `ModelListGP`. return True
[docs] def choose_model_class( datasets: Sequence[SupervisedDataset], search_space_digest: SearchSpaceDigest, ) -> type[Model]: """Chooses a BoTorch `Model` using the given data (currently just Yvars) and its properties (information about task and fidelity features). Args: Yvars: List of tensors, each representing observation noise for a given outcome, where outcomes are in the same order as in Xs. task_features: List of columns of X that are tasks. fidelity_features: List of columns of X that are fidelity parameters. Returns: A BoTorch `Model` class. """ if len(search_space_digest.fidelity_features) > 1: raise NotImplementedError( "Only a single fidelity feature supported " f"(got: {search_space_digest.fidelity_features})." ) if len(search_space_digest.task_features) > 1: raise NotImplementedError( f"Only a single task feature supported " f"(got: {search_space_digest.task_features})." ) if search_space_digest.task_features and search_space_digest.fidelity_features: raise NotImplementedError( "Multi-task multi-fidelity optimization not yet supported." ) is_fixed_noise = [ds.Yvar is not None for ds in datasets] all_inferred = not any(is_fixed_noise) if not all_inferred and not all(is_fixed_noise): raise ValueError( "Mix of known and unknown variances indicates valuation function " "errors. Variances should all be specified, or none should be." ) # Multi-task case (when `task_features` is specified). if search_space_digest.task_features: model_class = MultiTaskGP # Single-task multi-fidelity cases. elif search_space_digest.fidelity_features: model_class = SingleTaskMultiFidelityGP # Mixed optimization case. Note that presence of categorical # features in search space digest indicates that downstream in the # stack we chose not to perform continuous relaxation on those # features. elif search_space_digest.categorical_features: model_class = MixedSingleTaskGP # Single-task single-fidelity cases. else: model_class = SingleTaskGP logger.debug(f"Chose BoTorch model class: {model_class}.") return model_class
[docs] def choose_botorch_acqf_class( torch_opt_config: TorchOptConfig, ) -> type[AcquisitionFunction]: """Chooses a BoTorch ``AcquisitionFunction`` class. Current logic relies on ``TorchOptConfig.is_moo`` field to determine whether to use qLogNEHVI (for MOO) or qLogNEI for (SOO). """ if torch_opt_config.is_moo: acqf_class = qLogNoisyExpectedHypervolumeImprovement else: acqf_class = qLogNoisyExpectedImprovement logger.debug(f"Chose BoTorch acquisition function class: {acqf_class}.") return acqf_class
[docs] def construct_acquisition_and_optimizer_options( acqf_options: TConfig, model_gen_options: TConfig | None = None ) -> tuple[TConfig, TConfig]: """Extract acquisition and optimizer options from `model_gen_options`.""" acq_options = acqf_options.copy() opt_options = {} if model_gen_options: acq_options.update( checked_cast(dict, model_gen_options.get(Keys.ACQF_KWARGS, {})) ) # TODO: Add this if all acq. functions accept the `subset_model` # kwarg or opt for kwarg filtering. # acq_options[SUBSET_MODEL] = model_gen_options.get(SUBSET_MODEL) opt_options = checked_cast( dict, model_gen_options.get(Keys.OPTIMIZER_KWARGS, {}) ).copy() return acq_options, opt_options
[docs] def convert_to_block_design( datasets: Sequence[SupervisedDataset], force: bool = False, ) -> list[SupervisedDataset]: # Convert data to "block design". TODO: Figure out a better # solution for this using the data containers (pass outcome # names as properties of the data containers) is_fixed = [ds.Yvar is not None for ds in datasets] if any(is_fixed) and not all(is_fixed): raise UnsupportedError( "Cannot convert mixed data with and without variance " "observations to `block design`." ) is_fixed = all(is_fixed) Xs = [dataset.X for dataset in datasets] for dset in datasets[1:]: if dset.feature_names != datasets[0].feature_names: raise ValueError( "Feature names must be the same across all datasets, " f"got {dset.feature_names} and {datasets[0].feature_names}" ) # Join the outcome names of datasets. outcome_names = sum([ds.outcome_names for ds in datasets], []) if len({X.shape for X in Xs}) != 1 or not all( torch.equal(X, Xs[0]) for X in Xs[1:] ): if not force: raise UnsupportedError( "Cannot convert data to non-block design data. " "To force this and drop data not shared between " "outcomes use `force=True`." ) warnings.warn( "Forcing conversion of data not complying to a block design " "to block design by dropping observations that are not shared " "between outcomes.", AxWarning, stacklevel=3, ) X_shared, idcs_shared = _get_shared_rows(Xs=Xs) Y = torch.cat([ds.Y[i] for ds, i in zip(datasets, idcs_shared)], dim=-1) if is_fixed: Yvar = torch.cat( # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. [ds.Yvar[i] for ds, i in zip(datasets, idcs_shared)], dim=-1, ) else: Yvar = None datasets = [ SupervisedDataset( X=X_shared, Y=Y, Yvar=Yvar, feature_names=datasets[0].feature_names, outcome_names=outcome_names, ) ] return datasets # data complies to block design, can concat with impunity Y = torch.cat([ds.Y for ds in datasets], dim=-1) if is_fixed: Yvar = torch.cat([none_throws(ds.Yvar) for ds in datasets], dim=-1) else: Yvar = None datasets = [ SupervisedDataset( X=Xs[0], Y=Y, Yvar=Yvar, feature_names=datasets[0].feature_names, outcome_names=outcome_names, ) ] return datasets
def _get_shared_rows(Xs: list[Tensor]) -> tuple[Tensor, list[Tensor]]: """Extract shared rows from a list of tensors Args: Xs: A list of m two-dimensional tensors with shapes `(n_1 x d), ..., (n_m x d)`. It is not required that the `n_i` are the same. Returns: A two-tuple containing (i) a Tensor with the rows that are shared between all the Tensors in `Xs`, and (ii) a list of index tensors that indicate the location of these rows in the respective elements of `Xs`. """ idcs_shared = [] Xs_sorted = sorted(Xs, key=len) X_shared = Xs_sorted[0].clone() for X in Xs_sorted[1:]: X_shared = X_shared[(X_shared == X.unsqueeze(-2)).all(dim=-1).any(dim=-2)] # get indices for X in Xs: same = (X_shared == X.unsqueeze(-2)).all(dim=-1).any(dim=-1) idcs_shared.append(torch.arange(same.shape[-1], device=X_shared.device)[same]) return X_shared, idcs_shared
[docs] def fit_botorch_model( model: Model, mll_class: type[MarginalLogLikelihood], mll_options: dict[str, Any] | None = None, ) -> None: """Fit a BoTorch model.""" mll_options = mll_options or {} models = model.models if isinstance(model, ModelList) else [model] for m in models: # TODO: Support deterministic models when we support `ModelList` if is_fully_bayesian(m): fit_fully_bayesian_model_nuts( m, disable_progbar=True, **mll_options, ) elif isinstance(m, (GPyTorchModel, PairwiseGP)): mll_options = mll_options or {} mll = mll_class(likelihood=m.likelihood, model=m, **mll_options) fit_gpytorch_mll(mll) else: raise NotImplementedError( f"Model of type {m.__class__.__name__} is currently not supported." )
def _tensor_difference(A: Tensor, B: Tensor) -> Tensor: """Used to return B sans any Xs that also appear in A""" C = torch.cat((A, B), dim=0) D, inverse_ind = torch.unique(C, return_inverse=True, dim=0) n = A.shape[0] A_indices = inverse_ind[:n].tolist() B_indices = inverse_ind[n:].tolist() Bi_set = set(B_indices) - set(A_indices) return D[list(Bi_set)]
[docs] def check_outcome_dataset_match( outcome_names: Sequence[str], datasets: Sequence[SupervisedDataset], exact_match: bool, ) -> None: """Check that the given outcome names match those of datasets. Based on `exact_match` we either require that outcome names are a subset of all outcomes or require the them to be the same. Also checks that there are no duplicates in outcome names. Args: outcome_names: A list of outcome names. datasets: A list of `SupervisedDataset` objects. exact_match: If True, outcome_names must be the same as the union of outcome names of the datasets. Otherwise, we check that the outcome_names are a subset of all outcomes. Raises: ValueError: If there is no match. """ all_outcomes = sum((ds.outcome_names for ds in datasets), []) set_all_outcomes = set(all_outcomes) set_all_spec_outcomes = set(outcome_names) if len(set_all_outcomes) != len(all_outcomes): raise AxError("Found duplicate outcomes in the datasets.") if len(set_all_spec_outcomes) != len(outcome_names): raise AxError("Found duplicate outcome names.") if not exact_match: if not set_all_spec_outcomes.issubset(set_all_outcomes): raise AxError( "Outcome names must be a subset of the outcome names of the datasets." f"Got {outcome_names=} but the datasets model {set_all_outcomes}." ) elif set_all_spec_outcomes != set_all_outcomes: raise AxError( "Each outcome name must correspond to an outcome in the datasets. " f"Got {outcome_names=} but the datasets model {set_all_outcomes}." )
[docs] def get_subset_datasets( datasets: Sequence[SupervisedDataset], subset_outcome_names: Sequence[str], ) -> list[SupervisedDataset]: """Get the list of datasets corresponding to the given subset of outcome names. This is used to separate out datasets that are used by one surrogate. Args: datasets: A list of `SupervisedDataset` objects. subset_outcome_names: A list of outcome names to get datasets for. Returns: A list of `SupervisedDataset` objects corresponding to the given subset of outcome names. """ check_outcome_dataset_match( outcome_names=subset_outcome_names, datasets=datasets, exact_match=False ) single_outcome_datasets = { ds.outcome_names[0]: ds for ds in datasets if len(ds.outcome_names) == 1 } multi_outcome_datasets = { tuple(ds.outcome_names): ds for ds in datasets if len(ds.outcome_names) > 1 } subset_datasets = [] outcomes_processed = [] for outcome_name in subset_outcome_names: if outcome_name in outcomes_processed: # This can happen if the outcome appears in a multi-outcome # dataset that is already processed. continue if outcome_name in single_outcome_datasets: # The default case of outcome with a corresponding dataset. ds = single_outcome_datasets[outcome_name] else: # The case of outcome being part of a multi-outcome dataset. for outcome_names in multi_outcome_datasets.keys(): if outcome_name in outcome_names: ds = multi_outcome_datasets[outcome_names] if not set(ds.outcome_names).issubset(subset_outcome_names): raise UnsupportedError( "Breaking up a multi-outcome dataset between " "surrogates is not supported." ) break # Pyre-ignore [61]: `ds` may not be defined but it is guaranteed to be defined. subset_datasets.append(ds) outcomes_processed.extend(ds.outcome_names) return subset_datasets
[docs] def subset_state_dict( state_dict: OrderedDict[str, Tensor], submodel_index: int, ) -> OrderedDict[str, Tensor]: """Get the state dict for a submodel from the state dict of a model list. Args: state_dict: A state dict. submodel_index: The index of the submodel to extract. Returns: The state dict for the submodel. """ expected_substring = f"models.{submodel_index}." len_substring = len(expected_substring) new_items = [ (k[len_substring:], v) for k, v in state_dict.items() if k.startswith(expected_substring) ] return OrderedDict(new_items)