PCA Analysis
Perform Principal Component Analysis and generate visualization.
PCA Analysis
Processing
This brick performs Principal Component Analysis (PCA) to reduce the complexity of your dataset while retaining the most important patterns. It mathematically transforms a table with many numeric columns into a streamlined version with fewer columns (called Principal Components).
Practically, this is used to:
- Visualize high-dimensional data: Turn a table with 10+ columns into a 2D scatter plot to spot clusters or outliers.
- Simplify data for AI: Reduce noise and improve the performance of machine learning models by focusing on the signals with the highest variance.
The brick automatically handles data scaling (normalization) and generates both the transformed data and insightful visualizations.
Inputs
- data
- The dataset containing the numeric values you want to analyze. This can be a Pandas, Polars, or Arrow table.
Inputs Types
| Input | Types |
|---|---|
data |
DataFrame, ArrowTable |
You can check the list of supported types here: Available Type Hints.
Outputs
- Scaler
- The scaling object (StandardScaler) used to normalize the data. This can be used later to inverse-transform data or apply the same scaling to new data.
- PCA Image
- A visualization of the analysis (e.g., a scatter plot of the first two components or a variance chart), returned as an image.
- PCA Components
- The transformed dataset. This table contains your rows mapped onto the new Principal Components (e.g., columns named
PC1,PC2). If an ID column was provided, it is included here. - PCA Summary
- A statistical summary of the analysis. It includes the eigenvalues, the percentage of variance explained by each component, and the specific feature loadings.
- PCA Loadings
- A detailed breakdown of how much each original feature contributes to each Principal Component. Sorted by the absolute impact on the first component (PC1).
The PCA Components output contains the following specific data fields:
- {ID Column}: The identifier column (if provided in options).
- PC1: The value of the first Principal Component for this row.
- PC2: The value of the second Principal Component.
- ...: Additional columns up to the selected "Number of Components".
The PCA Summary output contains the following specific data fields:
- Component: The label of the component (e.g., "PC1").
- Eigenvalue: A measure of the variance magnitude.
- Variance_Explained: The percentage of the dataset's total information held by this component (0.0 to 1.0).
- Cumulative_Variance: The running total of variance explained.
Outputs Types
| Output | Types |
|---|---|
Scaler |
Any |
PCA Image |
MediaData, PILImage |
PCA Components |
DataFrame |
PCA Summary |
DataFrame |
PCA Loadings |
DataFrame |
You can check the list of supported types here: Available Type Hints.
Options
The PCA Analysis brick contains some changeable options:
- Columns for PCA
- Select specific numeric columns to analyze. If left empty, the brick will attempt to use all numeric columns found in the dataset.
- Number of Components
- The number of dimensions to keep (e.g., 2 or 3). Reducing data to 2 components is ideal for 2D plotting.
- Scale Data
- Determines if the data should be standardized (mean=0, variance=1) before analysis.
- True (On): Highly recommended. Ensures that variables with large numbers (e.g., Salary) don't dominate variables with small numbers (e.g., Age) simply due to scale.
- False (Off): Uses raw values. Only use this if your data is already normalized.
- ID Column
- The name of a column (e.g., "Customer_ID" or "Product_Name") that uniquely identifies each row. This column is excluded from the mathematical calculation but is added back to the results so you can identify which point belongs to which item.
- Hue Column (for scatter)
- The name of a column to use for coloring the points in the plot. Useful for seeing if specific groups (e.g., "Category" or "Status") cluster together.
- PC for X-axis
- Which Principal Component to plot on the horizontal axis (usually 1).
- PC for Y-axis
- Which Principal Component to plot on the vertical axis (usually 2).
- Plot Type
- The style of visualization to generate.
- scatter: A standard 2D plot showing the data points mapped to the selected components.
- biplot: A scatter plot that also includes arrows (vectors) indicating how the original features influence the axes.
- variance: A bar/line chart showing how much information (variance) is captured by each component.
- correlation_circle: A unit circle plot showing the correlations between original features and principal components, where each feature is represented as a point on or inside the circle.
- all: Generates a wide image containing all three plots side-by-side.
- Color Palette
- The color scheme used for the visualization.
- Output Type
- The technical format of the generated image (e.g., array, PIL image, or bytes).
- Random State
- A seed number to ensure the results are reproducible every time you run the workflow.
- Brick Caching
- Controls whether to save and reuse the calculation results to speed up subsequent runs with the exact same data.
- Verbose
- If enabled, detailed logs about the process will be printed to the console.
import logging
import io
import json
import xxhash
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import matplotlib
from scipy import sparse
from dataclasses import dataclass
from datetime import datetime
import hashlib
import tempfile
import sklearn
import scipy
import joblib
from pathlib import Path
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patheffects as PathEffects
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import seaborn as sns
from PIL import Image
from sklearn.decomposition import PCA as _PCA
from sklearn.preprocessing import StandardScaler
from coded_flows.types import (
Union,
DataFrame,
ArrowTable,
MediaData,
PILImage,
Tuple,
Any,
)
from coded_flows.utils import CodedFlowsLogger
logger = CodedFlowsLogger(name="Dim. Reduction PCA", level=logging.INFO)
DataType = Union[
pd.DataFrame, pl.DataFrame, np.ndarray, sparse.spmatrix, pd.Series, pl.Series
]
@dataclass
class _DatasetFingerprint:
"""Lightweight fingerprint of a dataset."""
hash: str
shape: tuple
computed_at: str
data_type: str
method: str
class _UniversalDatasetHasher:
"""
High-performance dataset hasher optimizing for zero-copy operations
and native backend execution (C/Rust).
"""
def __init__(
self,
data_size: int,
method: str = "auto",
sample_size: int = 100000,
verbose: bool = False,
):
self.method = method
self.sample_size = sample_size
self.data_size = data_size
self.verbose = verbose
def hash_data(self, data: DataType) -> _DatasetFingerprint:
"""
Main entry point: hash any supported data format.
Auto-detects format and applies optimal strategy.
"""
if isinstance(data, pd.DataFrame):
return self._hash_pandas(data)
elif isinstance(data, pl.DataFrame):
return self._hash_polars(data)
elif isinstance(data, pd.Series):
return self._hash_pandas_series(data)
elif isinstance(data, pl.Series):
return self._hash_polars_series(data)
elif isinstance(data, np.ndarray):
return self._hash_numpy(data)
elif sparse.issparse(data):
return self._hash_sparse(data)
else:
raise TypeError(f"Unsupported data type: {type(data)}")
def _hash_pandas(self, df: pd.DataFrame) -> _DatasetFingerprint:
"""
Optimized Pandas hashing using pd.util.hash_pandas_object.
Avoids object-to-string conversion overhead.
"""
method = self._determine_method(self.data_size, self.method)
self.verbose and logger.info(
f"Hashing Pandas: {self.data_size:,} rows - {method}"
)
target_df = df
if method == "sampled":
target_df = self._get_pandas_sample(df)
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{
"columns": df.columns.tolist(),
"dtypes": {k: str(v) for (k, v) in df.dtypes.items()},
"shape": df.shape,
},
)
try:
row_hashes = pd.util.hash_pandas_object(target_df, index=False)
hasher.update(memoryview(row_hashes.values))
except Exception as e:
self.verbose and logger.warning(
f"Fast hash failed, falling back to slow hash: {e}"
)
self._hash_pandas_fallback(hasher, target_df)
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=df.shape,
computed_at=datetime.now().isoformat(),
data_type="pandas",
method=method,
)
def _get_pandas_sample(self, df: pd.DataFrame) -> pd.DataFrame:
"""Deterministic slicing for sampling (Zero randomness)."""
if self.data_size <= self.sample_size:
return df
chunk = self.sample_size // 3
head = df.iloc[:chunk]
mid_idx = self.data_size // 2
mid = df.iloc[mid_idx : mid_idx + chunk]
tail = df.iloc[-chunk:]
return pd.concat([head, mid, tail])
def _hash_pandas_fallback(self, hasher, df: pd.DataFrame):
"""Legacy fallback for complex object types."""
for col in df.columns:
val = df[col].astype(str).values
hasher.update(val.astype(np.bytes_).tobytes())
def _hash_polars(self, df: pl.DataFrame) -> _DatasetFingerprint:
"""Optimized Polars hashing using native Rust execution."""
method = self._determine_method(self.data_size, self.method)
self.verbose and logger.info(
f"Hashing Polars: {self.data_size:,} rows - {method}"
)
target_df = df
if method == "sampled" and self.data_size > self.sample_size:
indices = self._get_sample_indices(self.data_size, self.sample_size)
target_df = df.gather(indices)
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{
"columns": df.columns,
"dtypes": [str(t) for t in df.dtypes],
"shape": df.shape,
},
)
row_hashes = target_df.hash_rows()
hasher.update(memoryview(row_hashes.to_numpy()))
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=df.shape,
computed_at=datetime.now().isoformat(),
data_type="polars",
method=method,
)
def _hash_pandas_series(self, series: pd.Series) -> _DatasetFingerprint:
"""Hash Pandas Series using the fastest vectorized method."""
self.verbose and logger.info(f"Hashing Pandas Series: {self.data_size:,} rows")
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{
"name": series.name if series.name else "None",
"dtype": str(series.dtype),
"shape": series.shape,
},
)
try:
row_hashes = pd.util.hash_pandas_object(series, index=False)
hasher.update(memoryview(row_hashes.values))
except Exception as e:
self.verbose and logger.warning(f"Series hash failed, falling back: {e}")
hasher.update(memoryview(series.astype(str).values.tobytes()))
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=series.shape,
computed_at=datetime.now().isoformat(),
data_type="pandas_series",
method="full",
)
def _hash_polars_series(self, series: pl.Series) -> _DatasetFingerprint:
"""Hash Polars Series using native Polars expressions."""
self.verbose and logger.info(f"Hashing Polars Series: {self.data_size:,} rows")
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{"name": series.name, "dtype": str(series.dtype), "shape": series.shape},
)
try:
row_hashes = series.hash()
hasher.update(memoryview(row_hashes.to_numpy()))
except Exception as e:
self.verbose and logger.warning(
f"Polars series native hash failed. Falling back."
)
hasher.update(str(series.to_list()).encode())
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=series.shape,
computed_at=datetime.now().isoformat(),
data_type="polars_series",
method="full",
)
def _hash_numpy(self, arr: np.ndarray) -> _DatasetFingerprint:
"""Optimized NumPy hashing using Buffer Protocol (Zero-Copy)."""
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{"shape": arr.shape, "dtype": str(arr.dtype), "strides": arr.strides},
)
if arr.flags["C_CONTIGUOUS"] or arr.flags["F_CONTIGUOUS"]:
hasher.update(memoryview(arr))
else:
hasher.update(memoryview(np.ascontiguousarray(arr)))
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=arr.shape,
computed_at=datetime.now().isoformat(),
data_type="numpy",
method="full",
)
def _hash_sparse(self, matrix: sparse.spmatrix) -> _DatasetFingerprint:
"""Optimized sparse hashing."""
if not (sparse.isspmatrix_csr(matrix) or sparse.isspmatrix_csc(matrix)):
matrix = matrix.tocsr()
hasher = xxhash.xxh128()
self._hash_schema(
hasher, {"shape": matrix.shape, "format": matrix.format, "nnz": matrix.nnz}
)
hasher.update(memoryview(matrix.data))
hasher.update(memoryview(matrix.indices))
hasher.update(memoryview(matrix.indptr))
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=matrix.shape,
computed_at=datetime.now().isoformat(),
data_type=f"sparse_{matrix.format}",
method="sparse",
)
def _determine_method(self, rows: int, requested: str) -> str:
if requested != "auto":
return requested
if rows < 5000000:
return "full"
return "sampled"
def _hash_schema(self, hasher, schema):
hasher.update(
json.dumps(schema, sort_keys=True, separators=(",", ":")).encode()
)
def _get_sample_indices(self, total_rows: int, sample_size: int) -> list:
chunk = sample_size // 3
indices = list(range(min(chunk, total_rows)))
mid_start = max(0, total_rows // 2 - chunk // 2)
mid_end = min(mid_start + chunk, total_rows)
indices.extend(range(mid_start, mid_end))
last_start = max(0, total_rows - chunk)
indices.extend(range(last_start, total_rows))
return sorted(list(set(indices)))
def _plot_scatter(
ax, components, hue_data, hue_col, palette, explained_var, pc_x, pc_y
):
"""Plot PCA scatter plot."""
pc_x_idx = pc_x - 1
pc_y_idx = pc_y - 1
if hue_data is not None:
scatter_df = pd.DataFrame(
{
f"PC{pc_x}": components[:, pc_x_idx],
f"PC{pc_y}": components[:, pc_y_idx],
hue_col: hue_data,
}
)
sns.scatterplot(
data=scatter_df,
x=f"PC{pc_x}",
y=f"PC{pc_y}",
hue=hue_col,
palette=palette,
ax=ax,
alpha=0.7,
s=50,
)
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
else:
ax.scatter(components[:, pc_x_idx], components[:, pc_y_idx], alpha=0.7, s=50)
ax.set_xlabel(f"PC{pc_x} ({explained_var[pc_x_idx]:.1%} variance)")
ax.set_ylabel(f"PC{pc_y} ({explained_var[pc_y_idx]:.1%} variance)")
ax.set_title(f"PCA Scatter Plot: PC{pc_x} vs PC{pc_y}")
ax.grid(True, alpha=0.3)
def _plot_biplot(
ax,
components,
loadings,
feature_names,
explained_var,
hue_data=None,
hue_col=None,
palette=None,
pc_x=1,
pc_y=2,
):
"""Plot PCA biplot with loadings."""
pc_x_idx = pc_x - 1
pc_y_idx = pc_y - 1
if hue_data is not None:
scatter_df = pd.DataFrame(
{
f"PC{pc_x}": components[:, pc_x_idx],
f"PC{pc_y}": components[:, pc_y_idx],
hue_col: hue_data,
}
)
sns.scatterplot(
data=scatter_df,
x=f"PC{pc_x}",
y=f"PC{pc_y}",
hue=hue_col,
palette=palette,
ax=ax,
alpha=0.5,
s=30,
legend=True,
)
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
else:
ax.scatter(components[:, pc_x_idx], components[:, pc_y_idx], alpha=0.5, s=30)
scale = components[:, [pc_x_idx, pc_y_idx]].max() * 0.8
for i, feature in enumerate(feature_names):
ax.arrow(
0,
0,
loadings[pc_x_idx, i] * scale,
loadings[pc_y_idx, i] * scale,
head_width=0.05,
head_length=0.05,
fc="red",
ec="red",
alpha=0.6,
)
ax.text(
loadings[pc_x_idx, i] * scale * 1.1,
loadings[pc_y_idx, i] * scale * 1.1,
feature,
fontsize=7,
ha="center",
va="center",
)
ax.set_xlabel(f"PC{pc_x} ({explained_var[pc_x_idx]:.1%} variance)")
ax.set_ylabel(f"PC{pc_y} ({explained_var[pc_y_idx]:.1%} variance)")
ax.set_title(f"PCA Biplot: PC{pc_x} vs PC{pc_y}")
ax.grid(True, alpha=0.3)
def _plot_correlation_circle(
ax, loadings, feature_names, explained_var, pc_x=1, pc_y=2, text_size=7
):
pc_x_idx = pc_x - 1
pc_y_idx = pc_y - 1
x_loadings = loadings[pc_x_idx, :]
y_loadings = loadings[pc_y_idx, :]
circle_1 = plt.Circle(
(0, 0), 1, color="#333333", fill=False, linestyle="-", linewidth=1.2, alpha=0.6
)
circle_05 = plt.Circle(
(0, 0), 0.5, color="gray", fill=False, linestyle=":", linewidth=0.8, alpha=0.5
)
ax.add_artist(circle_1)
ax.add_artist(circle_05)
ax.axhline(0, color="black", linewidth=0.8, alpha=0.5)
ax.axvline(0, color="black", linewidth=0.8, alpha=0.5)
ticks = [-1.0, -0.75, -0.5, -0.25, 0.0, 0.25, 0.5, 0.75, 1.0]
ax.set_xticks(ticks)
ax.set_yticks(ticks)
ax.grid(True, linestyle=":", linewidth=0.6, alpha=0.4, color="gray", zorder=0)
ax.tick_params(axis="both", which="major", labelsize=8, colors="#666666")
color_code = "#E74C3C"
for i, (feature, x, y) in enumerate(zip(feature_names, x_loadings, y_loadings)):
ax.plot([0, x], [0, y], color=color_code, linewidth=0.8, alpha=0.9, zorder=5)
ax.scatter(x, y, color=color_code, s=4, zorder=6)
text_x = x * 1.15
text_y = y * 1.15
ha = "center"
if x > 0.1:
ha = "left"
elif x < -0.1:
ha = "right"
va = "center"
if y > 0.1:
va = "bottom"
elif y < -0.1:
va = "top"
ax.text(
text_x,
text_y,
feature,
color="#333333",
ha=ha,
va=va,
fontsize=text_size,
fontweight="normal",
zorder=7,
)
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_aspect("equal")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.set_xlabel(
f"PC{pc_x} ({explained_var[pc_x_idx]:.1%} var)",
fontsize=12,
fontweight="bold",
color="#444444",
)
ax.set_ylabel(
f"PC{pc_y} ({explained_var[pc_y_idx]:.1%} var)",
fontsize=12,
fontweight="bold",
color="#444444",
)
ax.set_title(
f"Correlation Circle (PC{pc_x} vs PC{pc_y})",
fontsize=14,
pad=15,
color="#333333",
fontweight="bold",
)
def _plot_variance(ax, explained_var):
"""Plot explained variance."""
n_components = len(explained_var)
cumulative_var = np.cumsum(explained_var)
x = np.arange(1, n_components + 1)
ax.bar(x, explained_var, alpha=0.6, label="Individual")
ax.plot(x, cumulative_var, "ro-", linewidth=2, label="Cumulative")
ax.set_xlabel("Principal Component")
ax.set_ylabel("Explained Variance Ratio")
ax.set_title("Explained Variance by Component")
ax.set_xticks(x)
ax.legend()
ax.grid(True, alpha=0.3, axis="y")
def pca_analysis(
data: Union[DataFrame, ArrowTable], options=None
) -> Tuple[Any, Any, Union[MediaData, PILImage], DataFrame, DataFrame, DataFrame]:
options = options or {}
verbose = options.get("verbose", True)
random_state = options.get("random_state", 42)
activate_caching = options.get("activate_caching", False)
output_type = options.get("output_type", "array")
columns = options.get("columns", None)
n_components = options.get("n_components", 2)
scale_data = options.get("scale_data", True)
id_column = options.get("id_column", "")
hue = options.get("hue", "")
pc_x = options.get("pc_x", 1)
pc_y = options.get("pc_y", 2)
plot_type = options.get("plot_type", "scatter")
palette = options.get("palette", "husl")
dpi = 300
verbose and logger.info(f"Starting PCA with {n_components} components")
PCA = None
PCA_Image = None
PCA_Components = pd.DataFrame()
PCA_Summary = pd.DataFrame()
PCA_Loadings = pd.DataFrame()
Scaler = None
plot_components = None
plot_loadings = None
plot_explained_var = None
hue_data = None
hue_col = None
df = None
try:
if isinstance(data, pl.DataFrame):
verbose and logger.info(f"Converting polars DataFrame to pandas")
df = data.to_pandas()
elif isinstance(data, pa.Table):
verbose and logger.info(f"Converting Arrow table to pandas")
df = data.to_pandas()
elif isinstance(data, pd.DataFrame):
verbose and logger.info(f"Input is already pandas DataFrame")
df = data
else:
raise ValueError(f"Unsupported data type: {type(data).__name__}")
except Exception as e:
raise RuntimeError(
f"Failed to convert input data to pandas DataFrame: {e}"
) from e
if df.empty:
raise ValueError("Input DataFrame is empty")
verbose and logger.info(
f"Processing DataFrame with {df.shape[0]:,} rows × {df.shape[1]:,} columns"
)
if id_column and id_column.strip():
if id_column not in df.columns:
raise ValueError(f"ID column '{id_column}' not found in DataFrame")
if columns and len(columns) > 0:
missing_cols = [col for col in columns if col not in df.columns]
if missing_cols:
raise ValueError(f"Columns not found in DataFrame: {missing_cols}")
feature_cols = list(columns)
else:
feature_cols = df.select_dtypes(include=["number"]).columns.tolist()
if not feature_cols:
raise ValueError("No numeric columns found in DataFrame")
if id_column and id_column.strip() and (id_column in feature_cols):
feature_cols.remove(id_column)
if not feature_cols:
raise ValueError("No feature columns remaining after excluding ID column")
skip_computation = False
cache_folder = None
all_hash = None
if activate_caching:
verbose and logger.info(f"Caching is active")
data_hasher = _UniversalDatasetHasher(df.shape[0], verbose=verbose)
X_hash = data_hasher.hash_data(df[feature_cols]).hash
all_hash_base_text = f"HASH BASE TEXT PCAPandas Version {pd.__version__}POLARS Version {pl.__version__}Numpy Version {np.__version__}Scikit Learn Version {sklearn.__version__}Scipy Version {scipy.__version__}{X_hash}{n_components}{scale_data}{random_state}{sorted(feature_cols)}"
all_hash = hashlib.sha256(all_hash_base_text.encode("utf-8")).hexdigest()
verbose and logger.info(f"Hash was computed: {all_hash}")
temp_folder = Path(tempfile.gettempdir())
cache_folder = temp_folder / "coded-flows-cache"
cache_folder.mkdir(parents=True, exist_ok=True)
pca_components_path = cache_folder / f"pca_components_{all_hash}.parquet"
pca_summary_path = cache_folder / f"pca_summary_{all_hash}.parquet"
pca_loadings_path = cache_folder / f"pca_loadings_{all_hash}.parquet"
scaler_path = cache_folder / f"pca_scaler_{all_hash}.joblib"
pca_model_path = cache_folder / f"pca_model_{all_hash}.joblib"
if (
pca_components_path.is_file()
and pca_summary_path.is_file()
and pca_loadings_path.is_file()
and scaler_path.is_file()
and pca_model_path.is_file()
):
verbose and logger.info(f"Cache hit! Loading results.")
try:
PCA_Components = pd.read_parquet(pca_components_path)
PCA_Summary = pd.read_parquet(pca_summary_path)
PCA_Loadings = pd.read_parquet(pca_loadings_path)
Scaler = joblib.load(scaler_path)
PCA = joblib.load(pca_model_path)
plot_explained_var = PCA_Summary["Variance_Explained"].values
loading_cols = [
c for c in PCA_Summary.columns if c.startswith("Loading_")
]
plot_loadings = PCA_Summary[loading_cols].values
comps_for_plot = PCA_Components.copy()
if id_column and id_column in comps_for_plot.columns:
comps_for_plot = comps_for_plot.drop(columns=[id_column])
plot_components = comps_for_plot.values
skip_computation = True
except Exception as e:
verbose and logger.warning(f"Cache load failed, recomputing")
skip_computation = False
raise
if not skip_computation:
X = df[feature_cols]
if np.any(np.isnan(X)):
verbose and logger.warning(f"Removing rows with missing values")
mask = ~np.isnan(X).any(axis=1)
X = X[mask]
df_indices = df.index[mask]
else:
df_indices = df.index
if X.shape[0] == 0:
raise ValueError("No valid rows after removing missing values")
if scale_data:
verbose and logger.info(f"Scaling features")
Scaler = StandardScaler()
X_transformed = Scaler.fit_transform(X)
else:
X_transformed = X
verbose and logger.info(f"Fitting PCA model")
PCA = _PCA(n_components=n_components, random_state=random_state)
plot_components = PCA.fit_transform(X_transformed)
explained_var = PCA.explained_variance_ratio_
eigenvalues = PCA.explained_variance_
cumulative_var = np.cumsum(explained_var)
plot_explained_var = explained_var
plot_loadings = PCA.components_
pc_columns = [f"PC{i + 1}" for i in range(n_components)]
PCA_Components = pd.DataFrame(
plot_components, columns=pc_columns, index=df_indices
)
if id_column and id_column.strip():
id_data = df.loc[df_indices, id_column].values
PCA_Components.insert(0, id_column, id_data)
summary_data = {
"Component": pc_columns,
"Eigenvalue": eigenvalues,
"Variance_Explained": explained_var,
"Cumulative_Variance": cumulative_var,
}
for idx, feature in enumerate(feature_cols):
summary_data[f"Loading_{feature}"] = PCA.components_[:, idx]
PCA_Summary = pd.DataFrame(summary_data)
PCA_Loadings = pd.DataFrame(PCA.components_.T, columns=pc_columns)
PCA_Loadings.insert(0, "Feature", feature_cols)
PCA_Loadings = PCA_Loadings.loc[
PCA_Loadings["PC1"].abs().sort_values(ascending=False).index
]
PCA_Loadings = PCA_Loadings.reset_index(drop=True)
if activate_caching and cache_folder and all_hash:
try:
pca_components_path = (
cache_folder / f"pca_components_{all_hash}.parquet"
)
pca_summary_path = cache_folder / f"pca_summary_{all_hash}.parquet"
pca_loadings_path = cache_folder / f"pca_loadings_{all_hash}.parquet"
scaler_path = cache_folder / f"pca_scaler_{all_hash}.joblib"
pca_model_path = cache_folder / f"pca_model_{all_hash}.joblib"
PCA_Components.to_parquet(pca_components_path)
PCA_Summary.to_parquet(pca_summary_path)
PCA_Loadings.to_parquet(pca_loadings_path)
if Scaler is not None:
joblib.dump(Scaler, scaler_path)
if PCA is not None:
joblib.dump(PCA, pca_model_path)
verbose and logger.info(f"Results saved to cache")
except Exception as e:
verbose and logger.warning(f"Failed to save cache: {e}")
try:
if hue and hue.strip():
if hue not in df.columns:
raise ValueError(f"Hue column '{hue}' not found in DataFrame")
hue_col = hue
hue_data = df.loc[PCA_Components.index, hue].values
verbose and logger.info(f"Using hue column: '{hue}'")
verbose and logger.info(f"Creating {plot_type} visualization")
if plot_components is None or plot_explained_var is None:
raise RuntimeError("PCA plot data is missing (logic error)")
if plot_type == "all":
(fig, axes) = plt.subplots(2, 2, figsize=(16, 14))
ax_list = axes.flatten()
_plot_scatter(
ax_list[0],
plot_components,
hue_data,
hue_col,
palette,
plot_explained_var,
pc_x,
pc_y,
)
_plot_variance(ax_list[1], plot_explained_var)
_plot_biplot(
ax_list[2],
plot_components,
plot_loadings,
feature_cols,
plot_explained_var,
hue_data,
hue_col,
palette,
pc_x,
pc_y,
)
_plot_correlation_circle(
ax_list[3], plot_loadings, feature_cols, plot_explained_var, pc_x, pc_y
)
plt.tight_layout()
else:
(fig, ax) = plt.subplots(figsize=(10, 8))
if plot_type == "scatter":
_plot_scatter(
ax,
plot_components,
hue_data,
hue_col,
palette,
plot_explained_var,
pc_x,
pc_y,
)
elif plot_type == "biplot":
_plot_biplot(
ax,
plot_components,
plot_loadings,
feature_cols,
plot_explained_var,
hue_data,
hue_col,
palette,
pc_x,
pc_y,
)
elif plot_type == "variance":
_plot_variance(ax, plot_explained_var)
elif plot_type == "correlation_circle":
_plot_correlation_circle(
ax, plot_loadings, feature_cols, plot_explained_var, pc_x, pc_y
)
else:
raise ValueError(f"Invalid plot_type: '{plot_type}'")
verbose and logger.info(f"Rendering to {output_type} format with DPI={dpi}")
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight")
buf.seek(0)
if output_type == "bytesio":
PCA_Image = buf
elif output_type == "bytes":
PCA_Image = buf.getvalue()
buf.close()
elif output_type == "pil":
PCA_Image = Image.open(buf)
PCA_Image.load()
buf.close()
elif output_type == "array":
img = Image.open(buf)
PCA_Image = np.array(img)
buf.close()
else:
raise ValueError(f"Invalid output_type: '{output_type}'")
plt.close(fig)
except (ValueError, RuntimeError):
plt.close("all")
raise
except Exception as e:
error_msg = f"Failed to perform PCA: {e}"
verbose and logger.error(f"{error_msg}")
plt.close("all")
raise RuntimeError(error_msg) from e
if PCA_Image is None:
raise RuntimeError("PCA analysis returned empty result")
verbose and logger.info(f"Successfully completed PCA analysis")
return (PCA, Scaler, PCA_Image, PCA_Components, PCA_Summary, PCA_Loadings)
Brick Info
- scikit-learn
- pandas
- pyarrow
- polars[pyarrow]
- numpy
- joblib
- matplotlib
- seaborn
- pillow
- xxhash