Pairplot Image
Generate a pairplot visualization. Automatically adjusts DPI based on feature count to prevent memory errors.
Pairplot Image
Processing
This function takes tabular data (Pandas DataFrame, Polars DataFrame, or PyArrow Table) and generates a pairplot visualization, which displays pairwise relationships between features. It automatically selects all numeric columns if no specific columns are provided, and allows conditioning the visualization using a hue column. The resulting image is rendered to memory and returned in a user-specified format (NumPy array, PIL Image, bytes, or BytesIO stream).
Inputs
- data
- Input data used for visualization, typically containing multiple numeric features.
Inputs Types
| Input | Types |
|---|---|
data |
DataFrame, ArrowTable |
You can check the list of supported types here: Available Type Hints.
Outputs
- image
- The generated pairplot visualization. The specific format depends on the 'Output Type' option selected.
Outputs Types
| Output | Types |
|---|---|
image |
MediaData, PILImage |
You can check the list of supported types here: Available Type Hints.
Options
The Pairplot Image brick contains some changeable options:
- Columns to Plot
- List of specific columns to include in the pairplot matrix. If left empty, the function defaults to using all numeric columns found in the input data.
- Hue Column
- Name of the column used to color code the points in the plot based on categorical values.
- Color Palette
- The color scheme used for rendering the plot. Available choices include standard Seaborn palettes like
husl,deep,muted, etc. - Diagonal Plot Type
- Specifies the type of plot drawn on the diagonal axes, such as
hist(histogram) orkde(Kernel Density Estimate). - Only Lower
- If enabled, only the lower triangle of the plot matrix is drawn, making the output cleaner when analyzing symmetry is unnecessary.
- Output Type
- Defines the format of the returned image object: NumPy array (
array), PIL Image object (pil), raw bytes (bytes), or BytesIO stream (bytesio). - Verbose
- If enabled, detailed logs and information about the execution process are printed.
import logging
import io
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from coded_flows.types import Union, DataFrame, ArrowTable, MediaData, PILImage
from coded_flows.utils import CodedFlowsLogger
logger = CodedFlowsLogger(name="Pairplot Image", level=logging.INFO)
def pairplot(
data: Union[DataFrame, ArrowTable], options=None
) -> Union[MediaData, PILImage]:
options = options or {}
verbose = options.get("verbose", True)
output_type = options.get("output_type", "array")
columns = options.get("columns", None)
hue = options.get("hue", "")
palette = options.get("palette", "husl")
diag_kind = options.get("diag_kind", "auto")
corner = options.get("corner", False)
image = None
verbose and logger.info(
f"Starting pairplot generation with output type: '{output_type}'"
)
try:
df = None
if isinstance(data, pl.DataFrame):
verbose and logger.info(f"Converting Polars DataFrame to Pandas")
df = data.to_pandas()
elif isinstance(data, pa.Table):
verbose and logger.info(f"Converting Arrow Table to Pandas")
df = data.to_pandas()
elif isinstance(data, pd.DataFrame):
verbose and logger.info(f"Input is already Pandas DataFrame")
df = data
else:
raise ValueError(f"Unsupported data type: {type(data).__name__}")
if df.empty:
raise ValueError("Input DataFrame is empty")
verbose and logger.info(
f"Processing DataFrame with {df.shape[0]:,} rows × {df.shape[1]:,} columns"
)
plot_cols = []
if columns and len(columns) > 0:
missing_cols = [col for col in columns if col not in df.columns]
if missing_cols:
raise ValueError(f"Columns not found in DataFrame: {missing_cols}")
plot_cols = list(columns)
verbose and logger.info(f"Using specified columns: {plot_cols}")
else:
plot_cols = df.select_dtypes(include=[np.number]).columns.tolist()
if not plot_cols:
raise ValueError("No numeric columns found in DataFrame")
verbose and logger.info(f"Using all numeric columns: {plot_cols}")
hue_col = None
if hue and hue.strip():
if hue not in df.columns:
raise ValueError(f"Hue column '{hue}' not found in DataFrame")
if hue not in plot_cols:
plot_cols.append(hue)
hue_col = hue
verbose and logger.info(f"Using hue column: '{hue}'")
n_features = len(plot_cols)
dpi = 300
if n_features > 20:
dpi = 200
verbose and logger.info(
f"Feature count is {n_features} (>20). Reducing DPI to {dpi} to prevent memory overflow."
)
elif n_features >= 30:
dpi = 150
verbose and logger.info(
f"Feature count is {n_features} (>30). Reducing DPI to {dpi} for performance."
)
else:
verbose and logger.info(
f"Feature count is {n_features}. Using standard high DPI ({dpi})."
)
pairplot_obj = sns.pairplot(
df[plot_cols],
hue=hue_col,
palette=palette if hue_col else None,
diag_kind=diag_kind,
corner=corner,
)
verbose and logger.info("Rendering plot to buffer...")
fig = getattr(pairplot_obj, "figure", getattr(pairplot_obj, "fig", None))
if fig is None:
raise RuntimeError(
"Could not retrieve Matplotlib Figure from Seaborn PairGrid object."
)
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight")
buf.seek(0)
plt.close(fig)
Image.MAX_IMAGE_PIXELS = None
if output_type == "bytesio":
image = buf
elif output_type == "bytes":
image = buf.getvalue()
buf.close()
elif output_type == "pil":
image = Image.open(buf)
elif output_type == "array":
pil_img = Image.open(buf)
image = np.array(pil_img)
buf.close()
else:
raise ValueError(f"Invalid output_type: '{output_type}'")
verbose and logger.info(f"Successfully generated pairplot as {output_type}")
except Exception as e:
verbose and logger.error(f"Failed to generate pairplot: {e}")
plt.close("all")
raise
return image
Brick Info
- matplotlib
- pandas
- pyarrow
- polars[pyarrow]
- numpy
- seaborn
- pillow