Source code for ax.models.torch_base

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from ax.core.metric import Metric
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata
from ax.models.base import Model as BaseModel
from ax.models.types import TConfig
from botorch.acquisition.risk_measures import RiskMeasureMCObjective
from botorch.utils.datasets import SupervisedDataset
from torch import Tensor


[docs]@dataclass class TorchOptConfig: """Container for lightweight representation of optimization arguments. This is used for communicating between modelbridge and models. This is an ephemeral object and not meant to be stored / serialized. Attributes: objective_weights: If doing multi-objective optimization, these denote which objectives should be maximized and which should be minimized. Otherwise, the objective is to maximize a weighted sum of the columns of f(x). These are the weights. outcome_constraints: A tuple of (A, b). For k outcome constraints and m outputs at f(x), A is (k x m) and b is (k x 1) such that A f(x) <= b. objective_thresholds: A tensor containing thresholds forming a reference point from which to calculate pareto frontier hypervolume. Points that do not dominate the objective_thresholds contribute nothing to hypervolume. linear_constraints: A tuple of (A, b). For k linear constraints on d-dimensional x, A is (k x d) and b is (k x 1) such that A x <= b for feasible x. fixed_features: A map {feature_index: value} for features that should be fixed to a particular value during generation. pending_observations: A list of m (k_i x d) feature tensors X for m outcomes and k_i pending observations for outcome i. model_gen_options: A config dictionary that can contain model-specific options. This commonly includes `optimizer_kwargs`, which often specifies the optimizer options to be passed to the optimizer while optimizing the acquisition function. These are generally expected to mimic the signature of `optimize_acqf`, though not all models may support all possible arguments and some models may support additional arguments that are not passed to the optimizer. While constructing a generation strategy, these options can be passed in as follows: >>> model_gen_kwargs = { >>> "model_gen_options": { >>> "optimizer_kwargs": { >>> "num_restarts": 20, >>> "sequential": False, >>> "options": { >>> "batch_limit: 5, >>> "maxiter": 200, >>> }, >>> }, >>> }, >>> } rounding_func: A function that rounds an optimization result appropriately (i.e., according to `round-trip` transformations). opt_config_metrics: A dictionary of metrics that are included in the optimization config. is_moo: A boolean denoting whether this is for an MOO problem. risk_measure: An optional risk measure, used for robust optimization. """ objective_weights: Tensor outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None objective_thresholds: Optional[Tensor] = None linear_constraints: Optional[Tuple[Tensor, Tensor]] = None fixed_features: Optional[Dict[int, float]] = None pending_observations: Optional[List[Tensor]] = None model_gen_options: TConfig = field(default_factory=dict) rounding_func: Optional[Callable[[Tensor], Tensor]] = None opt_config_metrics: Optional[Dict[str, Metric]] = None is_moo: bool = False risk_measure: Optional[RiskMeasureMCObjective] = None
[docs]@dataclass(frozen=True) class TorchGenResults: """ points: (n x d) Tensor of generated points. weights: n-tensor of weights for each point. gen_metadata: Generation metadata Dictionary of model-specific metadata for the given generation candidates """ points: Tensor # (n x d)-dim weights: Tensor # n-dim gen_metadata: Dict[str, Any] = field(default_factory=dict) candidate_metadata: Optional[List[TCandidateMetadata]] = None
[docs]class TorchModel(BaseModel): """This class specifies the interface for a torch-based model. These methods should be implemented to have access to all of the features of Ax. """ dtype: Optional[torch.dtype] = None device: Optional[torch.device] = None _supports_robust_optimization: bool = False
[docs] def fit( self, datasets: List[SupervisedDataset], search_space_digest: SearchSpaceDigest, candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None, ) -> None: """Fit model to m outcomes. Args: datasets: A list of ``SupervisedDataset`` containers, each corresponding to the data of one metric (outcome). search_space_digest: A ``SearchSpaceDigest`` object containing metadata on the features in the datasets. candidate_metadata: Model-produced metadata for candidates, in the order corresponding to the Xs. """ pass
[docs] def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]: """Predict Args: X: (j x d) tensor of the j points at which to make predictions. Returns: 2-element tuple containing - (j x m) tensor of outcome predictions at X. - (j x m x m) tensor of predictive covariances at X. cov[j, m1, m2] is Cov[m1@j, m2@j]. """ raise NotImplementedError
[docs] def gen( self, n: int, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, ) -> TorchGenResults: """ Generate new candidates. Args: n: Number of candidates to generate. search_space_digest: A SearchSpaceDigest object containing metadata about the search space (e.g. bounds, parameter types). torch_opt_config: A TorchOptConfig object containing optimization arguments (e.g., objective weights, constraints). Returns: A TorchGenResult container. """ raise NotImplementedError
[docs] def best_point( self, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, ) -> Optional[Tensor]: """ Identify the current best point, satisfying the constraints in the same format as to gen. Return None if no such point can be identified. Args: search_space_digest: A SearchSpaceDigest object containing metadata about the search space (e.g. bounds, parameter types). torch_opt_config: A TorchOptConfig object containing optimization arguments (e.g., objective weights, constraints). Returns: d-tensor of the best point. """ return None
[docs] def cross_validate( self, datasets: List[SupervisedDataset], X_test: Tensor, search_space_digest: SearchSpaceDigest, ) -> Tuple[Tensor, Tensor]: """Do cross validation with the given training and test sets. Training set is given in the same format as to fit. Test set is given in the same format as to predict. Args: datasets: A list of ``SupervisedDataset`` containers, each corresponding to the data of one metric (outcome). X_test: (j x d) tensor of the j points at which to make predictions. search_space_digest: A SearchSpaceDigest object containing metadata on the features in X. Returns: 2-element tuple containing - (j x m) tensor of outcome predictions at X. - (j x m x m) tensor of predictive covariances at X. cov[j, m1, m2] is Cov[m1@j, m2@j]. """ raise NotImplementedError
[docs] def update( self, datasets: List[SupervisedDataset], metric_names: List[str], search_space_digest: SearchSpaceDigest, candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None, ) -> None: """Update the model. Updating the model requires both existing and additional data. The data passed into this method will become the new training data. Args: datasets: A list of ``SupervisedDataset`` containers, each corresponding to the data of one metric (outcome). `None` means that there is no additional data for the corresponding outcome. metric_names: A list of metric names, with the i-th metric corresponding to the i-th dataset. search_space_digest: A SearchSpaceDigest object containing metadata on the features in X. candidate_metadata: Model-produced metadata for candidates, in the order corresponding to the Xs. """ raise DeprecationWarning("Model.update has been deprecated. Use `fit` instead.")
[docs] def evaluate_acquisition_function( self, X: Tensor, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, acq_options: Optional[Dict[str, Any]] = None, ) -> Tensor: """Evaluate the acquisition function on the candidate set `X`. Args: X: (j x d) tensor of the j points at which to evaluate the acquisition function. search_space_digest: A dataclass used to compactly represent a search space. torch_opt_config: A TorchOptConfig object containing optimization arguments (e.g., objective weights, constraints). acq_options: Keyword arguments used to contruct the acquisition function. Returns: A single-element tensor with the acquisition value for these points. """ raise NotImplementedError # pragma: nocover