Source code for ax.models.torch.cbo_lcem

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Optional

import torch
from ax.models.torch.botorch import BotorchModel
from botorch.fit import fit_gpytorch_model
from botorch.models.contextual_multioutput import LCEMGP, FixedNoiseLCEMGP
from botorch.models.model_list_gp_regression import ModelListGP
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from torch import Tensor


MIN_OBSERVED_NOISE_LEVEL = 1e-7


[docs]class LCEMBO(BotorchModel): r"""Does Bayesian optimization with LCE-M GP.""" def __init__( self, context_cat_feature: Optional[Tensor] = None, context_emb_feature: Optional[Tensor] = None, embs_dim_list: Optional[List[int]] = None, ) -> None: self.context_cat_feature = context_cat_feature self.context_emb_feature = context_emb_feature self.embs_dim_list = embs_dim_list super().__init__(model_constructor=self.get_and_fit_model)
[docs] def get_and_fit_model( self, Xs: List[Tensor], Ys: List[Tensor], Yvars: List[Tensor], task_features: List[int], fidelity_features: List[int], metric_names: List[str], state_dict: Optional[Dict[str, Tensor]] = None, fidelity_model_id: Optional[int] = None, **kwargs: Any, ) -> ModelListGP: """Get a fitted multi-task contextual GP model for each outcome. Args: Xs: List of X data, one tensor per outcome. Ys: List of Y data, one tensor per outcome. Yvars:List of Noise variance of Yvar data, one tensor per outcome. task_features: List of columns of X that are tasks. Returns: ModeListGP that each model is a fitted LCEM GP model. """ if len(task_features) == 1: task_feature = task_features[0] elif len(task_features) > 1: raise NotImplementedError( f"LCEMBO only supports 1 task feature (got {task_features})" ) else: raise ValueError("LCEMBO requires context input as task features") models = [] for i, X in enumerate(Xs): # validate input Yvars Yvar = Yvars[i].clamp_min_(MIN_OBSERVED_NOISE_LEVEL) is_nan = torch.isnan(Yvar) all_nan_Yvar = torch.all(is_nan) if all_nan_Yvar: gp_m = LCEMGP( train_X=X, train_Y=Ys[i], task_feature=task_feature, context_cat_feature=self.context_cat_feature, context_emb_feature=self.context_emb_feature, embs_dim_list=self.embs_dim_list, ) else: gp_m = FixedNoiseLCEMGP( train_X=X, train_Y=Ys[i], train_Yvar=Yvar, task_feature=task_feature, context_cat_feature=self.context_cat_feature, context_emb_feature=self.context_emb_feature, embs_dim_list=self.embs_dim_list, ) models.append(gp_m) # Use a ModelListGP model = ModelListGP(*models) model.to(Xs[0]) mll = SumMarginalLogLikelihood(model.likelihood, model) fit_gpytorch_model(mll) return model