SHAP Feature Importance

Calculates feature importance using a pre-built SHAP explainer.

SHAP Feature Importance

Processing

This brick analyzes your machine learning model to reveal which data features (columns) have the strongest influence on the model's predictions. It calculates the "Global Feature Importance" by averaging the absolute SHAP values across your dataset.

Essentially, it answers the question: "Which variables matter the most to this model?"

The brick produces two results: a ranked table of features and a visual bar chart, allowing you to quickly identify the key drivers behind your model's decisions (e.g., determining that "House Size" is the most important factor in a Price Prediction model).

Inputs

explainer
The pre-configured SHAP Explainer object. This is a special object created by a previous brick (like "Create SHAP Explainer") that contains the logic for interpreting your specific model.
data
The dataset you want to analyze. This should be the data formatted exactly as the model expects it (e.g., numeric, normalized values).
original data (optional)
The original, human-readable version of the dataset.
  • Why use this? If your data input has been transformed (e.g., one-hot encoded or scaled between 0 and 1), the column names might be lost or hard to read. Providing the original data allows the brick to restore the original column names (like "Customer Age" instead of "col_01") for the final report and chart.

Inputs Types

Input Types
explainer Any
data DataFrame
original data DataFrame

You can check the list of supported types here: Available Type Hints.

Outputs

image
A visual bar chart displaying the top features ranked by importance. By default, this is returned as a PIL image, ready to be displayed or saved.
features importance
A structured list containing every feature and its calculated importance score, sorted from most important to least important.

The features importance output contains the following specific data fields:

  • feature: The name of the data column (e.g., "Age", "Income").
  • importance: The calculated importance score (mean absolute SHAP value). Higher numbers mean the feature has a bigger impact on the model's output.

Outputs Types

Output Types
image MediaData,PILImage
features importance DataFrame

You can check the list of supported types here: Available Type Hints.

Options

The SHAP Feature Importance brick contains some changeable options:

Class Index (Multiclass)
Used only when analyzing a Multiclass Classification model (a model that predicts one of 3+ categories).
  • This integer determines which "class" (category) you want to explain.
  • For example, if your model predicts [Red, Green, Blue], 0 explains "Red", 1 explains "Green", and 2 explains "Blue".
  • If you are using a Regression (predicting a number) or Binary Classification model, this setting is ignored.
Max Display
Controls the number of bars shown in the generated chart. Increasing this allows you to see more features, but may make the chart cluttered.
Output Image Format
Determines the technical format of the output image.
  • pil: Returns a Python Imaging Library object (standard for image manipulation).
  • bytes: Returns the raw image file data (useful for saving directly to a file or database).
  • array: Returns the image as a NumPy array (useful for pixel-level processing).
Verbose
Toggles detailed logging. When enabled, the brick prints progress updates and warnings to the console.
import logging
import io
import inspect
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image
from shap.explainers._explainer import Explainer
from coded_flows.types import DataFrame, Any, Tuple, MediaData, PILImage, Union
from coded_flows.utils import CodedFlowsLogger

matplotlib.use("Agg")
import shap

logger = CodedFlowsLogger(name="SHAP Feature Importance", level=logging.INFO)
logging.getLogger("shap").setLevel(logging.ERROR)


def _call_explainer(explainer, data, **base_kwargs):
    sig = inspect.signature(explainer.__call__)
    if "silent" in sig.parameters:
        return explainer(data, silent=True, **base_kwargs)
    return explainer(data, **base_kwargs)


def _to_numpy(df):
    """Convert pandas or polars DataFrame to numpy array."""
    if hasattr(df, "__class__") and "polars" in df.__class__.__module__:
        return df.to_numpy()
    elif hasattr(df, "values"):
        return df.values
    elif hasattr(df, "to_numpy"):
        return df.to_numpy()
    else:
        return np.array(df)


def _get_columns(df):
    """Get column names from pandas or polars DataFrame."""
    if hasattr(df, "columns"):
        columns = df.columns
        return list(columns) if not isinstance(columns, list) else columns
    return None


def shap_features_importance(
    explainer: Any,
    data: DataFrame,
    original_data: DataFrame = None,
    options: dict = None,
) -> Tuple[Union[MediaData, PILImage], DataFrame]:
    options = options or {}
    verbose = options.get("verbose", True)
    class_index = options.get("class_index", 0)
    max_display = options.get("max_display", 10)
    image_format = options.get("image_format", "pil")
    image = None
    features_importance = pd.DataFrame()
    try:
        verbose and logger.info("Starting SHAP feature importance calculation.")
        if not isinstance(explainer, Explainer):
            verbose and logger.error("Expects a Shap Explainer as an input.")
            raise ValueError("Expects a Shap Explainer as an input.")
        shap_explanation = _call_explainer(explainer, data)
        if original_data is not None:
            verbose and logger.info(
                "Original data provided. Overwriting explanation data for visualization."
            )
            columns = _get_columns(original_data)
            if columns is not None:
                shap_explanation.feature_names = columns
            shap_explanation.data = _to_numpy(original_data)
        ndim = len(shap_explanation.shape)
        explanation_to_plot = None
        if ndim == 2:
            verbose and logger.info(
                "Detected Regression or Binary Classification (single output)."
            )
            explanation_to_plot = shap_explanation
        elif ndim == 3:
            num_classes = shap_explanation.shape[2]
            verbose and logger.info(
                f"Detected Multiclass output with {num_classes} classes."
            )
            if class_index >= num_classes or class_index < 0:
                verbose and logger.warning(
                    f"Selected class_index {class_index} is out of bounds (0-{num_classes - 1}). Defaulting to 0."
                )
                class_index = 0
            verbose and logger.info(f"Selecting class index: {class_index}")
            explanation_to_plot = shap_explanation[:, :, class_index]
        else:
            raise ValueError(
                f"Unexpected SHAP explanation dimensionality: {ndim}. Expected 2 or 3."
            )
        feature_names = explanation_to_plot.feature_names
        if not feature_names:
            feature_names = [
                f"Feature {i}" for i in range(explanation_to_plot.shape[1])
            ]
        importance_values = np.abs(explanation_to_plot.values).mean(axis=0)
        features_importance = (
            pd.DataFrame({"feature": feature_names, "importance": importance_values})
            .sort_values(by="importance", ascending=False)
            .reset_index(drop=True)
        )
        verbose and logger.info(
            f"Calculated importance for {len(features_importance)} features."
        )
        verbose and logger.info(f"Generating Bar Plot with max_display={max_display}.")
        fig = plt.figure()
        shap.plots.bar(explanation_to_plot, max_display=max_display, show=False)
        buf = io.BytesIO()
        plt.savefig(buf, format="png", bbox_inches="tight", dpi=300)
        buf.seek(0)
        plt.close(plt.gcf())
        plt.close(fig)
        if image_format == "pil":
            image = Image.open(buf)
        elif image_format == "array":
            img = Image.open(buf)
            image = np.array(img)
        elif image_format == "bytes":
            image = buf.getvalue()
        verbose and logger.info("Image generation complete.")
    except Exception as e:
        verbose and logger.error(f"Error computing SHAP importance: {str(e)}")
        raise e
    return (image, features_importance)

Brick Info

version v0.1.4
python 3.11, 3.12, 3.13
requirements
  • shap>=0.47.0
  • matplotlib
  • numpy
  • pandas
  • pillow
  • numba>=0.56.0
  • shap