Source code for ax.plot.feature_importances

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from logging import Logger
from typing import Any

import numpy as np
import numpy.typing as npt
import pandas as pd
import plotly.graph_objs as go
from ax.core.parameter import ChoiceParameter
from ax.exceptions.core import NoDataError
from ax.modelbridge import ModelBridge
from ax.plot.base import AxPlotConfig, AxPlotTypes
from ax.plot.helper import compose_annotation
from ax.utils.common.logger import get_logger
from plotly import subplots

logger: Logger = get_logger(__name__)


[docs] def plot_feature_importance_plotly(df: pd.DataFrame, title: str) -> go.Figure: if df.empty: raise NoDataError("No Data on Feature Importances found.") df.set_index(df.columns[0], inplace=True) data = [ go.Bar(y=df.index, x=df[column_name], name=column_name, orientation="h") for column_name in df.columns ] fig = subplots.make_subplots( rows=len(df.columns), cols=1, subplot_titles=df.columns, print_grid=False, shared_xaxes=True, ) for idx, item in enumerate(data): fig.append_trace(item, idx + 1, 1) fig.layout.showlegend = False fig.layout.margin = go.layout.Margin( l=8 * min(max(len(idx) for idx in df.index), 75) # noqa E741 ) fig.layout.title = title return fig
[docs] def plot_feature_importance(df: pd.DataFrame, title: str) -> AxPlotConfig: """Wrapper method to convert plot_feature_importance_plotly to AxPlotConfig""" return AxPlotConfig( # pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got # `Figure`. data=plot_feature_importance_plotly(df, title), plot_type=AxPlotTypes.GENERIC, )
[docs] def plot_feature_importance_by_metric_plotly(model: ModelBridge) -> go.Figure: """One plot per feature, showing importances by metric.""" importances = [] for metric_name in sorted(model.metric_names): try: vals: dict[str, Any] = model.feature_importances(metric_name) vals["index"] = metric_name importances.append(vals) except NotImplementedError: logger.warning( f"Model for {metric_name} does not support feature importances." ) if not importances: raise NotImplementedError( "Feature importances could not be calculated for any metric" ) df = pd.DataFrame(importances) # plot_feature_importance expects index in first column df = df.reindex(columns=(["index"] + [a for a in df.columns if a != "index"])) plot_fi = plot_feature_importance_plotly(df, "Parameter Sensitivity by Metric") num_subplots = len(df.columns) num_features = len(df) # Include per-subplot margin for subplot titles (feature names). plot_fi["layout"]["height"] = num_subplots * (num_features + 1) * 50 return plot_fi
[docs] def plot_feature_importance_by_metric(model: ModelBridge) -> AxPlotConfig: """Wrapper method to convert plot_feature_importance_by_metric_plotly to AxPlotConfig""" return AxPlotConfig( # pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got # `Figure`. data=plot_feature_importance_by_metric_plotly(model), plot_type=AxPlotTypes.GENERIC, )
[docs] def plot_feature_importance_by_feature_plotly( model: ModelBridge | None = None, sensitivity_values: dict[str, dict[str, float | npt.NDArray]] | None = None, relative: bool = False, caption: str = "", importance_measure: str = "", label_dict: dict[str, str] | None = None, ) -> go.Figure: """One plot per metric, showing importances by feature. If sensitivity values are not all positive, the absolute value will be shown and color will indicate positive or negative sign. Args: model: A model with a ``feature_importances`` method. sensitivity_values: The sensitivity values for each metric in a dict format. It takes the following format if only the sensitivity value is plotted: `{"metric1":{"parameter1":value1,"parameter2":value2 ...} ...}` It takes the following format if the sensitivity value and standard error are plotted: `{"metric1":{"parameter1":[value1,var,se], "parameter2":[[value2,var,se]]...}...}}`. relative: Whether to normalize feature importances so that they add to 1. caption: An HTML-formatted string to place at the bottom of the plot. importance_measure: The name of the importance metric to be added to the title. label_dict: A dictionary mapping metric names to short labels. Returns a go.Figure of feature importances. """ if sensitivity_values is None: if model is None: raise ValueError( "A model is required when sensitivity values are not provided" ) try: sensitivity_values = { metric_name: model.feature_importances(metric_name) for i, metric_name in enumerate(sorted(model.metric_names)) } except NotImplementedError: raise NotImplementedError( "Feature importances cannot be computed by the model." ) if label_dict is not None: sensitivity_values = { # pyre-ignore label_dict.get(metric_name, metric_name): v for metric_name, v in sensitivity_values.items() } traces = [] categorical_features = [] if model is not None: categorical_features = [ name for name, par in model.model_space.parameters.items() if isinstance(par, ChoiceParameter) and not par.is_ordered ] for i, metric_name in enumerate(sorted(sensitivity_values.keys())): importances = sensitivity_values[metric_name] factor_col = "Factor" importance_col = "Importance" sign_col = "Sign" error_plot = np.asarray(next(iter(importances.values()))).size > 1 if error_plot: importance_col_se = "SE" df = pd.DataFrame( [ { factor_col: factor, importance_col: np.asarray(importance)[0], importance_col_se: np.asarray(importance)[2], sign_col: ( 0 if factor in categorical_features else 2 * (np.asarray(importance)[0] >= 0).astype(int) - 1 ), } for factor, importance in importances.items() ] ) df[importance_col] = df[importance_col].abs() df = df.sort_values(importance_col) error_x = {"type": "data", "array": df[importance_col_se], "visible": True} else: df = pd.DataFrame( [ { factor_col: factor, importance_col: importance, sign_col: ( 0 if factor in categorical_features # pyre-fixme[16]: Item `bool` of # `Union[ndarray[typing.Any, np.dtype[typing.Any]], bool]` # has no attribute `astype`. else 2 * (importance >= 0).astype(int) - 1 ), } for factor, importance in importances.items() ] ) df[importance_col] = df[importance_col].abs() df = df.sort_values(importance_col) error_x = None if relative: df[importance_col] = df[importance_col].div(df[importance_col].sum()) colors = {-1: "darkorange", 0: "gray", 1: "steelblue"} names = { -1: "Decreases metric", 0: "Affects metric (categorical choice)", 1: "Increases metric", } legend_counter = {-1: 0, 0: 0, 1: 0} all_positive = all(df[sign_col] >= 0) for _, row in df.iterrows(): traces.append( go.Bar( name=names[row[sign_col]], orientation="h", visible=i == 0, x=np.array([row[importance_col]]), y=np.array([row[factor_col]]), error_x=error_x, opacity=0.8, marker_color=colors[row[sign_col]], showlegend=(not all_positive) and (legend_counter[row[sign_col]] == 0), legendgroup=str(row[sign_col]), ) ) legend_counter[row[sign_col]] += 1 is_visible = [False] * (len(sensitivity_values) * len(df)) for j in range(i * len(df), (i + 1) * len(df)): is_visible[j] = True if not traces: raise NotImplementedError("No traces found for metric") features = list(list(sensitivity_values.values())[0].keys()) longest_label = max(len(f) for f in features) longest_metric = max(len(m) for m in sensitivity_values.keys()) layout = go.Layout( height=len(features) * 20, width=10 * longest_label + max(10 * longest_metric, 400), hovermode="closest", annotations=compose_annotation(caption=caption), ) if relative: layout.update({"xaxis": {"tickformat": ".0%"}}) return go.Figure(data=traces, layout=layout)
[docs] def plot_feature_importance_by_feature( model: ModelBridge | None = None, sensitivity_values: dict[str, dict[str, float | npt.NDArray]] | None = None, relative: bool = False, caption: str = "", importance_measure: str = "", label_dict: dict[str, str] | None = None, ) -> AxPlotConfig: """Wrapper method to convert `plot_feature_importance_by_feature_plotly` to AxPlotConfig""" return AxPlotConfig( # pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got # `Figure`. data=plot_feature_importance_by_feature_plotly( model=model, sensitivity_values=sensitivity_values, relative=relative, caption=caption, importance_measure=importance_measure, label_dict=label_dict, ), plot_type=AxPlotTypes.GENERIC, )
[docs] def plot_relative_feature_importance_plotly(model: ModelBridge) -> go.Figure: """Create a stacked bar chart of feature importances per metric""" importances = [] for metric_name in sorted(model.metric_names): try: vals: dict[str, Any] = model.feature_importances(metric_name) vals["index"] = metric_name importances.append(vals) except Exception: logger.warning( f"Model for {metric_name} does not support feature importances." ) df = pd.DataFrame(importances) df.set_index("index", inplace=True) df = df.div(df.sum(axis=1), axis=0) data = [ go.Bar(y=df.index, x=df[column_name], name=column_name, orientation="h") for column_name in df.columns ] layout = go.Layout( margin=go.layout.Margin(l=250), # noqa E741 barmode="group", yaxis={"title": ""}, xaxis={"title": "Relative Parameter importance"}, showlegend=False, title="Relative Parameter Importance per Metric", ) return go.Figure(data=data, layout=layout)
[docs] def plot_relative_feature_importance(model: ModelBridge) -> AxPlotConfig: """Wrapper method to convert plot_relative_feature_importance_plotly to AxPlotConfig""" return AxPlotConfig( # pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got # `Figure`. data=plot_relative_feature_importance_plotly(model), plot_type=AxPlotTypes.GENERIC, )