SHAP Probability
Visualizes the relationship between SHAP values (Logits) and Probability. Plots the Sigmoid curve and highlights the specific probability for a given SHAP score.
SHAP Probability
Processing
This brick generates a visualization that translates abstract model scores (SHAP values or Logits) into a readable probability percentage.
Machine learning models often calculate a raw score (Logit) which is the sum of a "Base Value" (bias) and "SHAP Values" (feature contributions). This brick plots the Sigmoid Curve—the mathematical function used to squash these infinite raw scores into a 0% to 100% probability range. It places a specific point on this curve based on the values you provide, helping you visually explain why a model predicted a specific probability.
Inputs
This brick acts as a generator and does not require specific upstream data inputs to function. It creates a new image based entirely on the settings configured in the Options section.
Outputs
- image
- The generated visualization. Depending on your configuration, this is returned as a standard image object, a raw byte stream, or a matrix of pixels.
Outputs Types
| Output | Types |
|---|---|
image |
MediaData,PILImage |
You can check the list of supported types here: Available Type Hints.
Options
The SHAP Probability brick contains some changeable options:
- SHAP Value
- The specific score contribution you want to visualize. In a machine learning context, this is usually the sum of SHAP values for a specific prediction, representing how much the features pushed the result towards positive or negative.
- Base Value (Bias)
- The starting point of the model. This is the "average" score the model outputs before it looks at any specific features of the data (often called the intercept or bias).
- Image Format
- Determines the technical format of the output image. This selection changes the data type of the
imageoutput.
- pil: Returns a standard Python Pillow Image object. Best for further image manipulation or displaying in Python-based UIs.
- bytes: Returns the raw binary data of the PNG file.
- array: Returns the image as a NumPy array (matrix of numbers). Best for scientific processing or pixel-level analysis.
- Verbose
- Controls the amount of detail logged during processing. If enabled, the system logs the calculated Logit and Probability values to the console.
import logging
import io
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image
from coded_flows.types import MediaData, Union, PILImage
from coded_flows.utils import CodedFlowsLogger
matplotlib.use("Agg")
logger = CodedFlowsLogger(name="SHAP Probability", level=logging.INFO)
def shap_probability(options: dict = None) -> Union[MediaData, PILImage]:
options = options or {}
verbose = options.get("verbose", True)
shap_val = options.get("shap_value", 0.0)
base_val = options.get("base_value", 0.0)
image_format = options.get("image_format", "pil")
image = None
try:
verbose and logger.info(
f"Starting Probability Plot. Input SHAP: {shap_val}, Base: {base_val}"
)
current_logit = shap_val + base_val
current_prob = 1 / (1 + np.exp(-current_logit))
verbose and logger.info(
f"Calculated Logit: {current_logit:.4f}, Probability: {current_prob:.4%}"
)
(fig, ax) = plt.subplots(figsize=(8, 5))
limit = max(6.0, abs(current_logit) + 2.0)
x = np.linspace(-limit, limit, 500)
y = 1 / (1 + np.exp(-x))
ax.plot(x, y, label="Sigmoid Function", color="#3b82f6", linewidth=2)
ax.scatter(
[current_logit],
[current_prob],
color="#ef4444",
s=100,
zorder=5,
label="Your Value",
)
ax.axvline(
x=current_logit,
ymax=current_prob,
color="#ef4444",
linestyle="--",
alpha=0.5,
)
ax.axhline(
y=current_prob,
xmax=(current_logit + limit) / (2 * limit),
color="#ef4444",
linestyle="--",
alpha=0.5,
)
ax.annotate(
f"Prob: {current_prob:.1%}\nLogit: {current_logit:.2f}",
xy=(current_logit, current_prob),
xytext=(10, -20),
textcoords="offset points",
bbox=dict(boxstyle="round,pad=0.5", fc="white", ec="#ef4444", alpha=0.8),
arrowprops=dict(
arrowstyle="->", connectionstyle="arc3,rad=0", color="#ef4444"
),
)
ax.set_title(
"SHAP Value (Log Odds) → Probability", fontsize=12, fontweight="bold"
)
ax.set_xlabel("Logit (Base Value + SHAP Sum)")
ax.set_ylabel("Probability")
ax.grid(True, linestyle=":", alpha=0.6)
ax.set_ylim(-0.05, 1.05)
ax.set_xlim(-limit, limit)
ax.legend(loc="upper left")
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", dpi=300)
buf.seek(0)
plt.close(fig)
if image_format == "pil":
image = Image.open(buf)
elif image_format == "array":
image = Image.open(buf)
image = np.array(image)
elif image_format == "bytes":
image = buf.getvalue()
verbose and logger.info("Plot generated successfully.")
except Exception as e:
verbose and logger.error(f"Error generating probability plot: {str(e)}")
raise e
return image
Brick Info
- shap>=0.47.0
- matplotlib
- pandas
- numpy
- pillow
- numba>=0.56.0