SHAP Cohorts

Automatically detects cohorts in the data and compares feature importance across them.

SHAP Cohorts

Processing

This brick performs an advanced analysis of your model's behavior by automatically grouping your data into distinct segments (cohorts) and comparing them. Instead of showing the "average" feature importance for the whole dataset, it detects groups of similar data points and shows how the model treats them differently.

For example, in a loan approval model, it might reveal that "Income" is the most important factor for one group of applicants, while "Credit History" is the driving factor for another. It generates a visual comparison bar chart to illustrate these differences.

Inputs

explainer
The calculated SHAP explainer object. This represents the "brain" that has already learned how to interpret your model's predictions. This usually comes from a "Calculate SHAP Values" or "Build Explainer" brick.
data
The dataset used to generate explanations. This should be the exact data structure (including any scaling or normalization) that the model expects to receive.
original data
(Optional) The raw, human-readable version of the dataset. If provided, the brick uses this to improve the chart labels (e.g., using real feature names and unscaled values). This makes the final visualization much easier to interpret.

Inputs Types

Input Types
explainer Any
data DataFrame
original data DataFrame

You can check the list of supported types here: Available Type Hints.

Outputs

image
The generated comparison chart. This visualization displays the top features and their average impact (SHAP value) for each detected cohort.

Outputs Types

Output Types
image MediaData,PILImage

You can check the list of supported types here: Available Type Hints.

Options

The SHAP Cohorts brick contains some changeable options:

Number of Cohorts
The number of distinct groups you want to find in your data.
  • 2: Splits the data into two opposing groups.
  • Higher values (e.g., 3-10): Finds more granular, specific subgroups within your data.
Class Index (Multiclass)
Used only if your model predicts multiple categories (e.g., "Red", "Blue", "Green"). This integer determines which category to analyze. For example, 0 analyzes the first class/category and 1 analyzes the second class/category, and so on.
Max Display
The maximum number of features to show on the bar chart. It limits the view to the most important factors to keep the chart readable.
Image Format
Controls the technical format of the output image.
  • pil: Returns a Python Image Library object. Best for further image manipulation within the flow.
  • bytes: Returns the raw image file data.
  • array: Returns the image as a grid of numbers (NumPy array). Best for mathematical image processing.
Verbose
When enabled, the brick logs detailed progress information to the console, which is helpful for troubleshooting.
import logging
import io
import inspect
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image
from shap.explainers._explainer import Explainer
from coded_flows.types import DataFrame, Any, MediaData, Union, PILImage
from coded_flows.utils import CodedFlowsLogger

matplotlib.use("Agg")
import shap

logger = CodedFlowsLogger(name="SHAP Cohorts", level=logging.INFO)
logging.getLogger("shap").setLevel(logging.ERROR)


def _call_explainer(explainer, data, **base_kwargs):
    sig = inspect.signature(explainer.__call__)
    if "silent" in sig.parameters:
        return explainer(data, silent=True, **base_kwargs)
    return explainer(data, **base_kwargs)


def _to_numpy(df):
    """Convert pandas or polars DataFrame to numpy array."""
    if hasattr(df, "__class__") and "polars" in df.__class__.__module__:
        return df.to_numpy()
    elif hasattr(df, "values"):
        return df.values
    elif hasattr(df, "to_numpy"):
        return df.to_numpy()
    else:
        return np.array(df)


def _get_columns(df):
    """Get column names from pandas or polars DataFrame."""
    if hasattr(df, "columns"):
        columns = df.columns
        return list(columns) if not isinstance(columns, list) else columns
    return None


def shap_cohorts(
    explainer: Any,
    data: DataFrame,
    original_data: DataFrame = None,
    options: dict = None,
) -> Union[MediaData, PILImage]:
    options = options or {}
    verbose = options.get("verbose", True)
    class_index = options.get("class_index", 0)
    max_display = options.get("max_display", 10)
    n_cohorts = options.get("n_cohorts", 2)
    image_format = options.get("image_format", "pil")
    image = None
    try:
        verbose and logger.info("Starting SHAP Cohort Analysis.")
        if not isinstance(explainer, Explainer):
            verbose and logger.error("Expects a Shap Explainer as an input.")
            raise ValueError("Expects a Shap Explainer as an input.")
        shap_explanation = _call_explainer(explainer, data)
        if original_data is not None:
            verbose and logger.info(
                "Original data provided. Overwriting explanation data."
            )
            columns = _get_columns(original_data)
            if columns is not None:
                shap_explanation.feature_names = columns
            shap_explanation.data = _to_numpy(original_data)
        ndim = len(shap_explanation.shape)
        explanation_to_process = None
        if ndim == 2:
            verbose and logger.info("Detected Regression or Binary Classification.")
            explanation_to_process = shap_explanation
        elif ndim == 3:
            num_classes = shap_explanation.shape[2]
            verbose and logger.info(
                f"Detected Multiclass output with {num_classes} classes."
            )
            if class_index >= num_classes or class_index < 0:
                verbose and logger.warning(
                    f"Class index {class_index} out of bounds. Defaulting to 0."
                )
                class_index = 0
            verbose and logger.info(f"Selecting class index: {class_index}")
            explanation_to_process = shap_explanation[:, :, class_index]
        else:
            raise ValueError(f"Unexpected SHAP explanation dimensionality: {ndim}")
        verbose and logger.info(f"Computing {n_cohorts} cohorts.")
        cohorts = explanation_to_process.cohorts(n_cohorts)
        verbose and logger.info(f"Generating Cohort Bar Plot max_display={max_display}")
        fig = plt.figure()
        shap.plots.bar(cohorts.abs.mean(0), show=False, max_display=max_display)
        buf = io.BytesIO()
        plt.savefig(buf, format="png", bbox_inches="tight", dpi=300)
        buf.seek(0)
        plt.close(plt.gcf())
        plt.close(fig)
        if image_format == "pil":
            image = Image.open(buf)
        elif image_format == "array":
            img = Image.open(buf)
            image = np.array(img)
        elif image_format == "bytes":
            image = buf.getvalue()
        verbose and logger.info("Image generation complete.")
    except Exception as e:
        verbose and logger.error(f"Error computing SHAP cohorts: {str(e)}")
        raise e
    return image

Brick Info

version v0.1.4
python 3.11, 3.12, 3.13
requirements
  • shap>=0.47.0
  • matplotlib
  • numpy
  • pandas
  • pillow
  • numba>=0.56.0
  • shap