#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
from logging import Logger
from typing import Iterable, Tuple, Union, Any, Dict, List, Optional
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRunType
from ax.core.metric import Metric
from ax.core.multi_type_experiment import MultiTypeExperiment
from ax.core.objective import ScalarizedObjective
from ax.core.trial import BaseTrial
from ax.modelbridge import ModelBridge
from ax.modelbridge.cross_validation import cross_validate
from ax.plot.contour import interact_contour_plotly
from ax.plot.diagnostic import interact_cross_validation_plotly
from ax.plot.feature_importances import plot_feature_importance_by_feature_plotly
from ax.plot.pareto_frontier import (
scatter_plot_with_pareto_frontier_plotly,
_pareto_frontier_plot_input_processing,
)
from ax.plot.pareto_utils import _extract_observed_pareto_2d
from ax.plot.scatter import interact_fitted_plotly, plot_multiple_metrics
from ax.plot.slice import interact_slice_plotly
from ax.plot.trace import optimization_trace_single_method_plotly
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
logger: Logger = get_logger(__name__)
FEATURE_IMPORTANCE_CAPTION = (
"<b>NOTE:</b> This plot is intended for advanced users. Specifically,<br>"
"it is a measure of sensitivity/smoothness, so parameters of<br>"
"relatively low importance may still be important to tune."
)
CROSS_VALIDATION_CAPTION = (
'<b>NOTE:</b> Models that <a href="https://en.wikipedia.org/wiki/Winsorizing">'
"winsorize</a> may appear to plateau<br>far from the optimum. Remove outliers "
"using the<br>zoom tool before judging the fit."
)
def _get_hypervolume_trace() -> None:
logger.warning(
"Objective trace plots not yet implemented for multi-objective optimization. "
"Returning `None`."
)
# pyre-ignore[11]: Annotation `go.Figure` is not defined as a type.
def _get_cross_validation_plots(model: ModelBridge) -> List[go.Figure]:
cv = cross_validate(model=model)
return [
interact_cross_validation_plotly(
cv_results=cv, caption=CROSS_VALIDATION_CAPTION
)
]
def _get_objective_trace_plot(
experiment: Experiment,
data: Data,
model_transitions: List[int],
true_objective_metric_name: Optional[str] = None,
) -> Iterable[go.Figure]:
if experiment.is_moo_problem:
# TODO: implement `_get_hypervolume_trace()`
return [_pareto_frontier_scatter_2d_plotly(experiment=experiment)]
optimization_config = experiment.optimization_config
if optimization_config is None:
return []
metric_names = (
metric_name
for metric_name in [
optimization_config.objective.metric.name,
true_objective_metric_name,
]
if metric_name is not None
)
plots = [
optimization_trace_single_method_plotly(
y=np.array([data.df[data.df["metric_name"] == metric_name]["mean"]]),
title=f"Best {metric_name} found vs. # of iterations",
ylabel=metric_name,
model_transitions=model_transitions,
# Try and use the metric's lower_is_better property, but fall back on
# objective's minimize property if relevent
optimization_direction=(
(
"minimize"
if experiment.metrics[metric_name].lower_is_better is True
else "maximize"
)
if experiment.metrics[metric_name].lower_is_better is not None
else (
"minimize" if optimization_config.objective.minimize else "maximize"
)
),
plot_trial_points=True,
)
for metric_name in metric_names
]
return [plot for plot in plots if plot is not None]
def _get_objective_v_param_plots(
experiment: Experiment,
model: ModelBridge,
) -> List[go.Figure]:
search_space = experiment.search_space
range_params = list(search_space.range_parameters.keys())
if len(range_params) < 1:
# if search space contains no range params
logger.warning(
"`_get_objective_v_param_plot` requires a search space with at least one "
"`RangeParameter`. Returning an empty list."
)
return []
# parameter slice plot
output_plots = [
interact_slice_plotly(
model=model,
)
]
if len(range_params) > 1:
# contour plots
try:
output_plots += [
interact_contour_plotly(
model=not_none(model),
metric_name=metric_name,
)
for metric_name in model.metric_names
]
# `mean shape torch.Size` RunTimeErrors, pending resolution of
# https://github.com/cornellius-gp/gpytorch/issues/1853
except RuntimeError as e:
logger.warning(f"Contour plotting failed with error: {e}.")
return output_plots
def _get_suffix(input_str: str, delim: str = ".", n_chunks: int = 1) -> str:
return delim.join(input_str.split(delim)[-n_chunks:])
def _get_shortest_unique_suffix_dict(
input_str_list: List[str], delim: str = "."
) -> Dict[str, str]:
"""Maps a list of strings to their shortest unique suffixes
Maps all original strings to the smallest number of chunks, as specified by
delim, that are not a suffix of any other original string. If the original
string was a suffix of another string, map it to its unaltered self.
Args:
input_str_list: a list of strings to create the suffix mapping for
delim: the delimiter used to split up the strings into meaningful chunks
Returns:
dict: A dict with the original strings as keys and their abbreviations as
values
"""
# all input strings must be unique
assert len(input_str_list) == len(set(input_str_list))
if delim == "":
raise ValueError("delim must be a non-empty string.")
suffix_dict = defaultdict(list)
# initialize suffix_dict with last chunk
for istr in input_str_list:
suffix_dict[_get_suffix(istr, delim=delim, n_chunks=1)].append(istr)
max_chunks = max(len(istr.split(delim)) for istr in input_str_list)
if max_chunks == 1:
return {istr: istr for istr in input_str_list}
# the upper range of this loop is `max_chunks + 2` because:
# - `i` needs to take the value of `max_chunks`, hence one +1
# - the contents of the loop are run one more time to check if `all_unique`,
# hence the other +1
for i in range(2, max_chunks + 2):
new_dict = defaultdict(list)
all_unique = True
for suffix, suffix_str_list in suffix_dict.items():
if len(suffix_str_list) > 1:
all_unique = False
for istr in suffix_str_list:
new_dict[_get_suffix(istr, delim=delim, n_chunks=i)].append(istr)
else:
new_dict[suffix] = suffix_str_list
if all_unique:
if len(set(input_str_list)) != len(suffix_dict.keys()):
break
return {
suffix_str_list[0]: suffix
for suffix, suffix_str_list in suffix_dict.items()
}
suffix_dict = new_dict
# If this function has not yet exited, some input strings still share a suffix.
# This is not expected, but in this case, the function will return the identity
# mapping, i.e., a dict with the original strings as both keys and values.
logger.warning(
"Something went wrong. Returning dictionary with original strings as keys and "
"values."
)
return {istr: istr for istr in input_str_list}
[docs]def get_standard_plots(
experiment: Experiment,
model: Optional[ModelBridge],
data: Optional[Data] = None,
model_transitions: Optional[List[int]] = None,
true_objective_metric_name: Optional[str] = None,
) -> List[go.Figure]:
"""Extract standard plots for single-objective optimization.
Extracts a list of plots from an ``Experiment`` and ``ModelBridge`` of general
interest to an Ax user. Currently not supported are
- TODO: multi-objective optimization
- TODO: ChoiceParameter plots
Args:
- experiment: The ``Experiment`` from which to obtain standard plots.
- model: The ``ModelBridge`` used to suggest trial parameters.
- data: If specified, data, to which to fit the model before generating plots.
- model_transitions: The arm numbers at which shifts in generation_strategy
occur.
Returns:
- a plot of objective value vs. trial index, to show experiment progression
- a plot of objective value vs. range parameter values, only included if the
model associated with generation_strategy can create predictions. This
consists of:
- a plot_slice plot if the search space contains one range parameter
- an interact_contour plot if the search space contains multiple
range parameters
"""
if (
true_objective_metric_name is not None
and true_objective_metric_name not in experiment.metrics.keys()
):
raise ValueError(
f"true_objective_metric_name='{true_objective_metric_name}' is not present "
f"in experiment.metrics={experiment.metrics}. Please add a valid "
"true_objective_metric_name or remove the optional parameter to get "
"standard plots."
)
objective = not_none(experiment.optimization_config).objective
if isinstance(objective, ScalarizedObjective):
logger.warning(
"get_standard_plots does not currently support ScalarizedObjective "
"optimization experiments. Returning an empty list."
)
return []
if data is None:
data = experiment.fetch_data()
if data.df.empty:
logger.info(f"Experiment {experiment} does not yet have data, nothing to plot.")
return []
output_plot_list = []
output_plot_list.extend(
_get_objective_trace_plot(
experiment=experiment,
data=data,
model_transitions=model_transitions
if model_transitions is not None
else [],
true_objective_metric_name=true_objective_metric_name,
)
)
# Objective vs. parameter plot requires a `Model`, so add it only if model
# is alrady available. In cases where initially custom trials are attached,
# model might not yet be set on the generation strategy.
if model:
# TODO: Check if model can predict in favor of try/catch.
try:
if true_objective_metric_name is not None:
output_plot_list.append(
_objective_vs_true_objective_scatter(
model=model,
objective_metric_name=objective.metric_names[0],
true_objective_metric_name=true_objective_metric_name,
)
)
output_plot_list.extend(
_get_objective_v_param_plots(
experiment=experiment,
model=model,
)
)
output_plot_list.extend(_get_cross_validation_plots(model=model))
feature_importance_plot = plot_feature_importance_by_feature_plotly(
model=model, relative=False, caption=FEATURE_IMPORTANCE_CAPTION
)
feature_importance_plot.layout.title = "[ADVANCED] " + str(
# pyre-fixme[16]: go.Figure has no attribute `layout`
feature_importance_plot.layout.title.text
)
output_plot_list.append(feature_importance_plot)
output_plot_list.append(interact_fitted_plotly(model=model, rel=False))
except NotImplementedError:
# Model does not implement `predict` method.
pass
return [plot for plot in output_plot_list if plot is not None]
def _merge_trials_dict_with_df(
df: pd.DataFrame, trials_dict: Dict[int, Any], column_name: str
) -> None:
"""Add a column ``column_name`` to a DataFrame ``df`` containing a column
``trial_index``. Each value of the new column is given by the element of
``trials_dict`` indexed by ``trial_index``.
Args:
df: Pandas DataFrame with column ``trial_index``, to be appended with a new
column.
trials_dict: Dict mapping each ``trial_index`` to a value. The new column of
df will be populated with the value corresponding with the
``trial_index`` of each row.
column_name: Name of the column to be appended to ``df``.
"""
if "trial_index" not in df.columns:
raise ValueError("df must have trial_index column")
if any(trials_dict.values()): # field present for any trial
if not all(trials_dict.values()): # not present for all trials
logger.warning(
f"Column {column_name} missing for some trials. "
"Filling with None when missing."
)
df[column_name] = [trials_dict[trial_index] for trial_index in df.trial_index]
else:
logger.warning(
f"Column {column_name} missing for all trials. " "Not appending column."
)
def _get_generation_method_str(trial: BaseTrial) -> str:
generation_methods = {
not_none(generator_run._model_key)
for generator_run in trial.generator_runs
if generator_run._model_key is not None
}
# add "Manual" if any generator_runs are manual
if any(
generator_run.generator_run_type == GeneratorRunType.MANUAL.name
for generator_run in trial.generator_runs
):
generation_methods.add("Manual")
return ", ".join(generation_methods) if generation_methods else "Unknown"
def _merge_results_if_no_duplicates(
arms_df: pd.DataFrame,
data: Data,
key_components: List[str],
metrics: List[Metric],
):
"""Formats ``data.df`` and merges it with ``arms_df`` if all of the following are
True:
- ``data.df`` is not empty
- ``data.df`` contains columns corresponding to ``key_components``
- after any formatting, ``data.df`` contains no duplicates of the column
``results_key_col``
"""
results = data.df
if len(results.index) == 0:
logger.info(
f"No results present for the specified metrics `{metrics}`. "
"Returning arm parameters and metadata only."
)
return arms_df
if not all(col in results.columns for col in key_components):
logger.warning(
f"At least one of key columns `{key_components}` not present in results df "
f"`{results}`. Returning arm parameters and metadata only."
)
return arms_df
# prepare results for merge
key_vals = results[key_components[0]].astype("str")
for key_component in key_components[1:]:
key_vals += results[key_component].astype("str")
results_key_col = "-".join(key_components)
results[results_key_col] = key_vals
# Don't return results if duplicates remain
if any(results.duplicated(subset=[results_key_col, "metric_name"])):
logger.warning(
"Experimental results dataframe contains multiple rows with the same "
f"keys {results_key_col}. Returning dataframe without results."
)
return arms_df
metric_vals = results.pivot(
index=results_key_col, columns="metric_name", values="mean"
).reset_index()
# dedupe results by key_components
metadata = results[key_components + [results_key_col]].drop_duplicates()
metrics_df = pd.merge(metric_vals, metadata, on=results_key_col)
# drop synthetic key column
metrics_df = metrics_df.drop(results_key_col, axis=1)
# merge and return
return pd.merge(metrics_df, arms_df, on=key_components, how="outer")
[docs]def exp_to_df(
exp: Experiment,
metrics: Optional[List[Metric]] = None,
run_metadata_fields: Optional[List[str]] = None,
trial_properties_fields: Optional[List[str]] = None,
**kwargs: Any,
) -> pd.DataFrame:
"""Transforms an experiment to a DataFrame with rows keyed by trial_index
and arm_name, metrics pivoted into one row. If the pivot results in more than
one row per arm (or one row per ``arm * map_keys`` combination if ``map_keys`` are
present), results are omitted and warning is produced. Only supports
``Experiment``.
Transforms an ``Experiment`` into a ``pd.DataFrame``.
Args:
exp: An ``Experiment`` that may have pending trials.
metrics: Override list of metrics to return. Return all metrics if ``None``.
run_metadata_fields: fields to extract from ``trial.run_metadata`` for trial
in ``experiment.trials``. If there are multiple arms per trial, these
fields will be replicated across the arms of a trial.
trial_properties_fields: fields to extract from ``trial._properties`` for trial
in ``experiment.trials``. If there are multiple arms per trial, these
fields will be replicated across the arms of a trial. Output columns names
will be prepended with ``"trial_properties_"``.
Returns:
DataFrame: A dataframe of inputs, metadata and metrics by trial and arm (and
``map_keys``, if present). If no trials are available, returns an empty
dataframe. If no metric ouputs are available, returns a dataframe of inputs and
metadata.
"""
if len(kwargs) > 0:
logger.warn(
"`kwargs` in exp_to_df is deprecated. Please remove extra arguments."
)
# Accept Experiment and SimpleExperiment
if isinstance(exp, MultiTypeExperiment):
raise ValueError("Cannot transform MultiTypeExperiments to DataFrames.")
key_components = ["trial_index", "arm_name"]
# Get each trial-arm with parameters
arms_df = pd.DataFrame()
for trial_index, trial in exp.trials.items():
for arm in trial.arms:
arms_df = arms_df.append(
{"arm_name": arm.name, "trial_index": trial_index, **arm.parameters},
ignore_index=True,
)
# Fetch results; in case arms_df is empty, return empty results (legacy behavior)
data = exp.lookup_data()
results = data.df
if len(arms_df.index) == 0:
if len(results.index) != 0:
raise ValueError(
"exp.lookup_data().df returned more rows than there are experimental "
"arms. This is an inconsistent experimental state. Please report to "
"Ax support."
)
return results
# Create key column from key_components
arms_df["trial_index"] = arms_df["trial_index"].astype(int)
# Add trial status
trials = exp.trials.items()
trial_to_status = {index: trial.status.name for index, trial in trials}
_merge_trials_dict_with_df(
df=arms_df, trials_dict=trial_to_status, column_name="trial_status"
)
# Add generation_method, accounting for the generic case that generator_runs is of
# arbitrary length. Repeated methods within a trial are condensed via `set` and an
# empty set will yield "Unknown" as the method.
trial_to_generation_method = {
trial_index: _get_generation_method_str(trial) for trial_index, trial in trials
}
_merge_trials_dict_with_df(
df=arms_df,
trials_dict=trial_to_generation_method,
column_name="generation_method",
)
# Add any trial properties fields to arms_df
if trial_properties_fields is not None:
# add trial._properties fields
for field in trial_properties_fields:
trial_to_properties_field = {
trial_index: (
trial._properties[field] if field in trial._properties else None
)
for trial_index, trial in trials
}
_merge_trials_dict_with_df(
df=arms_df,
trials_dict=trial_to_properties_field,
column_name="trial_properties_" + field,
)
# Add any run_metadata fields to arms_df
if run_metadata_fields is not None:
# add run_metadata fields
for field in run_metadata_fields:
trial_to_metadata_field = {
trial_index: (
trial.run_metadata[field] if field in trial.run_metadata else None
)
for trial_index, trial in trials
}
_merge_trials_dict_with_df(
df=arms_df,
trials_dict=trial_to_metadata_field,
column_name=field,
)
exp_df = _merge_results_if_no_duplicates(
arms_df=arms_df,
data=data,
key_components=key_components,
metrics=metrics or list(exp.metrics.values()),
)
return not_none(not_none(exp_df).sort_values(["trial_index"]))
def _pareto_frontier_scatter_2d_plotly(
experiment: Experiment,
metric_names: Optional[Tuple[str, str]] = None,
reference_point: Optional[Tuple[float, float]] = None,
minimize: Optional[Union[bool, Tuple[bool, bool]]] = None,
) -> go.Figure:
# Determine defaults for unspecified inputs using `optimization_config`
metric_names, reference_point, minimize = _pareto_frontier_plot_input_processing(
experiment=experiment,
metric_names=metric_names,
reference_point=reference_point,
minimize=minimize,
)
df = exp_to_df(experiment)
Y = df[list(metric_names)].to_numpy()
Y_pareto = (
_extract_observed_pareto_2d(
Y=Y, reference_point=reference_point, minimize=minimize
)
if minimize is not None
else None
)
return scatter_plot_with_pareto_frontier_plotly(
Y=Y,
Y_pareto=Y_pareto,
metric_x=metric_names[0],
metric_y=metric_names[1],
reference_point=reference_point,
minimize=minimize,
)
def _objective_vs_true_objective_scatter(
model: ModelBridge,
objective_metric_name: str,
true_objective_metric_name: str,
) -> go.Figure:
plot = plot_multiple_metrics(
model=model,
metric_x=objective_metric_name,
metric_y=true_objective_metric_name,
rel_x=False,
rel_y=False,
)
fig = go.Figure(plot.data)
fig.layout.title.text = (
f"Objective {objective_metric_name} vs. True Objective "
f"Metric {true_objective_metric_name}"
)
return fig