Source code for ax.plot.feature_importances

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from logging import Logger
from typing import Any, Dict, Optional, Union

import numpy as np
import pandas as pd
import plotly.graph_objs as go
from ax.exceptions.core import NoDataError
from ax.modelbridge import ModelBridge
from ax.plot.base import AxPlotConfig, AxPlotTypes
from ax.plot.helper import compose_annotation
from ax.utils.common.logger import get_logger
from plotly import subplots

logger: Logger = get_logger(__name__)


[docs]def plot_feature_importance_plotly(df: pd.DataFrame, title: str) -> go.Figure: if df.empty: raise NoDataError("No Data on Feature Importances found.") df.set_index(df.columns[0], inplace=True) data = [ go.Bar(y=df.index, x=df[column_name], name=column_name, orientation="h") for column_name in df.columns ] fig = subplots.make_subplots( rows=len(df.columns), cols=1, subplot_titles=df.columns, print_grid=False, shared_xaxes=True, ) for idx, item in enumerate(data): fig.append_trace(item, idx + 1, 1) fig.layout.showlegend = False fig.layout.margin = go.layout.Margin( l=8 * min(max(len(idx) for idx in df.index), 75) # noqa E741 ) fig.layout.title = title return fig
[docs]def plot_feature_importance(df: pd.DataFrame, title: str) -> AxPlotConfig: """Wrapper method to convert plot_feature_importance_plotly to AxPlotConfig""" return AxPlotConfig( data=plot_feature_importance_plotly(df, title), plot_type=AxPlotTypes.GENERIC )
[docs]def plot_feature_importance_by_metric_plotly(model: ModelBridge) -> go.Figure: """One plot per feature, showing importances by metric.""" importances = [] for metric_name in sorted(model.metric_names): try: vals: Dict[str, Any] = model.feature_importances(metric_name) vals["index"] = metric_name importances.append(vals) except NotImplementedError: logger.warning( f"Model for {metric_name} does not support feature importances." ) if not importances: raise NotImplementedError( "Feature importances could not be calculated for any metric" ) df = pd.DataFrame(importances) # plot_feature_importance expects index in first column df = df.reindex(columns=(["index"] + [a for a in df.columns if a != "index"])) plot_fi = plot_feature_importance_plotly( df, "Absolute Parameter Importances by Metric" ) num_subplots = len(df.columns) num_features = len(df) # Include per-subplot margin for subplot titles (feature names). plot_fi["layout"]["height"] = num_subplots * (num_features + 1) * 50 return plot_fi
[docs]def plot_feature_importance_by_metric(model: ModelBridge) -> AxPlotConfig: """Wrapper method to convert plot_feature_importance_by_metric_plotly to AxPlotConfig""" return AxPlotConfig( data=plot_feature_importance_by_metric_plotly(model), plot_type=AxPlotTypes.GENERIC, )
[docs]def plot_feature_importance_by_feature_plotly( model: Optional[ModelBridge] = None, sensitivity_values: Optional[Dict[str, Dict[str, Union[float, np.ndarray]]]] = None, relative: bool = False, caption: str = "", importance_measure: str = "", ) -> go.Figure: """One plot per metric, showing importances by feature. Args: model: A model with a ``feature_importances`` method. sensitivity_values: The sensitivity values for each metric in a dict format. It takes the following format if only the sensitivity value is plotted: `{"metric1":{"parameter1":value1,"parameter2":value2 ...} ...}` It takes the following format if the sensitivity value and standard error are plotted: `{"metric1":{"parameter1":[value1,var,se], "parameter2":[[value2,var,se]]...}...}}`. relative: Whether to normalize feature importances so that they add to 1. caption: An HTML-formatted string to place at the bottom of the plot. importance_measure: The name of the importance metric to be added to the title. Returns a go.Figure of feature importances. """ if sensitivity_values is None: if model is None: raise ValueError( "A model is required when sensitivity values are not provided" ) try: sensitivity_values = { metric_name: model.feature_importances(metric_name) for i, metric_name in enumerate(sorted(model.metric_names)) } except NotImplementedError: raise NotImplementedError( "Feature importances cannot be computed by the model." ) traces = [] dropdown = [] for i, metric_name in enumerate(sorted(sensitivity_values.keys())): importances = sensitivity_values[metric_name] factor_col = "Factor" importance_col = "Importance" error_plot = np.asarray(next(iter(importances.values()))).size > 1 if error_plot: importance_col_se = "SE" df = pd.DataFrame( [ { factor_col: factor, importance_col: np.asarray(importance)[0], importance_col_se: np.asarray(importance)[2], } for factor, importance in importances.items() ] ) df = df.sort_values(importance_col) error_x = {"type": "data", "array": df[importance_col_se], "visible": True} else: df = pd.DataFrame( [ {factor_col: factor, importance_col: importance} for factor, importance in importances.items() ] ) df = df.sort_values(importance_col) error_x = None if relative: df[importance_col] = df[importance_col].div(df[importance_col].sum()) traces.append( go.Bar( name=importance_col, orientation="h", visible=i == 0, x=df[importance_col], y=df[factor_col], error_x=error_x, opacity=0.8, ) ) is_visible = [False] * len(sensitivity_values) is_visible[i] = True dropdown.append( {"args": ["visible", is_visible], "label": metric_name, "method": "restyle"} ) if not traces: raise NotImplementedError("No traces found for metric") updatemenus = [ { "x": 0, "y": 1, "yanchor": "top", "xanchor": "left", "buttons": dropdown, "pad": { "t": -40 }, # hack to put dropdown below title regardless of number of features } ] features = traces[0].y title = ( "Relative Parameter Importances" if relative else "Absolute Parameter Importances" ) if importance_measure: title = title + " based on " + importance_measure layout = go.Layout( height=200 + len(features) * 20, hovermode="closest", margin=go.layout.Margin( l=8 * min(max(len(idx) for idx in features), 75) ), # noqa E741 showlegend=False, title=title, updatemenus=updatemenus, annotations=compose_annotation(caption=caption), ) if relative: layout.update({"xaxis": {"tickformat": ".0%"}}) return go.Figure(data=traces, layout=layout)
[docs]def plot_feature_importance_by_feature( model: Optional[ModelBridge] = None, sensitivity_values: Optional[Dict[str, Dict[str, Union[float, np.ndarray]]]] = None, relative: bool = False, caption: str = "", importance_measure: str = "", ) -> AxPlotConfig: """Wrapper method to convert `plot_feature_importance_by_feature_plotly` to AxPlotConfig""" return AxPlotConfig( data=plot_feature_importance_by_feature_plotly( model=model, sensitivity_values=sensitivity_values, relative=relative, caption=caption, importance_measure=importance_measure, ), plot_type=AxPlotTypes.GENERIC, )
[docs]def plot_relative_feature_importance_plotly(model: ModelBridge) -> go.Figure: """Create a stacked bar chart of feature importances per metric""" importances = [] for metric_name in sorted(model.metric_names): try: vals: Dict[str, Any] = model.feature_importances(metric_name) vals["index"] = metric_name importances.append(vals) except Exception: logger.warning( "Model for {} does not support feature importances.".format(metric_name) ) df = pd.DataFrame(importances) df.set_index("index", inplace=True) df = df.div(df.sum(axis=1), axis=0) data = [ go.Bar(y=df.index, x=df[column_name], name=column_name, orientation="h") for column_name in df.columns ] layout = go.Layout( margin=go.layout.Margin(l=250), # noqa E741 barmode="group", yaxis={"title": ""}, xaxis={"title": "Relative Parameter importance"}, showlegend=False, title="Relative Parameter Importance per Metric", ) return go.Figure(data=data, layout=layout)
[docs]def plot_relative_feature_importance(model: ModelBridge) -> AxPlotConfig: """Wrapper method to convert plot_relative_feature_importance_plotly to AxPlotConfig""" return AxPlotConfig( data=plot_relative_feature_importance_plotly(model), plot_type=AxPlotTypes.GENERIC, )