Source code for ax.analysis.plotly.utils
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import numpy as np
import torch
from ax.core.experiment import Experiment
from ax.core.objective import MultiObjective, ScalarizedObjective
from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.modelbridge.base import ModelBridge
from botorch.utils.probability.utils import compute_log_prob_feas_from_bounds
from numpy.typing import NDArray
# Because normal distributions have long tails, every arm has a non-zero
# probability of violating the constraint. But below a certain threshold, we
# consider probability of violation to be negligible.
MINIMUM_CONTRAINT_VIOLATION_THRESHOLD = 0.01
[docs]
def get_constraint_violated_probabilities(
predictions: list[tuple[dict[str, float], dict[str, float]]],
outcome_constraints: list[OutcomeConstraint],
) -> dict[str, list[float]]:
"""Get the probability that each arm violates the outcome constraints.
Args:
predictions: List of predictions for each observation feature
generated by predict_at_point. It should include predictions
for all outcome constraint metrics.
outcome_constraints: List of outcome constraints to check.
Returns:
A dict of probabilities that each arm violates the outcome
constraint provided, and for "any_constraint_violated" the probability that
the arm violates *any* outcome constraint provided.
"""
if len(outcome_constraints) == 0:
return {"any_constraint_violated": [0.0] * len(predictions)}
if any(constraint.relative for constraint in outcome_constraints):
raise UserInputError(
"`get_constraint_violated_probabilities()` does not support relative "
"outcome constraints. Use `Derelativize().transform_optimization_config()` "
"before passing constraints to this method."
)
metrics = [constraint.metric.name for constraint in outcome_constraints]
means = torch.as_tensor(
[
[prediction[0][metric_name] for metric_name in metrics]
for prediction in predictions
]
)
sigmas = torch.as_tensor(
[
[prediction[1][metric_name] for metric_name in metrics]
for prediction in predictions
]
)
feasibility_probabilities: dict[str, NDArray] = {}
for constraint in outcome_constraints:
if constraint.op == ComparisonOp.GEQ:
con_lower_inds = torch.tensor([metrics.index(constraint.metric.name)])
con_lower = torch.tensor([constraint.bound])
con_upper_inds = torch.as_tensor([])
con_upper = torch.as_tensor([])
else:
con_lower_inds = torch.as_tensor([])
con_lower = torch.as_tensor([])
con_upper_inds = torch.tensor([metrics.index(constraint.metric.name)])
con_upper = torch.tensor([constraint.bound])
feasibility_probabilities[constraint.metric.name] = (
compute_log_prob_feas_from_bounds(
means=means,
sigmas=sigmas,
con_lower_inds=con_lower_inds,
con_upper_inds=con_upper_inds,
con_lower=con_lower,
con_upper=con_upper,
# "both" can also be expressed by 2 separate constraints...
con_both_inds=torch.as_tensor([]),
con_both=torch.as_tensor([]),
)
.exp()
.numpy()
)
feasibility_probabilities["any_constraint_violated"] = np.prod(
list(feasibility_probabilities.values()), axis=0
)
return {
metric_name: (1 - feasibility_probabilities[metric_name]).tolist()
for metric_name in feasibility_probabilities
}
[docs]
def is_predictive(model: ModelBridge) -> bool:
"""Check if a model is predictive. Basically, we're checking if
predict() is implemented.
NOTE: This does not mean it's capable of out of sample prediction.
"""
try:
model.predict(observation_features=[])
except NotImplementedError:
return False
except Exception:
return True
return True
[docs]
def select_metric(experiment: Experiment) -> str:
"""Select the most relevant metric to plot from an Experiment."""
if experiment.optimization_config is None:
raise ValueError(
"Cannot infer metric to plot from Experiment without OptimizationConfig"
)
objective = experiment.optimization_config.objective
if isinstance(objective, MultiObjective):
raise UnsupportedError(
"Cannot infer metric to plot from MultiObjective, please "
"specify a metric"
)
if isinstance(objective, ScalarizedObjective):
raise UnsupportedError(
"Cannot infer metric to plot from ScalarizedObjective, please "
"specify a metric"
)
return experiment.optimization_config.objective.metric.name