Source code for ax.analysis.utils

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import itertools

from ax.analysis.analysis import Analysis
from ax.analysis.plotly.cross_validation import CrossValidationPlot
from ax.analysis.plotly.interaction import InteractionPlot
from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot
from ax.analysis.plotly.scatter import ScatterPlot
from ax.analysis.summary import Summary
from ax.core.experiment import Experiment
from ax.core.objective import MultiObjective, ScalarizedObjective


[docs] def choose_analyses(experiment: Experiment) -> list[Analysis]: """ Choose a default set of Analyses to compute based on the current state of the Experiment. """ if (optimization_config := experiment.optimization_config) is None: return [] if isinstance(optimization_config.objective, MultiObjective) or isinstance( optimization_config.objective, ScalarizedObjective ): # Pareto frontiers for each objective objective_plots = [ *[ ScatterPlot(x_metric_name=x, y_metric_name=y, show_pareto_frontier=True) for x, y in itertools.combinations( optimization_config.objective.metric_names, 2 ) ], ] other_scatters = [] interactions = [ InteractionPlot(metric_name=name) for name in optimization_config.objective.metric_names ] else: objective_name = optimization_config.objective.metric.name # ParallelCoorindates and leave-one-out cross validation objective_plots = [ ParallelCoordinatesPlot(metric_name=objective_name), ] # Up to six ScatterPlots for other metrics versus the objective, # prioritizing optimization config metrics over tracking metrics tracking_metric_names = [metric.name for metric in experiment.tracking_metrics] other_scatters = [ ScatterPlot( x_metric_name=objective_name, y_metric_name=name, show_pareto_frontier=False, ) for name in [ *optimization_config.metrics, *tracking_metric_names, ] if name != objective_name ][:6] interactions = [InteractionPlot(metric_name=objective_name)] # Leave-one-out cross validation for each objective and outcome constraint cv_plots = [ CrossValidationPlot(metric_name=name) for name in optimization_config.metrics ] return [*objective_plots, *other_scatters, *interactions, *cv_plots, Summary()]