Source code for ax.analysis.plotly.arm_effects.predicted_effects

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from itertools import chain
from typing import Any

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel

from ax.analysis.plotly.arm_effects.utils import (
    get_predictions_by_arm,
    prepare_arm_effects_plot,
)

from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.analysis.plotly.utils import is_predictive
from ax.core import OutcomeConstraint
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.exceptions.core import UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.transforms.derelativize import Derelativize
from ax.utils.common.typeutils import checked_cast
from pyre_extensions import none_throws


[docs] class PredictedEffectsPlot(PlotlyAnalysis): """ Plotly Predicted Effecs plot for a single metric, with one point per unique arm across all trials. It plots all observed points, as well as predictions for the most recently generated trial. This plot is useful for understanding how arms in a candidate trial can be expected to perform. The DataFrame computed will contain one row per arm and the following columns: - source: In-sample or model key that geneerated the candidate - arm_name: The name of the arm - mean: The observed or predicted mean of the metric specified - sem: The observed or predicted sem of the metric specified - error_margin: The 95% CI of the metric specified for the arm - size_column: The size of the circle in the plot, which represents the probability that the arm is feasible (does not violate any constraints). - parameters: A string representation of the parameters for the arm to be viewed in the tooltip. - constraints_violated: A string representation of the probability each constraint is violated for the arm, to be viewed in the tooltip. """ CARD_NAME = "PredictedEffectsPlot" trial_index: int | None = None def __init__(self, metric_name: str) -> None: """ Args: metric_name: The name of the metric to plot. If not specified the objective will be used. Note that the metric cannot be inferred for multi-objective or scalarized-objective experiments. """ self.metric_name = metric_name
[docs] def compute( self, experiment: Experiment | None = None, generation_strategy: GenerationStrategyInterface | None = None, ) -> PlotlyAnalysisCard: if experiment is None: raise UserInputError("PredictedEffectsPlot requires an Experiment.") generation_strategy = checked_cast( GenerationStrategy, generation_strategy, exception=UserInputError( "PredictedEffectsPlot requires a GenerationStrategy." ), ) try: trial_indices = [ t.index for t in experiment.trials.values() if t.status != TrialStatus.ABANDONED ] candidate_trial = experiment.trials[max(trial_indices)] # This is so the card will have a trial_index attribute self.trial_index = candidate_trial.index except ValueError: raise UserInputError( f"PredictedEffectsPlot cannot be used for {experiment} " "because it has no trials." ) if generation_strategy.model is None: generation_strategy._fit_current_model(data=experiment.lookup_data()) model = none_throws(generation_strategy.model) if not is_predictive(model=model): raise UserInputError( "PredictedEffectsPlot requires a GenerationStrategy which is " "in a state where the current model supports prediction. The current " f"model is {model._model_key} and does not support prediction." ) outcome_constraints = ( [] if experiment.optimization_config is None else Derelativize() .transform_optimization_config( # TODO[T203521207]: move cloning into transform_optimization_config optimization_config=none_throws(experiment.optimization_config).clone(), modelbridge=model, ) .outcome_constraints ) df = _prepare_data( model=model, metric_name=self.metric_name, candidate_trial=candidate_trial, outcome_constraints=outcome_constraints, ) fig = prepare_arm_effects_plot( df=df, metric_name=self.metric_name, outcome_constraints=outcome_constraints ) level = AnalysisCardLevel.HIGH nudge = -2 if experiment.optimization_config is not None: if ( self.metric_name in experiment.optimization_config.objective.metric_names ): nudge = 0 elif self.metric_name in experiment.optimization_config.metrics: nudge = -1 return self._create_plotly_analysis_card( title=f"Predicted Effects for {self.metric_name}", subtitle="View a candidate trial and its arms' predicted metric values", level=level + nudge, df=df, fig=fig, )
def _prepare_data( model: ModelBridge, metric_name: str, candidate_trial: BaseTrial, outcome_constraints: list[OutcomeConstraint], ) -> pd.DataFrame: """Prepare data for plotting. Data should include columns for: - source: In-sample or model key that geneerated the candidate - arm_name: Name of the arm - mean: Predicted metric value - error_margin: 1.96 * predicted sem for plotting 95% CI - **PARAMETER_NAME: The value of each parameter for the arm. Will be used for the tooltip. There will be one row for each arm in the model's training data and one for each arm in the generator runs of the candidate trial. If an arm is in both the training data and the candidate trial, it will only appear once for the candidate trial. Args: model: ModelBridge being used for prediction metric_name: Name of metric to plot candidate_trial: Trial to plot candidates for by generator run """ predictions_for_observed_arms: list[dict[str, Any]] = get_predictions_by_arm( model=model, metric_name=metric_name, outcome_constraints=outcome_constraints, ) candidate_generator_run_predictions: list[list[dict[str, Any]]] = ( [] if candidate_trial is None else [ get_predictions_by_arm( model=model, metric_name=metric_name, outcome_constraints=outcome_constraints, gr=gr, ) for gr in candidate_trial.generator_runs ] ) df = pd.DataFrame.from_records( list( chain( predictions_for_observed_arms, *candidate_generator_run_predictions, ) ) ) df.drop_duplicates(subset="arm_name", keep="last", inplace=True) return df