Source code for ax.modelbridge.prediction_utils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import Dict, Set, Tuple

import numpy as np
from ax.core.observation import ObservationFeatures
from ax.modelbridge import ModelBridge

[docs]def predict_at_point( model: ModelBridge, obsf: ObservationFeatures, metric_names: Set[str] ) -> Tuple[Dict[str, float], Dict[str, float]]: """Make a prediction at a point. Returns mean and standard deviation in format expected by plotting. Args: model: ModelBridge obsf: ObservationFeatures for which to predict metric_names: Limit predictions to these metrics. Returns: A tuple containing - Map from metric name to prediction. - Map from metric name to standard error. """ y_hat = {} se_hat = {} f_pred, cov_pred = model.predict([obsf]) for metric_name in f_pred: if metric_name in metric_names: y_hat[metric_name] = f_pred[metric_name][0] se_hat[metric_name] = np.sqrt(cov_pred[metric_name][metric_name][0]) return y_hat, se_hat
[docs]def predict_by_features( model: ModelBridge, label_to_feature_dict: Dict[int, ObservationFeatures], metric_names: Set[str], ) -> Dict[int, Dict[str, Tuple[float, float]]]: """Predict for given data points and model. Args: model: Model to be used for the prediction metric_names: Names of the metrics, for which to retrieve predictions. label_to_feature_dict: Mapping from an int label to a Parameterization. These data points are predicted. Returns: A mapping from an int label to a mapping of metric names to tuples of predicted metric mean and SEM, of form: { trial_index -> { metric_name: ( mean, SEM ) } }. """ predictions_dict = {} # Store predictions to return for label in label_to_feature_dict: try: y_hat, se_hat = predict_at_point( model=model, obsf=label_to_feature_dict[label], metric_names=metric_names, ) except NotImplementedError: raise NotImplementedError( "The model associated with the current generation strategy " "step is not one that can be used for predicting values. " "For example, this may be the Sobol generator associated with the " "initialization step where quasi-random points are generated. " "Try again by calling the `AxClient.create_experiment()` " "method with the `choose_generation_strategy_kwargs=" '{"num_initialization_trials": 0}` parameter if you are looking ' "to use a generation strategy without an initialization step that " "proceeds straight to the Bayesian optimization step, but note " "that performance of Bayesian optimization can be suboptimal if " "search space is not sampled well in the initialization phase." ) predictions_dict[label] = { metric: ( y_hat[metric], se_hat[metric], ) for metric in metric_names } return predictions_dict