Dim. Reduction UMAP
Visualize high-dimensional data using Uniform Manifold Approximation and Projection.
Dim. Reduction UMAP
Processing
This brick simplifies complex, high-dimensional datasets into a 2D or 3D representation using the Uniform Manifold Approximation and Projection (UMAP) algorithm.
In simple terms, it takes data with many characteristics (columns) and "flattens" it into a visual map. It is designed to keep similar items close together and dissimilar items far apart. This is widely used for:
- Clustering: Seeing natural groups in data (e.g., customer segments, biological species).
- Anomaly Detection: Spotting points that don't fit into any group.
- Visualization: Making datasets with dozens of columns understandable in a simple X/Y scatter plot.
The brick handles data scaling, dimensionality reduction, and automatic visualization generation.
Inputs
- data
- The dataset you want to analyze. This should contain the numeric features (columns) you want to use for clustering or visualization.
Inputs Types
| Input | Types |
|---|---|
data |
DataFrame, ArrowTable |
You can check the list of supported types here: Available Type Hints.
Outputs
- UMAP
- The trained UMAP model object. This contains the mathematical rules learned from your data and can be used in custom Python scripts to transform new data later.
- Scaler
- The scaling model used to normalize the data before processing. This can be used later to transform new data in the exact same way.
- UMAP_Image
- A visualization of the data. This is a scatter plot showing the data points projected onto the selected dimensions (e.g., Dimension 1 vs. Dimension 2).
- UMAP_Projections
- The processed data containing the new coordinates. This is a DataFrame where each row corresponds to the original data, but with new columns representing the reduced dimensions.
The UMAP_Projections output contains the following specific data fields:
- {ID Column}: If an ID column was specified in the options, it appears here to help identify rows.
- UMAP1: The coordinate for the first reduced dimension.
- UMAP2: The coordinate for the second reduced dimension.
- UMAP3: (If "Number of Components" is set to 3) The coordinate for the third dimension.
Outputs Types
| Output | Types |
|---|---|
UMAP |
Any |
Scaler |
Any |
UMAP Image |
MediaData, PILImage |
UMAP Projections |
DataFrame |
You can check the list of supported types here: Available Type Hints.
Options
The Dim. Reduction UMAP brick contains some changeable options:
- Columns for UMAP
- Select specifically which numeric columns to use for the calculation. If left empty, the brick automatically selects all numeric columns.
- Number of Components
- The number of dimensions to reduce the data down to.
- Component for X-axis
- Selects which of the calculated dimensions to plot on the horizontal (X) axis of the output image. Usually set to 1.
- Component for Y-axis
- Selects which of the calculated dimensions to plot on the vertical (Y) axis of the output image. Usually set to 2.
- Number of Neighbors
- Controls how the algorithm balances local detail vs. global structure.
- Low values (e.g., 5-10): Focuses on local structure. Good for finding small clusters, but might lose the big picture.
- High values (e.g., 50-100): Focuses on global structure. Good for seeing the overall shape of the data, but might obscure fine details.
- Minimum Distance
- Controls how tightly the points are allowed to pack together in the output.
- Low values (e.g., 0.1): Points clump tightly together. Good for distinct clustering.
- High values (e.g., 0.5): Points are more evenly distributed. Good for preserving the topological structure.
- Distance Metric
- The mathematical rule used to calculate the "distance" between two data points.
- Euclidean: The standard "straight line" distance. Works well for most general data.
- Cosine: Measures the angle between vectors. Excellent for text data or word embeddings.
- Manhattan: (Taxicab geometry) Measures distance along axes at right angles.
- Correlation, Chebyshev, Canberra, Braycurtis: Specialized metrics for specific statistical use cases.
- Scale Data
- If enabled, the data is normalized (StandardScaler) before processing. This is highly recommended so that columns with large numbers (e.g., "Salary") don't dominate columns with small numbers (e.g., "Age").
- ID Column
- The name of a column in your input data that acts as a unique identifier (e.g., "Product_ID", "Email"). This column will be excluded from calculations but added back to the
UMAP_Projectionsoutput so you can identify your rows. - Hue Column (for scatter)
- The name of a column to use for coloring the dots in the output image. For example, if you set this to "Species", points will be colored based on their species.
- Exclude Hue
- Determines if the column used for coloring (the Hue) should be part of the mathematical calculation. If the hue column is not numerical or boolean, it is automatically excluded.
- True (Active): The Hue column is excluded from the dimensionality reduction. The algorithm ignores this data when calculating positions/clusters, using it only to assign colors in the final result. This is best when the Hue represents a label, outcome, or category (e.g., "Customer Type") that you want to visualize but don't want influencing the clustering logic.
- False (Inactive): The Hue column is included in the dimensionality reduction. The values in this column will actively influence where the data points are positioned on the graph.
- Color Palette
- The color scheme used for the visualization. husl, deep, muted, bright, pastel, dark, colorblind: Various seaborn color palettes.
- Output Type
- Determines the format of the
UMAP Imageoutput.
- array: Returns a NumPy array (standard for image processing).
- pil: Returns a PIL Image object.
- bytes: Returns the raw file bytes (useful for saving directly to disk).
- bytesio: Returns a BytesIO stream.
- Number of Jobs
- How many CPU cores to use for parallel processing. More cores speed up training but use more system resources.
- Random State
- A seed number for the random number generator. Using the same number ensures that running the brick twice on the same data produces the exact same result.
- Brick Caching
- If enabled, results are saved temporarily. If you run the workflow again with the exact same data and settings, the brick loads the result from the cache instead of recalculating, which is significantly faster.
- Verbose
- If enabled, detailed logs about the progress (scaling, fitting, plotting) will be printed to the console.
import logging
import io
import json
import xxhash
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import matplotlib
from scipy import sparse
from dataclasses import dataclass
from datetime import datetime
import hashlib
import tempfile
import sklearn
import joblib
from pathlib import Path
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import umap
from sklearn.preprocessing import StandardScaler
from coded_flows.types import (
Union,
DataFrame,
ArrowTable,
MediaData,
PILImage,
Tuple,
Any,
)
from coded_flows.utils import CodedFlowsLogger
logger = CodedFlowsLogger(name="Dim. Reduction UMAP", level=logging.INFO)
DataType = Union[
pd.DataFrame, pl.DataFrame, np.ndarray, sparse.spmatrix, pd.Series, pl.Series
]
@dataclass
class _DatasetFingerprint:
"""Lightweight fingerprint of a dataset."""
hash: str
shape: tuple
computed_at: str
data_type: str
method: str
class _UniversalDatasetHasher:
"""
High-performance dataset hasher optimizing for zero-copy operations.
"""
def __init__(
self,
data_size: int,
method: str = "auto",
sample_size: int = 100000,
verbose: bool = False,
):
self.method = method
self.sample_size = sample_size
self.data_size = data_size
self.verbose = verbose
def hash_data(self, data: DataType) -> _DatasetFingerprint:
if isinstance(data, pd.DataFrame):
return self._hash_pandas(data)
elif isinstance(data, pl.DataFrame):
return self._hash_polars(data)
elif isinstance(data, pd.Series):
return self._hash_pandas_series(data)
elif isinstance(data, pl.Series):
return self._hash_polars_series(data)
elif isinstance(data, np.ndarray):
return self._hash_numpy(data)
elif sparse.issparse(data):
return self._hash_sparse(data)
else:
raise TypeError(f"Unsupported data type: {type(data)}")
def _hash_pandas(self, df: pd.DataFrame) -> _DatasetFingerprint:
method = self._determine_method(self.data_size, self.method)
target_df = df
if method == "sampled":
target_df = self._get_pandas_sample(df)
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{
"columns": df.columns.tolist(),
"dtypes": {k: str(v) for (k, v) in df.dtypes.items()},
"shape": df.shape,
},
)
try:
row_hashes = pd.util.hash_pandas_object(target_df, index=False)
hasher.update(memoryview(row_hashes.values))
except Exception:
self._hash_pandas_fallback(hasher, target_df)
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=df.shape,
computed_at=datetime.now().isoformat(),
data_type="pandas",
method=method,
)
def _get_pandas_sample(self, df: pd.DataFrame) -> pd.DataFrame:
if self.data_size <= self.sample_size:
return df
chunk = self.sample_size // 3
head = df.iloc[:chunk]
mid_idx = self.data_size // 2
mid = df.iloc[mid_idx : mid_idx + chunk]
tail = df.iloc[-chunk:]
return pd.concat([head, mid, tail])
def _hash_pandas_fallback(self, hasher, df: pd.DataFrame):
for col in df.columns:
val = df[col].astype(str).values
hasher.update(val.astype(np.bytes_).tobytes())
def _hash_polars(self, df: pl.DataFrame) -> _DatasetFingerprint:
method = self._determine_method(self.data_size, self.method)
target_df = df
if method == "sampled" and self.data_size > self.sample_size:
indices = self._get_sample_indices(self.data_size, self.sample_size)
target_df = df.gather(indices)
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{
"columns": df.columns,
"dtypes": [str(t) for t in df.dtypes],
"shape": df.shape,
},
)
row_hashes = target_df.hash_rows()
hasher.update(memoryview(row_hashes.to_numpy()))
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=df.shape,
computed_at=datetime.now().isoformat(),
data_type="polars",
method=method,
)
def _hash_pandas_series(self, series: pd.Series) -> _DatasetFingerprint:
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{
"name": series.name if series.name else "None",
"dtype": str(series.dtype),
"shape": series.shape,
},
)
try:
row_hashes = pd.util.hash_pandas_object(series, index=False)
hasher.update(memoryview(row_hashes.values))
except Exception:
hasher.update(memoryview(series.astype(str).values.tobytes()))
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=series.shape,
computed_at=datetime.now().isoformat(),
data_type="pandas_series",
method="full",
)
def _hash_polars_series(self, series: pl.Series) -> _DatasetFingerprint:
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{"name": series.name, "dtype": str(series.dtype), "shape": series.shape},
)
try:
row_hashes = series.hash()
hasher.update(memoryview(row_hashes.to_numpy()))
except Exception:
hasher.update(str(series.to_list()).encode())
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=series.shape,
computed_at=datetime.now().isoformat(),
data_type="polars_series",
method="full",
)
def _hash_numpy(self, arr: np.ndarray) -> _DatasetFingerprint:
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{"shape": arr.shape, "dtype": str(arr.dtype), "strides": arr.strides},
)
if arr.flags["C_CONTIGUOUS"] or arr.flags["F_CONTIGUOUS"]:
hasher.update(memoryview(arr))
else:
hasher.update(memoryview(np.ascontiguousarray(arr)))
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=arr.shape,
computed_at=datetime.now().isoformat(),
data_type="numpy",
method="full",
)
def _hash_sparse(self, matrix: sparse.spmatrix) -> _DatasetFingerprint:
if not (sparse.isspmatrix_csr(matrix) or sparse.isspmatrix_csc(matrix)):
matrix = matrix.tocsr()
hasher = xxhash.xxh128()
self._hash_schema(
hasher, {"shape": matrix.shape, "format": matrix.format, "nnz": matrix.nnz}
)
hasher.update(memoryview(matrix.data))
hasher.update(memoryview(matrix.indices))
hasher.update(memoryview(matrix.indptr))
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=matrix.shape,
computed_at=datetime.now().isoformat(),
data_type=f"sparse_{matrix.format}",
method="sparse",
)
def _determine_method(self, rows: int, requested: str) -> str:
if requested != "auto":
return requested
if rows < 5000000:
return "full"
return "sampled"
def _hash_schema(self, hasher, schema):
hasher.update(
json.dumps(schema, sort_keys=True, separators=(",", ":")).encode()
)
def _get_sample_indices(self, total_rows: int, sample_size: int) -> list:
chunk = sample_size // 3
indices = list(range(min(chunk, total_rows)))
mid_start = max(0, total_rows // 2 - chunk // 2)
mid_end = min(mid_start + chunk, total_rows)
indices.extend(range(mid_start, mid_end))
last_start = max(0, total_rows - chunk)
indices.extend(range(last_start, total_rows))
return sorted(list(set(indices)))
def _plot_umap_scatter(
ax, x_data, y_data, x_label, y_label, hue_data, hue_col, palette
):
"""Plot UMAP scatter plot with selected axes."""
if hue_data is not None:
scatter_df = pd.DataFrame({x_label: x_data, y_label: y_data, hue_col: hue_data})
sns.scatterplot(
data=scatter_df,
x=x_label,
y=y_label,
hue=hue_col,
palette=palette,
ax=ax,
alpha=0.7,
s=50,
)
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
else:
ax.scatter(x_data, y_data, alpha=0.7, s=50)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_title("UMAP Projection")
ax.grid(True, alpha=0.3)
def umap_analysis(
data: Union[DataFrame, ArrowTable], options=None
) -> Tuple[Any, Any, Union[MediaData, PILImage], DataFrame]:
options = options or {}
verbose = options.get("verbose", True)
random_state = options.get("random_state", 42)
activate_caching = options.get("activate_caching", False)
output_type = options.get("output_type", "array")
columns = options.get("columns", None)
n_components = options.get("n_components", 2)
scale_data = options.get("scale_data", True)
id_column = options.get("id_column", "")
hue = options.get("hue", "")
exclude_hue = options.get("exclude_hue", True)
palette = options.get("palette", "husl")
n_jobs_str = options.get("n_jobs", "1")
n_jobs_int = -1 if n_jobs_str == "All" else int(n_jobs_str)
pc_x = options.get("pc_x", 1)
pc_y = options.get("pc_y", 2)
n_neighbors = options.get("n_neighbors", 15)
min_dist = options.get("min_dist", 0.1)
metric_raw = options.get("metric", "Euclidean")
metric = metric_raw.lower() if metric_raw else "euclidean"
dpi = 300
verbose and logger.info(
f"Starting UMAP with {n_components} components. Displaying Dim{pc_x} vs Dim{pc_y}"
)
verbose and logger.info(
f"Config: neighbors={n_neighbors}, min_dist={min_dist}, metric={metric}"
)
if pc_x > n_components or pc_y > n_components:
verbose and logger.warning(
f"Selected axes (X:{pc_x}, Y:{pc_y}) exceed n_components ({n_components}). Clamping to valid range."
)
pc_x = min(pc_x, n_components)
pc_y = min(pc_y, n_components)
UMAP = None
UMAP_Image = None
UMAP_Projections = pd.DataFrame()
Scaler = None
plot_components = None
hue_data = None
hue_col = None
df = None
try:
if isinstance(data, pl.DataFrame):
verbose and logger.info(f"Converting polars DataFrame to pandas")
df = data.to_pandas()
elif isinstance(data, pa.Table):
verbose and logger.info(f"Converting Arrow table to pandas")
df = data.to_pandas()
elif isinstance(data, pd.DataFrame):
verbose and logger.info(f"Input is already pandas DataFrame")
df = data
else:
raise ValueError(f"Unsupported data type: {type(data).__name__}")
except Exception as e:
raise RuntimeError(
f"Failed to convert input data to pandas DataFrame: {e}"
) from e
if df.empty:
raise ValueError("Input DataFrame is empty")
verbose and logger.info(
f"Processing DataFrame with {df.shape[0]:,} rows × {df.shape[1]:,} columns"
)
if hue and hue.strip():
if hue not in df.columns:
raise ValueError(f"Hue column '{hue}' not found")
hue_col = hue
if id_column and id_column.strip() and (id_column not in df.columns):
raise ValueError(f"ID column '{id_column}' not found in DataFrame")
if columns and len(columns) > 0:
missing_cols = [col for col in columns if col not in df.columns]
if missing_cols:
raise ValueError(f"Columns not found in DataFrame: {missing_cols}")
feature_cols = list(columns)
else:
feature_cols = df.select_dtypes(include=["number", "bool"]).columns.tolist()
if not feature_cols:
raise ValueError("No numeric columns found in DataFrame")
if id_column and id_column.strip() and (id_column in feature_cols):
feature_cols.remove(id_column)
if exclude_hue and hue_col and (hue_col in feature_cols):
feature_cols.remove(hue_col)
if not feature_cols:
raise ValueError("No feature columns remaining after excluding ID column")
skip_computation = False
cache_folder = None
all_hash = None
if activate_caching:
verbose and logger.info(f"Caching is active")
data_hasher = _UniversalDatasetHasher(df.shape[0], verbose=verbose)
X_hash = data_hasher.hash_data(df[feature_cols]).hash
all_hash_base_text = f"HASH BASE TEXT UMAPScikit Learn Version {sklearn.__version__}{X_hash}_{n_components}_{scale_data}_{random_state}{n_neighbors}_{min_dist}_{metric}{exclude_hue}{n_jobs_int}{sorted(feature_cols)}"
all_hash = hashlib.sha256(all_hash_base_text.encode("utf-8")).hexdigest()
verbose and logger.info(f"Hash was computed: {all_hash}")
temp_folder = Path(tempfile.gettempdir())
cache_folder = temp_folder / "coded-flows-cache"
cache_folder.mkdir(parents=True, exist_ok=True)
umap_proj_path = cache_folder / f"umap_proj_{all_hash}.parquet"
umap_est_path = cache_folder / f"umap_estimator_{all_hash}.joblib"
scaler_path = cache_folder / f"umap_scaler_{all_hash}.joblib"
if (
umap_proj_path.is_file()
and umap_est_path.is_file()
and scaler_path.is_file()
):
verbose and logger.info(f"Cache hit! Loading results.")
try:
UMAP_Projections = pd.read_parquet(umap_proj_path)
UMAP = joblib.load(umap_est_path)
Scaler = joblib.load(scaler_path)
comps_for_plot_df = UMAP_Projections.copy()
if id_column and id_column in comps_for_plot_df.columns:
comps_for_plot_df = comps_for_plot_df.drop(columns=[id_column])
plot_components = comps_for_plot_df.values
skip_computation = True
except Exception as e:
verbose and logger.warning(f"Cache load failed, recomputing: {e}")
skip_computation = False
if not skip_computation:
X = df[feature_cols]
if np.any(np.isnan(X)):
verbose and logger.warning(f"Removing rows with missing values")
mask = ~np.isnan(X).any(axis=1)
X = X[mask]
df_indices = df.index[mask]
else:
df_indices = df.index
if X.shape[0] < n_neighbors:
verbose and logger.warning(
f"Data samples ({X.shape[0]}) < n_neighbors ({n_neighbors}). Adjusting n_neighbors."
)
n_neighbors = max(2, int(X.shape[0] - 1))
if scale_data:
verbose and logger.info(f"Scaling features")
Scaler = StandardScaler()
X_transformed = Scaler.fit_transform(X)
else:
X_transformed = X
verbose and logger.info(f"Fitting UMAP model (metric={metric})")
umap_params = {
"n_neighbors": n_neighbors,
"min_dist": min_dist,
"n_components": n_components,
"metric": metric,
"n_jobs": n_jobs_int,
}
if n_jobs_int == 1:
umap_params["random_state"] = random_state
UMAP = umap.UMAP(**umap_params)
plot_components = UMAP.fit_transform(X_transformed)
umap_cols = [f"UMAP{i + 1}" for i in range(n_components)]
UMAP_Projections = pd.DataFrame(
plot_components, columns=umap_cols, index=df_indices
)
if id_column and id_column.strip():
id_data = df.loc[df_indices, id_column].values
UMAP_Projections.insert(0, id_column, id_data)
if activate_caching and cache_folder and all_hash:
try:
umap_proj_path = cache_folder / f"umap_proj_{all_hash}.parquet"
umap_est_path = cache_folder / f"umap_estimator_{all_hash}.joblib"
scaler_path = cache_folder / f"umap_scaler_{all_hash}.joblib"
UMAP_Projections.to_parquet(umap_proj_path)
joblib.dump(UMAP, umap_est_path)
if Scaler is not None:
joblib.dump(Scaler, scaler_path)
verbose and logger.info(f"Results saved to cache")
except Exception as e:
verbose and logger.warning(f"Failed to save cache: {e}")
try:
idx_x = pc_x - 1
idx_y = pc_y - 1
x_values = plot_components[:, idx_x]
y_values = plot_components[:, idx_y]
x_label = f"UMAP Dimension {pc_x}"
y_label = f"UMAP Dimension {pc_y}"
if hue_col:
hue_data = df.loc[UMAP_Projections.index, hue].values
verbose and logger.info(f"Using hue column: '{hue}'")
verbose and logger.info(
f"Creating scatter visualization ({x_label} vs {y_label})"
)
(fig, ax) = plt.subplots(figsize=(10, 7))
_plot_umap_scatter(
ax, x_values, y_values, x_label, y_label, hue_data, hue_col, palette
)
verbose and logger.info(f"Rendering to {output_type}")
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight")
buf.seek(0)
if output_type == "bytesio":
UMAP_Image = buf
elif output_type == "bytes":
UMAP_Image = buf.getvalue()
buf.close()
elif output_type == "pil":
UMAP_Image = Image.open(buf)
UMAP_Image.load()
buf.close()
elif output_type == "array":
img = Image.open(buf)
UMAP_Image = np.array(img)
buf.close()
else:
raise ValueError(f"Invalid output_type: '{output_type}'")
plt.close(fig)
except Exception as e:
verbose and logger.error(f"Error during plotting: {e}")
plt.close("all")
raise RuntimeError(f"Plotting failed: {e}") from e
verbose and logger.info(f"Successfully completed UMAP analysis")
return (UMAP, Scaler, UMAP_Image, UMAP_Projections)
Brick Info
- shap>=0.47.0
- scikit-learn
- pandas
- pyarrow
- polars[pyarrow]
- numpy
- umap-learn
- numba>=0.56.0
- joblib
- matplotlib
- seaborn
- pillow
- xxhash