Dim. Reduction T-SNE
Visualize high-dimensional data using t-Distributed Stochastic Neighbor Embedding.
Dim. Reduction T-SNE
Processing
This brick reduces high-dimensional data into 2 or 3 dimensions using t-SNE (t-Distributed Stochastic Neighbor Embedding). It helps visualize complex datasets by grouping similar items together in a way that reveals patterns or clusters that might not be visible in the raw data.
The brick processes numeric columns, optionally scales them to a standard range, calculates the projection, and generates two main results: a scatter plot visualization of the data and a dataset containing the new coordinate values.
Inputs
- data
- The dataset containing the information you want to analyze. This must contain numeric columns (e.g., measurements, scores, counts) to be processed.
Inputs Types
| Input | Types |
|---|---|
data |
DataFrame, ArrowTable |
You can check the list of supported types here: Available Type Hints.
Outputs
- Scaler
- The scaling model used to normalize the data before processing. This can be passed to subsequent bricks if you need to reverse the scaling or apply the same scaling to new data.
- TSNE Image
- A visual representation of the t-SNE result. This is a scatter plot where similar data points are grouped closer together.
- TSNE Projections
- The processed data containing the new calculated coordinates (dimensions).
The TSNE Projections output contains the following specific data fields:
- TSNE1: The coordinate for the first dimension.
- TSNE2: The coordinate for the second dimension.
- TSNE3: (If 3 components selected) The coordinate for the third dimension.
- {ID Column}: (If provided) The identifier from your original data.
Outputs Types
| Output | Types |
|---|---|
Scaler |
Any |
TSNE Image |
MediaData,PILImage |
TSNE Projections |
DataFrame |
You can check the list of supported types here: Available Type Hints.
Options
The Dim. Reduction T-SNE brick contains some changeable options:
- Columns for T-SNE
- Select specific numeric columns to use for the analysis. If left empty, all numeric columns in the dataset will be used.
- Number of Components
- The number of dimensions to reduce the data into.
- 2: Reduces data to 2 dimensions (best for standard 2D plots).
- 3: Reduces data to 3 dimensions.
- Component for X-axis
- Selects which of the calculated dimensions (components) to plot on the horizontal X-axis. Usually "1".
- Component for Y-axis
- Selects which of the calculated dimensions (components) to plot on the vertical Y-axis. Usually "2".
- Distance Metric
- The method used to calculate the distance/similarity between data points.
- Euclidean: Standard straight-line distance. Good for general physical data.
- Cosine: Measures the angle between vectors. Good for text data or high-dimensional sparse data.
- Manhattan: Grid-like distance (L1 norm).
- Chebyshev: The greatest difference along any coordinate dimension.
- Perplexity
- Controls how the algorithm balances attention between local and global aspects of your data. It is roughly a guess about the number of close neighbors each point has.
- Lower values (5-30): Focuses on local structure (small, tight groups).
- Higher values (30-100): Focuses on global structure (overall shape).
- Auto Learning Rate
- If enabled, the algorithm automatically calculates the best learning rate based on your data size.
- Learning Rate (if not Auto)
- The step size for the optimization algorithm. If the result looks like a "ball" with no structure, try increasing this. If it looks like a condensed cloud with points far apart, try decreasing it.
- Max Iterations
- The maximum number of times the algorithm will run to refine the shape. Higher numbers take longer but may produce more stable results.
- Early Exaggeration
- Controls how tight natural clusters in the original space are in the embedded space and how much space will be between them.
- Angle (Barnes-Hut)
- Controls the trade-off between speed and accuracy.
- Lower (e.g., 0.2): More accurate, but slower.
- Higher (e.g., 0.8): Faster, but less accurate.
- Min Gradient Norm
- A technical threshold to stop the processing early if the changes become insignificant.
- Scale Data
- If enabled, data is standardized (mean=0, variance=1) before processing. This is highly recommended to prevent columns with large numbers (e.g., "Salary") from dominating columns with small numbers (e.g., "Age").
- ID Column
- The name of a column in your input data that serves as a unique identifier (e.g., "Product_ID", "Email"). This column will be preserved in the output dataset so you can map the results back to your original items.
- Hue Column (for scatter)
- The name of a column to use for coloring the points in the plot (e.g., "Category", "Status").
- Exclude Hue
- Determines if the column used for coloring (the Hue) should be part of the mathematical calculation. If the hue column is not numerical or boolean, it is automatically excluded.
- True (Active): The Hue column is excluded from the dimensionality reduction. The algorithm ignores this data when calculating positions/clusters, using it only to assign colors in the final result. This is best when the Hue represents a label, outcome, or category (e.g., "Customer Type") that you want to visualize but don't want influencing the clustering logic.
- False (Inactive): The Hue column is included in the dimensionality reduction. The values in this column will actively influence where the data points are positioned on the graph.
- Color Palette
- The color scheme used for the visualization.
- Output Type
- The format of the resulting image.
- array: Returns a NumPy array representation of the image.
- pil: Returns a PIL Image object.
- bytes: Returns the raw image file bytes.
- bytesio: Returns a BytesIO stream.
- Number of Jobs
- How many CPU cores to use for parallel processing. More cores speed up training but use more system resources.
- Random State
- A seed number to ensure the results are reproducible. Using the same number ensures the plot looks the same every time you run it.
- Brick Caching
- If enabled, the result is saved temporarily. Running the brick again with the exact same inputs will load the result from the cache instead of recalculating, speeding up the workflow.
- Verbose
- If enabled, detailed logs about the processing steps will be generated.
import logging
import io
import json
import xxhash
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import matplotlib
from scipy import sparse
from dataclasses import dataclass
from datetime import datetime
import hashlib
import tempfile
import sklearn
import joblib
from pathlib import Path
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
from coded_flows.types import (
Union,
DataFrame,
ArrowTable,
MediaData,
PILImage,
Tuple,
Any,
DataSeries,
Str,
)
from coded_flows.utils import CodedFlowsLogger
logger = CodedFlowsLogger(name="Dim. Reduction T-SNE", level=logging.INFO)
DataType = Union[
pd.DataFrame, pl.DataFrame, np.ndarray, sparse.spmatrix, pd.Series, pl.Series
]
@dataclass
class _DatasetFingerprint:
"""Lightweight fingerprint of a dataset."""
hash: str
shape: tuple
computed_at: str
data_type: str
method: str
class _UniversalDatasetHasher:
"""
High-performance dataset hasher optimizing for zero-copy operations.
"""
def __init__(
self,
data_size: int,
method: str = "auto",
sample_size: int = 100000,
verbose: bool = False,
):
self.method = method
self.sample_size = sample_size
self.data_size = data_size
self.verbose = verbose
def hash_data(self, data: DataType) -> _DatasetFingerprint:
if isinstance(data, pd.DataFrame):
return self._hash_pandas(data)
elif isinstance(data, pl.DataFrame):
return self._hash_polars(data)
elif isinstance(data, pd.Series):
return self._hash_pandas_series(data)
elif isinstance(data, pl.Series):
return self._hash_polars_series(data)
elif isinstance(data, np.ndarray):
return self._hash_numpy(data)
elif sparse.issparse(data):
return self._hash_sparse(data)
else:
raise TypeError(f"Unsupported data type: {type(data)}")
def _hash_pandas(self, df: pd.DataFrame) -> _DatasetFingerprint:
method = self._determine_method(self.data_size, self.method)
target_df = df
if method == "sampled":
target_df = self._get_pandas_sample(df)
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{
"columns": df.columns.tolist(),
"dtypes": {k: str(v) for (k, v) in df.dtypes.items()},
"shape": df.shape,
},
)
try:
row_hashes = pd.util.hash_pandas_object(target_df, index=False)
hasher.update(memoryview(row_hashes.values))
except Exception:
self._hash_pandas_fallback(hasher, target_df)
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=df.shape,
computed_at=datetime.now().isoformat(),
data_type="pandas",
method=method,
)
def _get_pandas_sample(self, df: pd.DataFrame) -> pd.DataFrame:
if self.data_size <= self.sample_size:
return df
chunk = self.sample_size // 3
head = df.iloc[:chunk]
mid_idx = self.data_size // 2
mid = df.iloc[mid_idx : mid_idx + chunk]
tail = df.iloc[-chunk:]
return pd.concat([head, mid, tail])
def _hash_pandas_fallback(self, hasher, df: pd.DataFrame):
for col in df.columns:
val = df[col].astype(str).values
hasher.update(val.astype(np.bytes_).tobytes())
def _hash_polars(self, df: pl.DataFrame) -> _DatasetFingerprint:
method = self._determine_method(self.data_size, self.method)
target_df = df
if method == "sampled" and self.data_size > self.sample_size:
indices = self._get_sample_indices(self.data_size, self.sample_size)
target_df = df.gather(indices)
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{
"columns": df.columns,
"dtypes": [str(t) for t in df.dtypes],
"shape": df.shape,
},
)
row_hashes = target_df.hash_rows()
hasher.update(memoryview(row_hashes.to_numpy()))
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=df.shape,
computed_at=datetime.now().isoformat(),
data_type="polars",
method=method,
)
def _hash_pandas_series(self, series: pd.Series) -> _DatasetFingerprint:
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{
"name": series.name if series.name else "None",
"dtype": str(series.dtype),
"shape": series.shape,
},
)
try:
row_hashes = pd.util.hash_pandas_object(series, index=False)
hasher.update(memoryview(row_hashes.values))
except Exception:
hasher.update(memoryview(series.astype(str).values.tobytes()))
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=series.shape,
computed_at=datetime.now().isoformat(),
data_type="pandas_series",
method="full",
)
def _hash_polars_series(self, series: pl.Series) -> _DatasetFingerprint:
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{"name": series.name, "dtype": str(series.dtype), "shape": series.shape},
)
try:
row_hashes = series.hash()
hasher.update(memoryview(row_hashes.to_numpy()))
except Exception:
hasher.update(str(series.to_list()).encode())
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=series.shape,
computed_at=datetime.now().isoformat(),
data_type="polars_series",
method="full",
)
def _hash_numpy(self, arr: np.ndarray) -> _DatasetFingerprint:
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{"shape": arr.shape, "dtype": str(arr.dtype), "strides": arr.strides},
)
if arr.flags["C_CONTIGUOUS"] or arr.flags["F_CONTIGUOUS"]:
hasher.update(memoryview(arr))
else:
hasher.update(memoryview(np.ascontiguousarray(arr)))
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=arr.shape,
computed_at=datetime.now().isoformat(),
data_type="numpy",
method="full",
)
def _hash_sparse(self, matrix: sparse.spmatrix) -> _DatasetFingerprint:
if not (sparse.isspmatrix_csr(matrix) or sparse.isspmatrix_csc(matrix)):
matrix = matrix.tocsr()
hasher = xxhash.xxh128()
self._hash_schema(
hasher, {"shape": matrix.shape, "format": matrix.format, "nnz": matrix.nnz}
)
hasher.update(memoryview(matrix.data))
hasher.update(memoryview(matrix.indices))
hasher.update(memoryview(matrix.indptr))
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=matrix.shape,
computed_at=datetime.now().isoformat(),
data_type=f"sparse_{matrix.format}",
method="sparse",
)
def _determine_method(self, rows: int, requested: str) -> str:
if requested != "auto":
return requested
if rows < 5000000:
return "full"
return "sampled"
def _hash_schema(self, hasher, schema):
hasher.update(
json.dumps(schema, sort_keys=True, separators=(",", ":")).encode()
)
def _get_sample_indices(self, total_rows: int, sample_size: int) -> list:
chunk = sample_size // 3
indices = list(range(min(chunk, total_rows)))
mid_start = max(0, total_rows // 2 - chunk // 2)
mid_end = min(mid_start + chunk, total_rows)
indices.extend(range(mid_start, mid_end))
last_start = max(0, total_rows - chunk)
indices.extend(range(last_start, total_rows))
return sorted(list(set(indices)))
def _plot_tsne_scatter(
ax, x_data, y_data, x_label, y_label, hue_data, hue_col, palette
):
"""Plot T-SNE scatter plot with selected axes."""
if hue_data is not None:
scatter_df = pd.DataFrame({x_label: x_data, y_label: y_data, hue_col: hue_data})
sns.scatterplot(
data=scatter_df,
x=x_label,
y=y_label,
hue=hue_col,
palette=palette,
ax=ax,
alpha=0.7,
s=50,
)
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
else:
ax.scatter(x_data, y_data, alpha=0.7, s=50)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_title("t-SNE Projection")
ax.grid(True, alpha=0.3)
def tsne_analysis(
data: Union[DataFrame, ArrowTable], options=None
) -> Tuple[Any, Union[MediaData, PILImage], DataFrame]:
options = options or {}
verbose = options.get("verbose", True)
random_state = options.get("random_state", 42)
activate_caching = options.get("activate_caching", False)
output_type = options.get("output_type", "array")
columns = options.get("columns", None)
n_components = options.get("n_components", 2)
scale_data = options.get("scale_data", True)
id_column = options.get("id_column", "")
hue = options.get("hue", "")
exclude_hue = options.get("exclude_hue", True)
palette = options.get("palette", "husl")
n_jobs_str = options.get("n_jobs", "1")
n_jobs_int = -1 if n_jobs_str == "All" else int(n_jobs_str)
pc_x = options.get("pc_x", 1)
pc_y = options.get("pc_y", 2)
perplexity = options.get("perplexity", 30.0)
auto_lr = options.get("auto_learning_rate", True)
custom_lr = options.get("learning_rate", 200.0)
learning_rate = "auto" if auto_lr else custom_lr
n_iter = options.get("n_iter", 1000)
early_exaggeration = options.get("early_exaggeration", 12.0)
metric_raw = options.get("metric", "Euclidean")
metric = metric_raw.lower() if metric_raw else "euclidean"
angle = options.get("angle", 0.5)
min_grad_norm = options.get("min_grad_norm", 1e-07)
dpi = 300
verbose and logger.info(
f"Starting T-SNE with {n_components} components. Displaying PC{pc_x} vs PC{pc_y}"
)
verbose and logger.info(
f"Config: metric={metric}, lr={learning_rate}, angle={angle}"
)
if pc_x > n_components or pc_y > n_components:
verbose and logger.warning(
f"Selected PCs (X:{pc_x}, Y:{pc_y}) exceed n_components ({n_components}). Clamping to valid range."
)
pc_x = min(pc_x, n_components)
pc_y = min(pc_y, n_components)
TSNE_Image = None
TSNE_Projections = pd.DataFrame()
Scaler = None
plot_components = None
hue_data = None
hue_col = None
df = None
try:
if isinstance(data, pl.DataFrame):
verbose and logger.info(f"Converting polars DataFrame to pandas")
df = data.to_pandas()
elif isinstance(data, pa.Table):
verbose and logger.info(f"Converting Arrow table to pandas")
df = data.to_pandas()
elif isinstance(data, pd.DataFrame):
verbose and logger.info(f"Input is already pandas DataFrame")
df = data
else:
raise ValueError(f"Unsupported data type: {type(data).__name__}")
except Exception as e:
raise RuntimeError(
f"Failed to convert input data to pandas DataFrame: {e}"
) from e
if df.empty:
raise ValueError("Input DataFrame is empty")
verbose and logger.info(
f"Processing DataFrame with {df.shape[0]:,} rows × {df.shape[1]:,} columns"
)
if hue and hue.strip():
if hue not in df.columns:
raise ValueError(f"Hue column '{hue}' not found")
hue_col = hue
if id_column and id_column.strip() and (id_column not in df.columns):
raise ValueError(f"ID column '{id_column}' not found in DataFrame")
if columns and len(columns) > 0:
missing_cols = [col for col in columns if col not in df.columns]
if missing_cols:
raise ValueError(f"Columns not found in DataFrame: {missing_cols}")
feature_cols = list(columns)
else:
feature_cols = df.select_dtypes(include=["number", "bool"]).columns.tolist()
if not feature_cols:
raise ValueError("No numeric columns found in DataFrame")
if id_column and id_column.strip() and (id_column in feature_cols):
feature_cols.remove(id_column)
if exclude_hue and hue_col and (hue_col in feature_cols):
feature_cols.remove(hue_col)
if not feature_cols:
raise ValueError("No feature columns remaining after excluding ID column")
skip_computation = False
cache_folder = None
all_hash = None
if activate_caching:
verbose and logger.info(f"Caching is active")
data_hasher = _UniversalDatasetHasher(df.shape[0], verbose=verbose)
X_hash = data_hasher.hash_data(df[feature_cols]).hash
all_hash_base_text = f"HASH BASE TEXT TSNE_V3Scikit Learn Version {sklearn.__version__}{X_hash}_{n_components}_{scale_data}_{random_state}{perplexity}_{learning_rate}_{n_iter}_{early_exaggeration}{metric}_{angle}_{min_grad_norm}{exclude_hue}{sorted(feature_cols)}"
all_hash = hashlib.sha256(all_hash_base_text.encode("utf-8")).hexdigest()
verbose and logger.info(f"Hash was computed: {all_hash}")
temp_folder = Path(tempfile.gettempdir())
cache_folder = temp_folder / "coded-flows-cache"
cache_folder.mkdir(parents=True, exist_ok=True)
tsne_proj_path = cache_folder / f"tsne_proj_{all_hash}.parquet"
tsne_est_path = cache_folder / f"tsne_estimator_{all_hash}.joblib"
scaler_path = cache_folder / f"pca_scaler_{all_hash}.joblib"
if (
tsne_proj_path.is_file()
and tsne_est_path.is_file()
and scaler_path.is_file()
):
verbose and logger.info(f"Cache hit! Loading results.")
try:
TSNE_Projections = pd.read_parquet(tsne_proj_path)
tsne = joblib.load(tsne_est_path)
Scaler = joblib.load(scaler_path)
comps_for_plot_df = TSNE_Projections.copy()
if id_column and id_column in comps_for_plot_df.columns:
comps_for_plot_df = comps_for_plot_df.drop(columns=[id_column])
plot_components = comps_for_plot_df.values
skip_computation = True
except Exception as e:
verbose and logger.warning(f"Cache load failed, recomputing: {e}")
skip_computation = False
if not skip_computation:
X = df[feature_cols]
if np.any(np.isnan(X)):
verbose and logger.warning(f"Removing rows with missing values")
mask = ~np.isnan(X).any(axis=1)
X = X[mask]
df_indices = df.index[mask]
else:
df_indices = df.index
if X.shape[0] < perplexity:
verbose and logger.warning(
f"Data samples ({X.shape[0]}) < Perplexity ({perplexity}). Adjusting perplexity."
)
perplexity = max(5.0, float(X.shape[0] - 1))
if scale_data:
verbose and logger.info(f"Scaling features")
Scaler = StandardScaler()
X_transformed = Scaler.fit_transform(X)
else:
X_transformed = X
verbose and logger.info(f"Fitting T-SNE model (metric={metric})")
tsne = TSNE(
n_components=n_components,
perplexity=perplexity,
learning_rate=learning_rate,
max_iter=n_iter,
early_exaggeration=early_exaggeration,
metric=metric,
angle=angle,
min_grad_norm=min_grad_norm,
random_state=random_state,
n_jobs=n_jobs_int,
)
plot_components = tsne.fit_transform(X_transformed)
tsne_cols = [f"TSNE{i + 1}" for i in range(n_components)]
TSNE_Projections = pd.DataFrame(
plot_components, columns=tsne_cols, index=df_indices
)
if id_column and id_column.strip():
id_data = df.loc[df_indices, id_column].values
TSNE_Projections.insert(0, id_column, id_data)
if activate_caching and cache_folder and all_hash:
try:
tsne_proj_path = cache_folder / f"tsne_proj_{all_hash}.parquet"
tsne_est_path = cache_folder / f"tsne_estimator_{all_hash}.joblib"
scaler_path = cache_folder / f"pca_scaler_{all_hash}.joblib"
TSNE_Projections.to_parquet(tsne_proj_path)
joblib.dump(tsne, tsne_est_path)
if Scaler is not None:
joblib.dump(Scaler, scaler_path)
verbose and logger.info(f"Results saved to cache")
except Exception as e:
verbose and logger.warning(f"Failed to save cache: {e}")
try:
idx_x = pc_x - 1
idx_y = pc_y - 1
x_values = plot_components[:, idx_x]
y_values = plot_components[:, idx_y]
x_label = f"t-SNE Dimension {pc_x}"
y_label = f"t-SNE Dimension {pc_y}"
if hue_col:
hue_data = df.loc[TSNE_Projections.index, hue].values
verbose and logger.info(f"Using hue column: '{hue}'")
verbose and logger.info(
f"Creating scatter visualization ({x_label} vs {y_label})"
)
(fig, ax) = plt.subplots(figsize=(10, 7))
_plot_tsne_scatter(
ax, x_values, y_values, x_label, y_label, hue_data, hue_col, palette
)
verbose and logger.info(f"Rendering to {output_type}")
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight")
buf.seek(0)
if output_type == "bytesio":
TSNE_Image = buf
elif output_type == "bytes":
TSNE_Image = buf.getvalue()
buf.close()
elif output_type == "pil":
TSNE_Image = Image.open(buf)
TSNE_Image.load()
buf.close()
elif output_type == "array":
img = Image.open(buf)
TSNE_Image = np.array(img)
buf.close()
else:
raise ValueError(f"Invalid output_type: '{output_type}'")
plt.close(fig)
except Exception as e:
verbose and logger.error(f"Error during plotting: {e}")
plt.close("all")
raise RuntimeError(f"Plotting failed: {e}") from e
verbose and logger.info(f"Successfully completed T-SNE analysis")
return (Scaler, TSNE_Image, TSNE_Projections)
Brick Info
- shap>=0.47.0
- scikit-learn
- pandas
- pyarrow
- polars[pyarrow]
- numpy
- numba>=0.56.0
- joblib
- matplotlib
- seaborn
- pillow
- xxhash