Source code for ax.plot.marginal_effects

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, List

import pandas as pd
import plotly.graph_objs as go
from ax.modelbridge.base import ModelBridge
from ax.plot.base import AxPlotConfig, AxPlotTypes, DECIMALS
from ax.plot.helper import get_plot_data
from ax.utils.stats.statstools import marginal_effects
from plotly import subplots


[docs]def plot_marginal_effects(model: ModelBridge, metric: str) -> AxPlotConfig: """ Calculates and plots the marginal effects -- the effect of changing one factor away from the randomized distribution of the experiment and fixing it at a particular level. Args: model: Model to use for estimating effects metric: The metric for which to plot marginal effects. Returns: AxPlotConfig of the marginal effects """ plot_data, _, _ = get_plot_data(model, {}, {metric}) arm_dfs = [] for arm in plot_data.in_sample.values(): arm_df = pd.DataFrame(arm.parameters, index=[arm.name]) arm_df["mean"] = arm.y_hat[metric] arm_df["sem"] = arm.se_hat[metric] arm_dfs.append(arm_df) effect_table = marginal_effects(pd.concat(arm_dfs, axis=0)) varnames = effect_table["Name"].unique() # pyre-fixme[33]: Given annotation cannot contain `Any`. data: List[Any] = [] for varname in varnames: var_df = effect_table[effect_table["Name"] == varname] data += [ go.Bar( x=var_df["Level"], y=var_df["Beta"], error_y={"type": "data", "array": var_df["SE"]}, name=varname, ) ] fig = subplots.make_subplots( cols=len(varnames), rows=1, subplot_titles=list(varnames), print_grid=False, shared_yaxes=True, ) for idx, item in enumerate(data): fig.append_trace(item, 1, idx + 1) # pyre-ignore[16] fig.layout.showlegend = False # fig.layout.margin = go.layout.Margin(l=2, r=2) fig.layout.title = "Marginal Effects by Factor" fig.layout.yaxis = { "title": "% higher than experiment average", "hoverformat": ".{}f".format(DECIMALS), } return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)