Source code for ax.analysis.healthcheck.constraints_feasibility
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import json
from typing import Tuple
import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.healthcheck.healthcheck_analysis import (
HealthcheckAnalysis,
HealthcheckAnalysisCard,
HealthcheckStatus,
)
from ax.analysis.plotly.arm_effects.utils import get_predictions_by_arm
from ax.analysis.plotly.utils import is_predictive
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.optimization_config import OptimizationConfig
from ax.exceptions.core import UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.transforms.derelativize import Derelativize
from ax.utils.common.typeutils import checked_cast
from pyre_extensions import none_throws
[docs]
class ConstraintsFeasibilityAnalysis(HealthcheckAnalysis):
"""
Analysis for checking the feasibility of the constraints for the experiment.
A constraint is considered feasible if the probability of constraints violation
is below the threshold for at least one arm.
"""
def __init__(self, prob_threshold: float = 0.95) -> None:
r"""
Args:
prob_theshold: The threshold for the probability of constraint violation.
Returns None
"""
self.prob_threshold = prob_threshold
[docs]
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
) -> HealthcheckAnalysisCard:
r"""
Compute the feasibility of the constraints for the experiment.
Args:
experiment: Ax experiment.
generation_strategy: Ax generation strategy.
prob_threhshold: Threshold for the probability of constraint violation.
Constraints are considered feasible if the probability of constraint
violation is below the threshold for at least one arm.
Returns:
A HealthcheckAnalysisCard object with the information on infeasible metrics,
i.e., metrics for which the constraints are infeasible for all test groups
(arms).
"""
status = HealthcheckStatus.PASS
subtitle = "All constraints are feasible."
title_status = "Success"
level = AnalysisCardLevel.LOW
df = pd.DataFrame({"status": [status]})
if experiment is None:
raise UserInputError(
"ConstraintsFeasibilityAnalysis requires an Experiment."
)
if experiment.optimization_config is None:
raise UserInputError(
"ConstraintsFeasibilityAnalysis requires an Experiment with an "
"optimization config."
)
if (
experiment.optimization_config.outcome_constraints is None
or len(experiment.optimization_config.outcome_constraints) == 0
):
subtitle = "No constraints are specified."
return HealthcheckAnalysisCard(
name="ConstraintsFeasibility",
title=f"Ax Constraints Feasibility {title_status}",
blob=json.dumps({"status": status}),
subtitle=subtitle,
df=df,
level=level,
)
if generation_strategy is None:
raise UserInputError(
"ConstraintsFeasibilityAnalysis requires a GenerationStrategy."
)
generation_strategy = checked_cast(
GenerationStrategy,
generation_strategy,
exception=UserInputError(
"ConstraintsFeasibilityAnalysis requires a GenerationStrategy."
),
)
if generation_strategy.model is None:
generation_strategy._fit_current_model(data=experiment.lookup_data())
model = none_throws(generation_strategy.model)
if not is_predictive(model=model):
raise UserInputError(
"ConstraintsFeasibility requires a GenerationStrategy which is "
"in a state where the current model supports prediction. "
"The current model is {model._model_key} and does not support "
"prediction."
)
optimization_config = checked_cast(
OptimizationConfig, experiment.optimization_config
)
constraints_feasible, df = constraints_feasibility(
optimization_config=optimization_config,
model=model,
prob_threshold=self.prob_threshold,
)
df["status"] = status
if not constraints_feasible:
status = HealthcheckStatus.WARNING
subtitle = (
"Constraints are infeasible for all test groups (arms) with respect "
f"to the probability threshold {self.prob_threshold}. "
"We suggest relaxing the constraint bounds for the constraints."
)
title_status = "Warning"
df.loc[
df["overall_probability_constraints_violated"] > self.prob_threshold,
"status",
] = status
return HealthcheckAnalysisCard(
name="ConstraintsFeasibility",
title=f"Ax Constraints Feasibility {title_status}",
blob=json.dumps({"status": status}),
subtitle=subtitle,
df=df,
level=level,
)
[docs]
def constraints_feasibility(
optimization_config: OptimizationConfig,
model: ModelBridge,
prob_threshold: float = 0.99,
) -> Tuple[bool, pd.DataFrame]:
r"""
Check the feasibility of the constraints for the experiment.
Args:
optimization_config: Ax optimization config.
model: Ax model to use for predictions.
prob_threshold: Threshold for the probability of constraint violation.
Returns:
A tuple of a boolean indicating whether the constraints are feasible and a
dataframe with information on the probabilities of constraints violation for
each arm.
"""
if (optimization_config.outcome_constraints is None) or (
len(optimization_config.outcome_constraints) == 0
):
raise UserInputError("No constraints are specified.")
derel_optimization_config = optimization_config
outcome_constraints = optimization_config.outcome_constraints
if any(constraint.relative for constraint in outcome_constraints):
derel_optimization_config = Derelativize().transform_optimization_config(
optimization_config=optimization_config,
modelbridge=model,
)
constraint_metric_name = [
constraint.metric.name
for constraint in derel_optimization_config.outcome_constraints
][0]
arm_dict = get_predictions_by_arm(
model=model,
metric_name=constraint_metric_name,
outcome_constraints=derel_optimization_config.outcome_constraints,
)
df = pd.DataFrame(arm_dict)
constraints_feasible = True
if all(
arm_info["overall_probability_constraints_violated"] > prob_threshold
for arm_info in arm_dict
if arm_info["arm_name"] != model.status_quo_name
):
constraints_feasible = False
return constraints_feasible, df