Source code for ax.models.torch.cbo_lcem

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, Optional

import torch
from ax.models.torch.botorch import BotorchModel
from botorch.fit import fit_gpytorch_mll
from botorch.models.contextual_multioutput import LCEMGP
from botorch.models.model_list_gp_regression import ModelListGP
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from torch import Tensor


MIN_OBSERVED_NOISE_LEVEL = 1e-7


[docs]class LCEMBO(BotorchModel): r"""Does Bayesian optimization with LCE-M GP.""" def __init__( self, context_cat_feature: Optional[Tensor] = None, context_emb_feature: Optional[Tensor] = None, embs_dim_list: Optional[list[int]] = None, ) -> None: self.context_cat_feature = context_cat_feature self.context_emb_feature = context_emb_feature self.embs_dim_list = embs_dim_list super().__init__(model_constructor=self.get_and_fit_model)
[docs] def get_and_fit_model( self, Xs: list[Tensor], Ys: list[Tensor], Yvars: list[Tensor], task_features: list[int], fidelity_features: list[int], metric_names: list[str], state_dict: Optional[dict[str, Tensor]] = None, fidelity_model_id: Optional[int] = None, **kwargs: Any, ) -> ModelListGP: """Get a fitted multi-task contextual GP model for each outcome. Args: Xs: List of X data, one tensor per outcome. Ys: List of Y data, one tensor per outcome. Yvars:List of Noise variance of Yvar data, one tensor per outcome. task_features: List of columns of X that are tasks. Returns: ModeListGP that each model is a fitted LCEM GP model. """ if len(task_features) == 1: task_feature = task_features[0] elif len(task_features) > 1: raise NotImplementedError( f"LCEMBO only supports 1 task feature (got {task_features})" ) else: raise ValueError("LCEMBO requires context input as task features") models = [] for i, X in enumerate(Xs): # validate input Yvars Yvar = Yvars[i].clamp_min_(MIN_OBSERVED_NOISE_LEVEL) is_nan = torch.isnan(Yvar) all_nan_Yvar = torch.all(is_nan) all_tasks, _, _ = LCEMGP.get_all_tasks(train_X=X, task_feature=task_feature) gp_m = LCEMGP( train_X=X, train_Y=Ys[i], train_Yvar=None if all_nan_Yvar else Yvar, task_feature=task_feature, context_cat_feature=self.context_cat_feature, context_emb_feature=self.context_emb_feature, embs_dim_list=self.embs_dim_list, # specify output tasks so that model.num_outputs = 1 # since the model only models a single outcome. output_tasks=all_tasks[:1], ) models.append(gp_m) # Use a ModelListGP model = ModelListGP(*models) model.to(Xs[0]) mll = SumMarginalLogLikelihood(model.likelihood, model) fit_gpytorch_mll(mll) return model