SHAP Explanation
Explains a single prediction (row) using a waterfall plot. Shows how each feature contributed to pushing the model output from the base value to the final prediction.
SHAP Explanation
Processing
This brick analyzes a single specific prediction from your machine learning model to explain why that result occurred. It generates a "Waterfall" plot and a detailed data table that visualizes how each individual factor (feature) pushed the prediction outcome higher or lower than the average (base) value.
For example, if a model predicts a house price, this brick helps you understand that while the "Size" added \(50k to the value, the "Age" removed \)10k. It processes the explanation data to sort features by their impact, ensuring the most important factors are highlighted.
Inputs
- explainer
- The calculated SHAP Explainer object. This is usually the output from a previous "Calculate SHAP" brick. It contains the mathematical rules needed to explain the model.
- data
- The dataset containing the rows you want to explain. This must match the format of the data used to train the model.
- original data
- (Optional) A dataset containing the raw, human-readable values. If provided, the waterfall plot will display these values (e.g., "Married") instead of the numerical encodings (e.g., "1") used by the model. This makes the chart easier to read.
Inputs Types
| Input | Types |
|---|---|
explainer |
Any |
data |
DataFrame |
original data |
DataFrame |
You can check the list of supported types here: Available Type Hints.
Outputs
- image
- The visual Waterfall plot showing the positive and negative forces acting on the prediction.
- explanation
- A structured list of the mathematical contributions for the specific row analyzed. It is returned as a DataFrame sorted by the magnitude of impact.
The explanation output contains the following specific data fields:
- feature: The name of the data column (e.g., "Age", "Income") or "Base Value".
- value: The actual value of that feature for this specific row (e.g., "25", "High").
- contribution: The numerical amount this feature added to or subtracted from the prediction.
Outputs Types
| Output | Types |
|---|---|
image |
MediaData,PILImage |
explanation |
DataFrame |
You can check the list of supported types here: Available Type Hints.
Options
The SHAP Explanation brick contains some changeable options:
- Row Index
- The specific row number in your dataset that you want to explain.
- Class Index (Multiclass)
- Used only if your model predicts multiple categories (e.g., "Red", "Blue", "Green"). This determines which category prediction to explain.
- Max Display
- The maximum number of features to show on the plot. If your data has 100 columns, setting this to 20 ensures the chart remains readable by grouping the smallest contributions into "Other".
- Image Format
- Determines the technical format of the output image.
- pil: Returns a Python Image Library object. Best for further image processing in Python.
- bytes: Returns the raw file bytes.
- array: Returns the image as a grid of numbers (Numpy array). Best for computer vision tasks.
- Verbose
- When enabled, the brick logs detailed progress information to the console, which is helpful for troubleshooting.
import logging
import io
import inspect
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image
from shap.explainers._explainer import Explainer
from coded_flows.types import DataFrame, Any, Tuple, MediaData, Union, PILImage
from coded_flows.utils import CodedFlowsLogger
matplotlib.use("Agg")
import shap
logger = CodedFlowsLogger(name="SHAP Explanation", level=logging.INFO)
logging.getLogger("shap").setLevel(logging.ERROR)
def _call_explainer(explainer, data, **base_kwargs):
sig = inspect.signature(explainer.__call__)
if "silent" in sig.parameters:
return explainer(data, silent=True, **base_kwargs)
return explainer(data, **base_kwargs)
def _to_numpy(df):
"""Convert pandas or polars DataFrame to numpy array."""
if hasattr(df, "__class__") and "polars" in df.__class__.__module__:
return df.to_numpy()
elif hasattr(df, "values"):
return df.values
elif hasattr(df, "to_numpy"):
return df.to_numpy()
else:
return np.array(df)
def _get_columns(df):
"""Get column names from pandas or polars DataFrame."""
if hasattr(df, "columns"):
columns = df.columns
return list(columns) if not isinstance(columns, list) else columns
return None
def shap_explanation(
explainer: Any,
data: DataFrame,
original_data: DataFrame = None,
options: dict = None,
) -> Tuple[Union[MediaData, PILImage], DataFrame]:
options = options or {}
verbose = options.get("verbose", True)
row_index = options.get("row_index", 0)
class_index = options.get("class_index", 0)
max_display = options.get("max_display", 20)
image_format = options.get("image_format", "pil")
image = None
explanation = pd.DataFrame()
try:
verbose and logger.info(
f"Starting SHAP Individual Waterfall for row {row_index}."
)
if not isinstance(explainer, Explainer):
verbose and logger.error("Expects a Shap Explainer as an input.")
raise ValueError("Expects a Shap Explainer as an input.")
shap_explanation = _call_explainer(explainer, data)
if original_data is not None:
verbose and logger.info(
"Original data provided. Overwriting explanation data."
)
columns = _get_columns(original_data)
if columns is not None:
shap_explanation.feature_names = columns
shap_explanation.data = _to_numpy(original_data)
if row_index >= shap_explanation.shape[0] or row_index < 0:
raise ValueError(
f"Row index {row_index} is out of bounds (0-{shap_explanation.shape[0] - 1})."
)
ndim = len(shap_explanation.shape)
single_explanation = None
if ndim == 2:
verbose and logger.info("Detected Regression or Binary Classification.")
single_explanation = shap_explanation[row_index]
elif ndim == 3:
num_classes = shap_explanation.shape[2]
verbose and logger.info(
f"Detected Multiclass output with {num_classes} classes."
)
if class_index >= num_classes or class_index < 0:
verbose and logger.warning(
f"Class index {class_index} out of bounds. Defaulting to 0."
)
class_index = 0
verbose and logger.info(f"Selecting class index: {class_index}")
single_explanation = shap_explanation[row_index, :, class_index]
else:
raise ValueError(f"Unexpected SHAP explanation dimensionality: {ndim}")
feature_names = single_explanation.feature_names
if not feature_names:
feature_names = [f"Feature {i}" for i in range(single_explanation.shape[0])]
values = np.array(single_explanation.data).flatten()
contributions = np.array(single_explanation.values).flatten()
rows = []
for f, v, c in zip(feature_names, values, contributions):
rows.append(
{
"feature": f,
"value": v,
"contribution": c,
"abs_contribution": abs(c),
}
)
rows.sort(key=lambda x: x["abs_contribution"], reverse=True)
for r in rows:
del r["abs_contribution"]
base_value = single_explanation.base_values
if isinstance(base_value, (np.ndarray, list)):
base_value = base_value if np.ndim(base_value) == 0 else base_value[0]
rows.insert(
0,
{
"feature": "Base Value (Expected)",
"value": np.nan,
"contribution": base_value,
},
)
explanation = pd.DataFrame(rows)
verbose and logger.info(
"Constructed sorted contribution DataFrame (Base Value + Features)."
)
verbose and logger.info("Generating Waterfall Plot.")
fig = plt.figure()
shap.plots.waterfall(single_explanation, max_display=max_display, show=False)
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", dpi=300)
buf.seek(0)
plt.close(plt.gcf())
plt.close(fig)
if image_format == "pil":
image = Image.open(buf)
elif image_format == "array":
img = Image.open(buf)
image = np.array(img)
elif image_format == "bytes":
image = buf.getvalue()
verbose and logger.info("Image generation complete.")
except Exception as e:
verbose and logger.error(f"Error computing SHAP waterfall: {str(e)}")
raise e
return (image, explanation)
Brick Info
- shap>=0.47.0
- matplotlib
- numpy
- pandas
- pillow
- numba>=0.56.0
- shap