SHAP Heatmap

Generates a SHAP heatmap visualizing feature impacts across instances. Supports ordering by similarity or sum of SHAP values.

SHAP Heatmap

Processing

This brick generates a heatmap visualization to help you understand how your machine learning model makes decisions across a whole group of data points (instances).

The Heatmap displays the specific impact of features for every single instance in your dataset simultaneously. It groups similar instances together on the horizontal axis and stacks features on the vertical axis. This allows you to identify clusters of data with similar model behavior and see how feature importance shifts across different segments of your population.

Inputs

explainer
The SHAP Explainer object created by a previous brick (e.g., "Calculate SHAP Values"). This contains the rules and logic the model uses to make predictions.
data
The dataset you want to analyze. This should be the same data (or a subset of it) that you want to visualize explanations for.
original data
(Optional) If your primary data input was pre-processed (e.g., normalized between 0 and 1 or encoded), you can provide the original, human-readable dataset here. The brick will use the data for calculation but display the values from original data on the chart, making it easier to read.

Inputs Types

Input Types
explainer Any
data DataFrame
original data DataFrame

You can check the list of supported types here: Available Type Hints.

Outputs

image
The generated heatmap visualization. Depending on your settings, this is returned as an image object, raw bytes, or a data array, ready to be displayed or saved.

Outputs Types

Output Types
image MediaData,PILImage

You can check the list of supported types here: Available Type Hints.

Options

The SHAP Heatmap brick contains some changeable options:

Class Index (Multiclass)
Used only if your model predicts multiple categories (e.g., "Red", "Blue", "Green"). This integer selects which category to visualize (0 for the first category, 1 for the second, etc.).
Max Display
Controls the number of features shown on the vertical axis. It selects the top most important features to display (default is 10) to keep the chart readable.
Order Instances By
Determines how the data points are arranged along the horizontal axis.
  • similarity: Groups instances that have similar explanations together using hierarchical clustering. This is useful for finding patterns or distinct segments in your data.
  • sum: Orders instances by the total magnitude of their SHAP values. This helps visualize high-impact vs. low-impact cases.
Output Image Format
Defines the technical format of the output image.
  • pil: Returns a Python Imaging Library (PIL) object. Best for standard image processing.
  • bytes: Returns the raw image file data.
  • array: Returns the image as a NumPy array of pixel values. Best for mathematical image manipulation.
Verbose
If enabled, detailed logs about the heatmap generation process will be printed to the console.
import logging
import io
import inspect
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image
from shap.explainers._explainer import Explainer
from coded_flows.types import DataFrame, Any, MediaData, Union, PILImage
from coded_flows.utils import CodedFlowsLogger

matplotlib.use("Agg")
import shap

logger = CodedFlowsLogger(name="SHAP Heatmap", level=logging.INFO)
logging.getLogger("shap").setLevel(logging.ERROR)


def _call_explainer(explainer, data, **base_kwargs):
    sig = inspect.signature(explainer.__call__)
    if "silent" in sig.parameters:
        return explainer(data, silent=True, **base_kwargs)
    return explainer(data, **base_kwargs)


def _to_numpy(df):
    """Convert pandas or polars DataFrame to numpy array."""
    if hasattr(df, "__class__") and "polars" in df.__class__.__module__:
        return df.to_numpy()
    elif hasattr(df, "values"):
        return df.values
    elif hasattr(df, "to_numpy"):
        return df.to_numpy()
    else:
        return np.array(df)


def _get_columns(df):
    """Get column names from pandas or polars DataFrame."""
    if hasattr(df, "columns"):
        columns = df.columns
        return list(columns) if not isinstance(columns, list) else columns
    return None


def shap_heatmap(
    explainer: Any,
    data: DataFrame,
    original_data: DataFrame = None,
    options: dict = None,
) -> Union[MediaData, PILImage]:
    """
    Generates a SHAP heatmap visualizing feature impacts across instances.
    """
    options = options or {}
    verbose = options.get("verbose", True)
    class_index = options.get("class_index", 0)
    max_display = options.get("max_display", 10)
    order_by = options.get("order_by", "similarity")
    image_format = options.get("image_format", "pil")
    image = None
    try:
        verbose and logger.info("Starting SHAP Heatmap generation.")
        if not isinstance(explainer, Explainer):
            verbose and logger.error("Expects a Shap Explainer as an input.")
            raise ValueError("Expects a Shap Explainer as an input.")
        shap_explanation = _call_explainer(explainer, data)
        if original_data is not None:
            verbose and logger.info(
                "Original data provided. Overwriting explanation data for visualization."
            )
            columns = _get_columns(original_data)
            if columns is not None:
                shap_explanation.feature_names = columns
            shap_explanation.data = _to_numpy(original_data)
        ndim = len(shap_explanation.shape)
        explanation_to_process = None
        if ndim == 2:
            verbose and logger.info("Detected Regression or Binary Classification.")
            explanation_to_process = shap_explanation
        elif ndim == 3:
            num_classes = shap_explanation.shape[2]
            verbose and logger.info(
                f"Detected Multiclass output with {num_classes} classes."
            )
            if class_index >= num_classes or class_index < 0:
                verbose and logger.warning(
                    f"Class index {class_index} out of bounds. Defaulting to 0."
                )
                class_index = 0
            verbose and logger.info(f"Selecting class index: {class_index}")
            explanation_to_process = shap_explanation[:, :, class_index]
        else:
            raise ValueError(f"Unexpected SHAP explanation dimensionality: {ndim}")
        instance_ordering = None
        if order_by == "sum":
            verbose and logger.info("Ordering instances by sum of SHAP values.")
            shap_sums = explanation_to_process.values.sum(axis=1)
            instance_ordering = np.argsort(-shap_sums)
        elif order_by == "similarity":
            verbose and logger.info(
                "Ordering instances by similarity (default SHAP behavior)."
            )
        else:
            verbose and logger.warning(
                f"Unknown order_by type: {order_by}. Using default similarity ordering."
            )
        verbose and logger.info("Generating SHAP Heatmap.")
        fig = plt.figure()
        if order_by == "sum":
            shap.plots.heatmap(
                explanation_to_process,
                instance_order=instance_ordering,
                max_display=max_display,
                show=False,
            )
        else:
            shap.plots.heatmap(
                explanation_to_process, max_display=max_display, show=False
            )
        buf = io.BytesIO()
        plt.savefig(buf, format="png", bbox_inches="tight", dpi=300)
        buf.seek(0)
        plt.close(plt.gcf())
        plt.close(fig)
        if image_format == "pil":
            image = Image.open(buf)
        elif image_format == "array":
            img = Image.open(buf)
            image = np.array(img)
        elif image_format == "bytes":
            image = buf.getvalue()
        verbose and logger.info("Image generation complete.")
    except Exception as e:
        verbose and logger.error(f"Error computing SHAP heatmap: {str(e)}")
        raise e
    return image

Brick Info

version v0.1.4
python 3.11, 3.12, 3.13
requirements
  • shap>=0.47.0
  • matplotlib
  • numpy
  • pandas
  • pillow
  • numba>=0.56.0
  • shap