Cla. K Nearest Neighbors
Train a K-Nearest Neighbors (KNN) classification model.
Cla. K Nearest Neighbors
Processing
This brick trains a classification model using the K-Nearest Neighbors (KNN) algorithm. It predicts the category of a new data point by looking at its "neighbors"—the most similar data points in your training set.
Conceptually, it operates on the principle that similar things exist in close proximity. If the majority of a data point's closest neighbors belong to a specific category (e.g., "Fraud"), the algorithm classifies the new point as "Fraud" as well.
This brick handles the entire machine learning pipeline, including:
- Data Scaling: Standardizing data so that features with large numbers (like Salary) don't overpower features with small numbers (like Age).
- Splitting: Automatically dividing data into training and testing sets.
- Optimization: Optionally finding the best number of neighbors (
k) and distance metrics automatically. - Evaluation: Calculating performance scores like Accuracy and F1-Score.
Inputs
- X
- The features (independent variables) used to make predictions. This should be a table where rows are examples and columns are attributes (e.g., customer age, transaction amount).
- y
- The target (dependent variable) you want to predict. This is a list or column containing the known categories or labels (e.g., "Churned", "Active").
Inputs Types
| Input | Types |
|---|---|
X |
DataFrame |
y |
DataSeries, NDArray, List |
You can check the list of supported types here: Available Type Hints.
Outputs
- Model
- The trained K-Nearest Neighbors model object. This can be passed to other bricks to make predictions on new, unseen data.
- Model Classes
- A reference table linking the internal numerical IDs used by the model to the actual class names (labels) found in your data.
- SHAP
- The SHAP explainer object (if enabled in options). This is used to interpret why the model made specific predictions.
- Scaler
- The fitted scaler object used to normalize the data. This is required to process future data in the exact same way as the training data.
- Metrics
- A summary of the model's performance on the test set, containing scores like Accuracy, Precision, Recall, and ROC-AUC.
- CV Metrics
- Detailed performance statistics from Cross-Validation (if enabled), showing how stable the model is across different subsets of data.
- Prediction Set
- A dataframe combining the test data features with the model's actual predictions and probabilities. This allows you to inspect specific rows where the model succeeded or failed.
- HPO Trials
- A history of all attempts made during Hyperparameter Optimization (if enabled), showing which settings were tried and how well they performed.
- HPO Best
- A dictionary containing the specific configuration (e.g., exact number of neighbors) that resulted in the best performing model.
The Prediction_Set output contains the following specific data fields:
- feature_{name}: The input feature columns used for the prediction.
- proba_{class}: The calculated probability for each specific class. For binary classification, this might appear as a single
probacolumn representing the positive class. - y_true: The actual known label for the row.
- y_pred: The label predicted by the model.
- is_false_prediction: A boolean (True/False) indicating if the model guessed wrong.
Outputs Types
| Output | Types |
|---|---|
Model |
Any |
Model Classes |
DataFrame |
SHAP |
Any |
Scaler |
Any |
Metrics |
DataFrame, Dict |
CV Metrics |
DataFrame |
Prediction Set |
DataFrame |
HPO Trials |
DataFrame |
HPO Best |
Dict |
You can check the list of supported types here: Available Type Hints.
Options
The Cla. K Nearest Neighbors brick contains some changeable options:
- Number of Neighbors
- The 'K' in KNN. This controls how many closest neighbors are consulted to vote on the class of a new data point.
- Low values (e.g., 1-3): The model is very sensitive to local patterns but may capture noise (overfitting).
- High values: The model becomes smoother and more stable but may miss finer details.
- Distance Weighting
- Determines how the vote is counted.
- False (Uniform): Democracy. Every neighbor gets 1 vote, regardless of how close they are.
- True (Distance): Weighted. Neighbors that are closer to the data point have a stronger influence on the result than those further away.
- Neighbor Algorithm
- The technical method used to search for neighbors.
- Auto: Automatically selects the best method based on the data structure. (Recommended)
- Ball-tree: Efficient for high-dimensional data.
- KD-tree: Efficient for low-dimensional data.
- Brute Force: Computes distances to all pairs. Exact but slow for large datasets.
- Metric
- The mathematical formula used to measure "distance" or similarity between two points.
- Minkowski: The standard geometric distance (includes Euclidean and Manhattan).
- Cosine: Measures the angle between points. Often better for text data or high-dimensional sparse data.
- Haversine: Used for calculating distances between latitude/longitude points on a sphere.
- Minkowski Exponent
- Only applies when Metric is "Minkowski".
- 1: Manhattan Distance (measured along grid lines, like city blocks).
- 2: Euclidean Distance (straight-line distance).
- Standard Scaling
- If enabled, automatically scales all numerical features to have a mean of 0 and variance of 1. This is highly recommended for KNN, as the algorithm is sensitive to the scale of data (e.g., Income vs Age).
- Auto Split Data
- If enabled, the brick automatically calculates the best split ratio for training and testing based on your dataset size.
- Shuffle Split
- Randomly shuffles the data before splitting. This ensures the training and test sets are representative of the whole dataset.
- Stratify Split
- Ensures the distribution of classes (e.g., 70% 'No', 30% 'Yes') is preserved in both the training and testing sets. Critical for imbalanced data.
- Test/Validation Set %
- The percentage of data held back for testing the model's performance. (Ignored if "Auto Split Data" is enabled).
- Retrain On Full Data
- If enabled, after evaluating the model on the test set, the brick retrains the model on 100% of the available data. Use this when you are ready to deploy the final model for production.
- Average Strategy
- How to average metrics (Precision, Recall, F1) for multiclass problems.
- auto: Automatically selects based on class balance.
- binary: For two-class problems only.
- micro: Calculate metrics globally by counting total true positives, false negatives and false positives.
- macro: Calculate metrics for each label, and find their unweighted mean. (Treats all classes equally).
- weighted: Calculate metrics for each label, and find their average weighted by support (the number of true instances for each label).
- Enable Cross-Validation
- If enabled, splits the data into multiple folds (e.g., 5 parts) and trains/tests 5 times. This provides a more robust estimate of model performance but takes longer.
- Number of CV Folds
- The number of groups to split the data into for Cross-Validation (typically 5 or 10).
- Hyperparameter Optim.
- If enabled, the brick will run multiple trials to automatically find the best settings for your specific data.
- Optimization Metric
- The score the optimization process tries to maximize (e.g., maximize "F1 Score" to balance precision and recall).
- Optimization Method
- The algorithm used to search for the best parameters. "Tree-structured Parzen" is generally the most efficient.
- Tree-structured Parzen: A Bayesian optimization method that models good vs bad parameter regions using probability distributions and prioritizes sampling where success is statistically more likely.
- Gaussian Process: ses a probabilistic regression model (Gaussian Process) to estimate performance uncertainty and selects new trials using acquisition functions.
- CMA-ES: An evolutionary strategy that adapts the covariance matrix of a multivariate normal distribution to efficiently search complex, non-linear, non-convex spaces.
- Random Sobol Search: Uses low-discrepancy quasi-random sequences to ensure uniform coverage of the parameter space, avoiding clustering and gaps.
- Random Search: Uniform random sampling of parameter configurations without learning or feedback between iterations.
- Optimization Iterations
- The number of different parameter combinations to try during optimization.
- Positive Label (Binary Only)
- Explicitly define which class is "Positive" (e.g., "1", "churn", "fraud"). If left empty, the brick attempts to infer it.
- Metrics as
- Choose whether the
Metricsoutput is returned as a Dataframe (table) or a Dictionary (JSON-like). - SHAP Explainer
- If enabled, generates a SHAP explainer object. Note: computing SHAP values for KNN can be computationally expensive.
- SHAP Sampler
- If enabled, uses a smaller sample of the background data for the SHAP explainer to speed up calculation.
- Number of Jobs
- Controls how many CPU cores are used for training and cross-validation.
- 1: Sequential processing (slower, uses less system resources).
- 2, 4, 8: Specific number of cores.
- All: Uses all available cores on the machine for maximum speed.
- Random State
- A seed number (integer) that ensures reproducibility. Using the same seed with the same data ensures you get the exact same split and model results every time you run the brick.
- Brick Caching
- If enabled, the brick saves the results to a temporary cache based on the inputs. If you run the flow again with the exact same inputs and options, it loads the results immediately instead of retraining, saving significant time.
- Verbose Logging
- If enabled, prints detailed progress updates, metric calculations, and optimization steps to the execution logs. Useful for debugging or monitoring long-running tasks.
import logging
import warnings
import shap
import json
import xxhash
import hashlib
import tempfile
import sklearn
import scipy
import joblib
import numpy as np
import pandas as pd
import polars as pl
from pathlib import Path
from scipy import sparse
from optuna.samplers import (
TPESampler,
RandomSampler,
GPSampler,
CmaEsSampler,
QMCSampler,
)
import optuna
from optuna import Study
from optuna.trial import FrozenTrial
from optuna.pruners import HyperbandPruner
from optuna import create_study
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split, cross_validate, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.metrics import (
accuracy_score,
precision_score,
recall_score,
f1_score,
roc_auc_score,
make_scorer,
)
from dataclasses import dataclass
from datetime import datetime
from coded_flows.types import (
Union,
Dict,
List,
Tuple,
NDArray,
DataFrame,
DataSeries,
Any,
Tuple,
)
from coded_flows.utils import CodedFlowsLogger
logger = CodedFlowsLogger(name="Cla. K Nearest Neighbors", level=logging.INFO)
optuna.logging.set_verbosity(optuna.logging.ERROR)
warnings.filterwarnings("ignore", category=optuna.exceptions.ExperimentalWarning)
DataType = Union[
pd.DataFrame, pl.DataFrame, np.ndarray, sparse.spmatrix, pd.Series, pl.Series
]
@dataclass
class _DatasetFingerprint:
"""Lightweight fingerprint of a dataset."""
hash: str
shape: tuple
computed_at: str
data_type: str
method: str
class _UniversalDatasetHasher:
"""
High-performance dataset hasher optimizing for zero-copy operations
and native backend execution (C/Rust).
"""
def __init__(
self,
data_size: int,
method: str = "auto",
sample_size: int = 100000,
verbose: bool = False,
):
self.method = method
self.sample_size = sample_size
self.data_size = data_size
self.verbose = verbose
def hash_data(self, data: DataType) -> _DatasetFingerprint:
"""
Main entry point: hash any supported data format.
Auto-detects format and applies optimal strategy.
"""
if isinstance(data, pd.DataFrame):
return self._hash_pandas(data)
elif isinstance(data, pl.DataFrame):
return self._hash_polars(data)
elif isinstance(data, pd.Series):
return self._hash_pandas_series(data)
elif isinstance(data, pl.Series):
return self._hash_polars_series(data)
elif isinstance(data, np.ndarray):
return self._hash_numpy(data)
elif sparse.issparse(data):
return self._hash_sparse(data)
else:
raise TypeError(f"Unsupported data type: {type(data)}")
def _hash_pandas(self, df: pd.DataFrame) -> _DatasetFingerprint:
"""
Optimized Pandas hashing using pd.util.hash_pandas_object.
Avoids object-to-string conversion overhead.
"""
method = self._determine_method(self.data_size, self.method)
self.verbose and logger.info(
f"Hashing Pandas: {self.data_size:,} rows - {method}"
)
target_df = df
if method == "sampled":
target_df = self._get_pandas_sample(df)
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{
"columns": df.columns.tolist(),
"dtypes": {k: str(v) for (k, v) in df.dtypes.items()},
"shape": df.shape,
},
)
try:
row_hashes = pd.util.hash_pandas_object(target_df, index=False)
hasher.update(memoryview(row_hashes.values))
except Exception as e:
self.verbose and logger.warning(
f"Fast hash failed, falling back to slow hash: {e}"
)
self._hash_pandas_fallback(hasher, target_df)
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=df.shape,
computed_at=datetime.now().isoformat(),
data_type="pandas",
method=method,
)
def _get_pandas_sample(self, df: pd.DataFrame) -> pd.DataFrame:
"""Deterministic slicing for sampling (Zero randomness)."""
if self.data_size <= self.sample_size:
return df
chunk = self.sample_size // 3
head = df.iloc[:chunk]
mid_idx = self.data_size // 2
mid = df.iloc[mid_idx : mid_idx + chunk]
tail = df.iloc[-chunk:]
return pd.concat([head, mid, tail])
def _hash_pandas_fallback(self, hasher, df: pd.DataFrame):
"""Legacy fallback for complex object types."""
for col in df.columns:
val = df[col].astype(str).values
hasher.update(val.astype(np.bytes_).tobytes())
def _hash_polars(self, df: pl.DataFrame) -> _DatasetFingerprint:
"""
Optimized Polars hashing using native Rust execution.
"""
method = self._determine_method(self.data_size, self.method)
self.verbose and logger.info(
f"Hashing Polars: {self.data_size:,} rows - {method}"
)
target_df = df
if method == "sampled" and self.data_size > self.sample_size:
indices = self._get_sample_indices(self.data_size, self.sample_size)
target_df = df.gather(indices)
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{
"columns": df.columns,
"dtypes": [str(t) for t in df.dtypes],
"shape": df.shape,
},
)
row_hashes = target_df.hash_rows()
hasher.update(memoryview(row_hashes.to_numpy()))
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=df.shape,
computed_at=datetime.now().isoformat(),
data_type="polars",
method=method,
)
def _hash_pandas_series(self, series: pd.Series) -> _DatasetFingerprint:
"""Hash Pandas Series using the fastest vectorized method."""
self.verbose and logger.info(f"Hashing Pandas Series: {self.data_size:,} rows")
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{
"name": series.name if series.name else "None",
"dtype": str(series.dtype),
"shape": series.shape,
},
)
try:
row_hashes = pd.util.hash_pandas_object(series, index=False)
hasher.update(memoryview(row_hashes.values))
except Exception as e:
self.verbose and logger.warning(f"Series hash failed, falling back: {e}")
hasher.update(memoryview(series.astype(str).values.tobytes()))
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=series.shape,
computed_at=datetime.now().isoformat(),
data_type="pandas_series",
method="full",
)
def _hash_polars_series(self, series: pl.Series) -> _DatasetFingerprint:
"""Hash Polars Series using native Polars expressions."""
self.verbose and logger.info(f"Hashing Polars Series: {self.data_size:,} rows")
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{"name": series.name, "dtype": str(series.dtype), "shape": series.shape},
)
try:
row_hashes = series.hash()
hasher.update(memoryview(row_hashes.to_numpy()))
except Exception as e:
self.verbose and logger.warning(
f"Polars series native hash failed. Falling back."
)
hasher.update(str(series.to_list()).encode())
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=series.shape,
computed_at=datetime.now().isoformat(),
data_type="polars_series",
method="full",
)
def _hash_numpy(self, arr: np.ndarray) -> _DatasetFingerprint:
"""
Optimized NumPy hashing using Buffer Protocol (Zero-Copy).
"""
hasher = xxhash.xxh128()
self._hash_schema(
hasher,
{"shape": arr.shape, "dtype": str(arr.dtype), "strides": arr.strides},
)
if arr.flags["C_CONTIGUOUS"] or arr.flags["F_CONTIGUOUS"]:
hasher.update(memoryview(arr))
else:
hasher.update(memoryview(np.ascontiguousarray(arr)))
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=arr.shape,
computed_at=datetime.now().isoformat(),
data_type="numpy",
method="full",
)
def _hash_sparse(self, matrix: sparse.spmatrix) -> _DatasetFingerprint:
"""
Optimized sparse hashing. Hashes underlying data arrays directly.
"""
if not (sparse.isspmatrix_csr(matrix) or sparse.isspmatrix_csc(matrix)):
matrix = matrix.tocsr()
hasher = xxhash.xxh128()
self._hash_schema(
hasher, {"shape": matrix.shape, "format": matrix.format, "nnz": matrix.nnz}
)
hasher.update(memoryview(matrix.data))
hasher.update(memoryview(matrix.indices))
hasher.update(memoryview(matrix.indptr))
return _DatasetFingerprint(
hash=hasher.hexdigest(),
shape=matrix.shape,
computed_at=datetime.now().isoformat(),
data_type=f"sparse_{matrix.format}",
method="sparse",
)
def _determine_method(self, rows: int, requested: str) -> str:
if requested != "auto":
return requested
if rows < 5000000:
return "full"
return "sampled"
def _hash_schema(self, hasher, schema: Dict[str, Any]):
"""Compact schema hashing."""
hasher.update(
json.dumps(schema, sort_keys=True, separators=(",", ":")).encode()
)
def _get_sample_indices(self, total_rows: int, sample_size: int) -> list:
"""Calculate indices for sampling without generating full range lists."""
chunk = sample_size // 3
indices = list(range(min(chunk, total_rows)))
mid_start = max(0, total_rows // 2 - chunk // 2)
mid_end = min(mid_start + chunk, total_rows)
indices.extend(range(mid_start, mid_end))
last_start = max(0, total_rows - chunk)
indices.extend(range(last_start, total_rows))
return sorted(list(set(indices)))
def _normalize_hpo_df(df):
df = df.copy()
param_cols = [c for c in df.columns if c.startswith("params_")]
df[param_cols] = df[param_cols].astype("string[pyarrow]")
return df
def _validate_numerical_data(data):
"""
Validates if the input data (NumPy array, Pandas DataFrame/Series,
Polars DataFrame/Series, or SciPy sparse matrix) contains only
numerical (integer, float) or boolean values.
Args:
data: The input data structure to check.
Raises:
TypeError: If the input data contains non-numerical and non-boolean types.
ValueError: If the input data is of an unsupported type.
"""
if sparse.issparse(data):
if not (
np.issubdtype(data.dtype, np.number) or np.issubdtype(data.dtype, np.bool_)
):
raise TypeError(
f"Sparse matrix contains unsupported data type: {data.dtype}. Only numerical or boolean types are allowed."
)
return
elif isinstance(data, np.ndarray):
if not (
np.issubdtype(data.dtype, np.number) or np.issubdtype(data.dtype, np.bool_)
):
raise TypeError(
f"NumPy array contains unsupported data type: {data.dtype}. Only numerical or boolean types are allowed."
)
return
elif isinstance(data, (pd.DataFrame, pd.Series)):
d_types = data.dtypes.apply(lambda x: x.kind)
non_numerical_mask = ~d_types.isin(["i", "f", "b"])
if non_numerical_mask.any():
non_numerical_columns = (
data.columns[non_numerical_mask].tolist()
if isinstance(data, pd.DataFrame)
else [data.name]
)
raise TypeError(
f"Pandas {('DataFrame' if isinstance(data, pd.DataFrame) else 'Series')} contains non-numerical/boolean data. Offending column(s) and types: {data.dtypes[non_numerical_mask].to_dict()}"
)
return
elif isinstance(data, (pl.DataFrame, pl.Series)):
pl_numerical_types = [
pl.Int8,
pl.Int16,
pl.Int32,
pl.Int64,
pl.UInt8,
pl.UInt16,
pl.UInt32,
pl.UInt64,
pl.Float32,
pl.Float64,
pl.Boolean,
]
if isinstance(data, pl.DataFrame):
for col, dtype in data.schema.items():
if dtype not in pl_numerical_types:
raise TypeError(
f"Polars DataFrame column '{col}' has unsupported data type: {dtype}. Only numerical or boolean types are allowed."
)
elif isinstance(data, pl.Series):
if data.dtype not in pl_numerical_types:
raise TypeError(
f"Polars Series has unsupported data type: {data.dtype}. Only numerical or boolean types are allowed."
)
return
else:
raise ValueError(
f"Unsupported data type provided: {type(data)}. Function supports NumPy, Pandas, Polars, and SciPy sparse matrices."
)
def _get_shape_and_sparsity(X: Any) -> Tuple[int, int, float, bool]:
"""
Efficiently extracts shape and estimates sparsity without converting
the entire dataset to numpy.
"""
(n_samples, n_features) = (0, 0)
is_sparse = False
sparsity = 0.0
if hasattr(X, "nnz") and hasattr(X, "shape"):
(n_samples, n_features) = X.shape
is_sparse = True
sparsity = 1.0 - X.nnz / (n_samples * n_features)
return (n_samples, n_features, sparsity, is_sparse)
if hasattr(X, "height") and hasattr(X, "width"):
(n_samples, n_features) = (X.height, X.width)
return (n_samples, n_features, 0.0, False)
if hasattr(X, "shape") and hasattr(X, "iloc"):
(n_samples, n_features) = X.shape
return (n_samples, n_features, 0.0, False)
if isinstance(X, list):
X = np.array(X)
if hasattr(X, "shape"):
(n_samples, n_features) = X.shape
return (n_samples, n_features, 0.0, False)
raise ValueError("Unsupported data type")
def _smart_split(
n_samples,
X,
y,
*,
random_state=42,
shuffle=True,
stratify=None,
fixed_test_split=None,
verbose=True,
):
"""
Parameters
----------
n_samples : int
Number of samples in the dataset (len(X) or len(y))
X : array-like
Features
y : array-like
Target
random_state : int
shuffle : bool
stratify : array-like or None
For stratified splitting (recommended for classification)
Returns
-------
If return_val=True → X_train, X_val, X_test, y_train, y_val, y_test
If return_val=False → X_train, X_test, y_train, y_test
"""
if fixed_test_split:
test_ratio = fixed_test_split
val_ratio = fixed_test_split
elif n_samples <= 1000:
test_ratio = 0.2
val_ratio = 0.1
elif n_samples < 10000:
test_ratio = 0.15
val_ratio = 0.15
elif n_samples < 100000:
test_ratio = 0.1
val_ratio = 0.1
elif n_samples < 1000000:
test_ratio = 0.05
val_ratio = 0.05
else:
test_ratio = 0.01
val_ratio = 0.01
(X_train, X_test, y_train, y_test) = train_test_split(
X,
y,
test_size=test_ratio,
random_state=random_state,
shuffle=shuffle,
stratify=stratify,
)
val_size_in_train = val_ratio / (1 - test_ratio)
verbose and logger.info(
f"Split → Train: {1 - test_ratio:.2%} | Test: {test_ratio:.2%} (no validation set)"
)
return (X_train, X_test, y_train, y_test, val_size_in_train)
def _get_best_metric_average_strategy(y_true, balance_threshold: float = 0.5) -> str:
"""
Analyzes y_true to determine the best averaging strategy.
Args:
y_true: Input array (Numpy array, Pandas Series, or Polars Series).
balance_threshold: Float (0 to 1). If min_class_count / max_class_count
is below this, the data is considered imbalanced.
Returns:
str: 'binary', 'weighted', or 'macro'
"""
counts = None
if hasattr(y_true, "value_counts") and hasattr(y_true, "values"):
counts = y_true.value_counts().values
elif hasattr(y_true, "value_counts") and hasattr(y_true, "to_numpy"):
vc = y_true.value_counts()
if "count" in vc.columns:
counts = vc["count"].to_numpy()
else:
counts = vc[:, 1].to_numpy()
elif isinstance(y_true, np.ndarray):
(_, counts) = np.unique(y_true, return_counts=True)
else:
(_, counts) = np.unique(np.array(y_true), return_counts=True)
if counts is None or len(counts) == 0:
raise ValueError("Input y_true appears to be empty.")
n_classes = len(counts)
if n_classes <= 2:
return "binary"
min_c = np.min(counts)
max_c = np.max(counts)
ratio = min_c / max_c
if ratio < balance_threshold:
return "weighted"
else:
return "macro"
def _ensure_feature_names(X, feature_names=None):
if isinstance(X, pd.DataFrame):
return list(X.columns)
if isinstance(X, np.ndarray):
if feature_names is None:
feature_names = [f"feature_{i}" for i in range(X.shape[1])]
return feature_names
raise TypeError("X must be a pandas DataFrame or numpy ndarray")
def _perform_cross_validation(
model,
X,
y,
cv_folds,
average_strategy,
shuffle,
random_state,
n_jobs,
verbose,
pos_label=None,
) -> dict[str, Any]:
"""Perform cross-validation on the model."""
verbose and logger.info(f"Performing {cv_folds}-fold cross-validation...")
cv = StratifiedKFold(n_splits=cv_folds, shuffle=shuffle, random_state=random_state)
if average_strategy == "binary":
scoring = {
"accuracy": "accuracy",
"precision": make_scorer(
precision_score, average="binary", pos_label=pos_label
),
"recall": make_scorer(recall_score, average="binary", pos_label=pos_label),
"f1": make_scorer(f1_score, average="binary", pos_label=pos_label),
"roc_auc": "roc_auc",
}
else:
average_strategy_suffix = f"_{average_strategy}"
roc_average_strategy_suffix = (
f"_{average_strategy}" if average_strategy == "weighted" else ""
)
roc_auc_ovr_suffix = "_ovr"
scoring = (
f"f1{average_strategy_suffix}",
"accuracy",
f"precision{average_strategy_suffix}",
f"recall{average_strategy_suffix}",
f"roc_auc{roc_auc_ovr_suffix}{roc_average_strategy_suffix}",
)
cv_results = cross_validate(
model, X, y, cv=cv, scoring=scoring, return_train_score=True, n_jobs=n_jobs
)
def get_score_mean_std(metric_key):
if metric_key in cv_results:
return (cv_results[metric_key].mean(), cv_results[metric_key].std())
return (0.0, 0.0)
if average_strategy == "binary":
(accuracy_mean, accuracy_std) = get_score_mean_std("test_accuracy")
(precision_mean, precision_std) = get_score_mean_std("test_precision")
(recall_mean, recall_std) = get_score_mean_std("test_recall")
(f1_mean, f1_std) = get_score_mean_std("test_f1")
(roc_auc_mean, roc_auc_std) = get_score_mean_std("test_roc_auc")
else:
(accuracy_mean, accuracy_std) = get_score_mean_std("test_accuracy")
(precision_mean, precision_std) = get_score_mean_std(
f"test_precision{average_strategy_suffix}"
)
(recall_mean, recall_std) = get_score_mean_std(
f"test_recall{average_strategy_suffix}"
)
(f1_mean, f1_std) = get_score_mean_std(f"test_f1{average_strategy_suffix}")
roc_key = f"test_roc_auc{roc_auc_ovr_suffix}{roc_average_strategy_suffix}"
(roc_auc_mean, roc_auc_std) = get_score_mean_std(roc_key)
verbose and logger.info(
f"CV Accuracy : {accuracy_mean:.4f} (+/- {accuracy_std:.4f})"
)
verbose and logger.info(
f"CV Precision : {precision_mean:.4f} (+/- {precision_std:.4f})"
)
verbose and logger.info(f"CV Recall : {recall_mean:.4f} (+/- {recall_std:.4f})")
verbose and logger.info(f"CV F1 Score : {f1_mean:.4f} (+/- {f1_std:.4f})")
verbose and logger.info(
f"CV ROC-AUC : {roc_auc_mean:.4f} (+/- {roc_auc_std:.4f})"
)
CV_metrics = pd.DataFrame(
{
"Metric": ["Accuracy", "Precision", "Recall", "F1-Score", "ROC AUC"],
"Mean": [accuracy_mean, precision_mean, recall_mean, f1_mean, roc_auc_mean],
"Std": [accuracy_std, precision_std, recall_std, f1_std, roc_auc_std],
}
)
return CV_metrics
def _compute_score(model, X, y, metric, average_strategy, pos_label=None):
score_params = {"average": average_strategy, "zero_division": 0}
y_pred = model.predict(X)
if average_strategy != "binary":
y_score = model.predict_proba(X)
else:
score_params["pos_label"] = pos_label
if pos_label is not None:
classes = list(model.classes_)
try:
pos_idx = classes.index(pos_label)
except ValueError:
pos_idx = 1 if len(classes) > 1 else 0
y_score = model.predict_proba(X)[:, pos_idx]
else:
y_score = model.predict_proba(X)[:, 1]
if metric == "Accuracy":
score = accuracy_score(y, y_pred)
elif metric == "Precision":
score = precision_score(y, y_pred, **score_params)
elif metric == "Recall":
score = recall_score(y, y_pred, **score_params)
elif metric == "F1 Score":
score = f1_score(y, y_pred, **score_params)
elif metric == "ROC-AUC":
if average_strategy != "binary":
score = roc_auc_score(
y, y_score, multi_class="ovr", average=average_strategy
)
else:
score = roc_auc_score(y, y_score)
return score
def _get_cv_scoring_object(metric, average_strategy, pos_label=None):
"""
Returns a scoring object (string or callable) suitable for cross_validate.
Used during HPO.
"""
if average_strategy == "binary":
if metric == "F1 Score":
return make_scorer(f1_score, average="binary", pos_label=pos_label)
elif metric == "Accuracy":
return "accuracy"
elif metric == "Precision":
return make_scorer(precision_score, average="binary", pos_label=pos_label)
elif metric == "Recall":
return make_scorer(recall_score, average="binary", pos_label=pos_label)
elif metric == "ROC-AUC":
return "roc_auc"
else:
average_strategy_suffix = f"_{average_strategy}"
roc_auc_ovr_suffix = "_ovr"
roc_average_strategy_suffix = (
f"_{average_strategy}" if average_strategy == "weighted" else ""
)
if metric == "F1 Score":
return f"f1{average_strategy_suffix}"
elif metric == "Accuracy":
return "accuracy"
elif metric == "Precision":
return f"precision{average_strategy_suffix}"
elif metric == "Recall":
return f"recall{average_strategy_suffix}"
elif metric == "ROC-AUC":
return f"roc_auc{roc_auc_ovr_suffix}{roc_average_strategy_suffix}"
def _hyperparameters_optimization(
X,
y,
constant_hyperparameters,
optimization_metric,
metric_average_strategy,
val_ratio,
shuffle_split,
stratify_split,
use_cross_val,
cv_folds,
n_trials=50,
strategy="maximize",
sampler="Tree-structured Parzen",
standard_scaling=False,
seed=None,
n_jobs=-1,
verbose=False,
pos_label=None,
):
direction = "maximize" if strategy.lower() == "maximize" else "minimize"
sampler_map = {
"Tree-structured Parzen": TPESampler(seed=seed),
"Gaussian Process": GPSampler(seed=seed),
"CMA-ES": CmaEsSampler(seed=seed),
"Random Search": RandomSampler(seed=seed),
"Random Sobol Search": QMCSampler(seed=seed),
}
if sampler in sampler_map:
chosen_sampler = sampler_map[sampler]
else:
logger.warning(f"Sampler '{sampler}' not recognized → falling back to TPE")
chosen_sampler = TPESampler(seed=seed)
chosen_pruner = HyperbandPruner()
if use_cross_val:
cv = StratifiedKFold(
n_splits=cv_folds, shuffle=shuffle_split, random_state=seed
)
cv_score_obj = _get_cv_scoring_object(
optimization_metric, metric_average_strategy, pos_label
)
else:
(X_train, X_val, y_train, y_val) = train_test_split(
X,
y,
test_size=val_ratio,
random_state=seed,
stratify=y if stratify_split else None,
shuffle=shuffle_split,
)
def logging_callback(study: Study, trial: FrozenTrial):
"""Callback function to log trial progress"""
verbose and logger.info(
f"Trial {trial.number} finished with value: {trial.value} and parameters: {trial.params}"
)
try:
verbose and logger.info(f"Best value so far: {study.best_value}")
verbose and logger.info(f"Best parameters so far: {study.best_params}")
except ValueError:
verbose and logger.info(f"No successful trials completed yet")
verbose and logger.info(f"" + "-" * 50)
def objective(trial):
try:
n_neighbors = trial.suggest_int("n_neighbors", 1, 50)
weights = trial.suggest_categorical("weights", ["uniform", "distance"])
metric = trial.suggest_categorical("metric", ["minkowski", "cosine"])
p = trial.suggest_int("p", 1, 2)
model = KNeighborsClassifier(
n_neighbors=n_neighbors,
weights=weights,
metric=metric,
algorithm="auto",
p=p,
n_jobs=n_jobs,
)
if standard_scaling:
model = Pipeline([("scaler", StandardScaler()), ("knn", model)])
if use_cross_val:
scores = cross_validate(
model, X, y, cv=cv, n_jobs=n_jobs, scoring=cv_score_obj
)
return scores["test_score"].mean()
else:
model.fit(X_train, y_train)
score = _compute_score(
model,
X_val,
y_val,
optimization_metric,
metric_average_strategy,
pos_label,
)
return score
except Exception as e:
verbose and logger.error(
f"Trial {trial.number} failed with error: {str(e)}"
)
raise
study = create_study(
direction=direction, sampler=chosen_sampler, pruner=chosen_pruner
)
study.optimize(
objective,
n_trials=n_trials,
catch=(Exception,),
n_jobs=n_jobs,
callbacks=[logging_callback],
)
verbose and logger.info(f"Optimization completed!")
verbose and logger.info(
f" Best Number of Neighbors : {study.best_params['n_neighbors']}"
)
verbose and logger.info(
f" Best Weighting : {study.best_params['weights']}"
)
verbose and logger.info(
f" Best Metric : {study.best_params['metric']}"
)
verbose and logger.info(
f" Best Minkowski Exponent : {study.best_params['p']}"
)
verbose and logger.info(
f" Best {optimization_metric:<23}: {study.best_value:.4f}"
)
verbose and logger.info(f" Sampler used : {sampler}")
verbose and logger.info(f" Direction : {direction}")
if use_cross_val:
verbose and logger.info(f" Cross-validation : {cv_folds}-fold")
else:
verbose and logger.info(
f" Validation : single train/val split"
)
trials = study.trials_dataframe()
trials["best_value"] = trials["value"].cummax()
cols = list(trials.columns)
value_idx = cols.index("value")
cols = [c for c in cols if c != "best_value"]
new_order = cols[: value_idx + 1] + ["best_value"] + cols[value_idx + 1 :]
trials = trials[new_order]
return (study.best_params, trials)
def _combine_test_data(
X_test, y_true, y_pred, y_proba, class_names, features_names=None
):
"""
Combine X_test, y_true, y_pred, and y_proba into a single DataFrame.
Parameters:
-----------
X_test : pandas/polars DataFrame, numpy array, or scipy sparse matrix
Test features
y_true : pandas/polars Series, numpy array, or list
True labels
y_pred : pandas/polars Series, numpy array, or list
Predicted labels
y_proba : pandas/polars Series/DataFrame, numpy array (1D or 2D), or list
Prediction probabilities - can be:
- 1D array for binary classification (probability of positive class)
- 2D array for multiclass (probabilities for each class)
class_names : list or array-like
Names of the classes in order.
For binary classification with 1D y_proba, only the positive class name is needed.
Returns:
--------
pandas.DataFrame
Combined DataFrame with features, probabilities, y_true, and y_pred
"""
if sparse.issparse(X_test):
X_df = pd.DataFrame(X_test.toarray())
elif isinstance(X_test, np.ndarray):
X_df = pd.DataFrame(X_test)
elif hasattr(X_test, "to_pandas"):
X_df = X_test.to_pandas()
elif isinstance(X_test, pd.DataFrame):
X_df = X_test.copy()
else:
raise TypeError(f"Unsupported type for X_test: {type(X_test)}")
if X_df.columns.tolist() == list(range(len(X_df.columns))):
X_df.columns = (
[f"feature_{i}" for i in range(len(X_df.columns))]
if features_names is None
else features_names
)
if isinstance(y_true, list):
y_true_series = pd.Series(y_true, name="y_true")
elif isinstance(y_true, np.ndarray):
y_true_series = pd.Series(y_true, name="y_true")
elif hasattr(y_true, "to_pandas"):
y_true_series = y_true.to_pandas()
y_true_series.name = "y_true"
elif isinstance(y_true, pd.Series):
y_true_series = y_true.copy()
y_true_series.name = "y_true"
else:
raise TypeError(f"Unsupported type for y_true: {type(y_true)}")
if isinstance(y_pred, list):
y_pred_series = pd.Series(y_pred, name="y_pred")
elif isinstance(y_pred, np.ndarray):
y_pred_series = pd.Series(y_pred, name="y_pred")
elif hasattr(y_pred, "to_pandas"):
y_pred_series = y_pred.to_pandas()
y_pred_series.name = "y_pred"
elif isinstance(y_pred, pd.Series):
y_pred_series = y_pred.copy()
y_pred_series.name = "y_pred"
else:
raise TypeError(f"Unsupported type for y_pred: {type(y_pred)}")
if isinstance(y_proba, list):
y_proba_array = np.array(y_proba)
elif isinstance(y_proba, np.ndarray):
y_proba_array = y_proba
elif hasattr(y_proba, "to_pandas"):
y_proba_pd = y_proba.to_pandas()
if isinstance(y_proba_pd, pd.Series):
y_proba_array = y_proba_pd.values
else:
y_proba_array = y_proba_pd.values
elif isinstance(y_proba, pd.Series):
y_proba_array = y_proba.values
elif isinstance(y_proba, pd.DataFrame):
y_proba_array = y_proba.values
else:
raise TypeError(f"Unsupported type for y_proba: {type(y_proba)}")
def sanitize_class_name(class_name):
"""Convert class name to valid column name by replacing spaces and special chars"""
return str(class_name).replace(" ", "_").replace("-", "_")
if y_proba_array.ndim == 1:
y_proba_df = pd.DataFrame({"proba": y_proba_array})
else:
n_classes = y_proba_array.shape[1]
if len(class_names) == n_classes:
proba_columns = [f"proba_{sanitize_class_name(cls)}" for cls in class_names]
else:
proba_columns = [f"proba_{i}" for i in range(n_classes)]
y_proba_df = pd.DataFrame(y_proba_array, columns=proba_columns)
y_proba_df = y_proba_df.reset_index(drop=True)
X_df = X_df.reset_index(drop=True)
y_true_series = y_true_series.reset_index(drop=True)
y_pred_series = y_pred_series.reset_index(drop=True)
is_false_prediction = pd.Series(
y_true_series != y_pred_series, name="is_false_prediction"
).reset_index(drop=True)
result_df = pd.concat(
[X_df, y_proba_df, y_true_series, y_pred_series, is_false_prediction], axis=1
)
return result_df
def _smart_shap_background(
X: Union[np.ndarray, pd.DataFrame],
model_type: str = "tree",
seed: int = 42,
verbose: bool = True,
) -> Union[np.ndarray, pd.DataFrame, object]:
"""
Intelligently prepares a background dataset for SHAP based on model type.
Strategies:
- Tree: Higher sample cap (1000), uses Random Sampling (preserves data structure).
- Other: Lower sample cap (100), uses K-Means (maximizes info density).
"""
(n_rows, n_features) = X.shape
if model_type == "tree":
max_samples = 1000
use_kmeans = False
else:
max_samples = 100
use_kmeans = True
if n_rows <= max_samples:
verbose and logger.info(
f"✓ Dataset small ({n_rows} <= {max_samples}). Using full data."
)
return X
verbose and logger.info(
f"⚡ Large dataset detected ({n_rows} rows). Optimization Strategy: {('K-Means' if use_kmeans else 'Random Sampling')}"
)
if use_kmeans:
try:
verbose and logger.info(
f" Summarizing to {max_samples} weighted centroids..."
)
return shap.kmeans(X, max_samples)
except Exception as e:
logger.warning(
f" K-Means failed ({str(e)}). Falling back to random sampling."
)
return shap.sample(X, max_samples, random_state=seed)
else:
verbose and logger.info(f" Sampling {max_samples} random rows...")
return shap.sample(X, max_samples, random_state=seed)
def _class_index_df(model):
columns = {"index": pd.Series(dtype="int64"), "class": pd.Series(dtype="object")}
if model is None:
return pd.DataFrame(columns)
classes = getattr(model, "classes_", None)
if classes is None:
return pd.DataFrame(columns)
return pd.DataFrame({"index": range(len(classes)), "class": classes})
def train_cla_knn(
X: DataFrame, y: Union[DataSeries, NDArray, List], options=None
) -> Tuple[
Any,
DataFrame,
Any,
Any,
Union[DataFrame, Dict],
DataFrame,
DataFrame,
DataFrame,
Dict,
]:
options = options or {}
n_neighbors = options.get("n_neighbors", 5)
distance_weighting = options.get("distance_weighting", False)
neighbor_algorithm = options.get("neighbor_algorithm", "Auto")
metric = options.get("metric", "Minkowski").lower()
p = options.get("p", 2)
weights = "distance" if distance_weighting else "uniform"
if neighbor_algorithm == "Auto":
neighbor_algorithm = "auto"
elif neighbor_algorithm == "Ball-tree":
neighbor_algorithm = "ball_tree"
elif neighbor_algorithm == "KD-tree":
neighbor_algorithm = "kd_tree"
elif neighbor_algorithm == "Brute Force":
neighbor_algorithm = "brute"
else:
neighbor_algorithm = "auto"
if neighbor_algorithm == "kd_tree" and metric != "minkowski":
metric = "minkowski"
verbose and logging.warning(
"When using the KD-tree algorithm, the metric is set by default to Minkowski."
)
standard_scaling = options.get("standard_scaling", True)
auto_split = options.get("auto_split", True)
test_val_size = options.get("test_val_size", 15) / 100
shuffle_split = options.get("shuffle_split", True)
stratify_split = options.get("stratify_split", True)
retrain_on_full = options.get("retrain_on_full", False)
custom_average_strategy = options.get("custom_average_strategy", "auto")
use_cross_validation = options.get("use_cross_validation", False)
cv_folds = options.get("cv_folds", 5)
use_hpo = options.get("use_hyperparameter_optimization", False)
optimization_metric = options.get("optimization_metric", "F1 Score")
optimization_method = options.get("optimization_method", "Tree-structured Parzen")
optimization_iterations = options.get("optimization_iterations", 50)
pos_label_option = options.get("pos_label", "").strip()
if pos_label_option == "":
pos_label_option = None
return_shap_explainer = options.get("return_shap_explainer", False)
use_shap_sampler = options.get("use_shap_sampler", False)
metrics_as = options.get("metrics_as", "Dataframe")
n_jobs_str = options.get("n_jobs", "1")
random_state = options.get("random_state", 42)
activate_caching = options.get("activate_caching", False)
verbose = options.get("verbose", True)
n_jobs_int = -1 if n_jobs_str == "All" else int(n_jobs_str)
skip_computation = False
Scaler = None
Model = None
Metrics = pd.DataFrame()
CV_Metrics = pd.DataFrame()
SHAP = None
HPO_Trials = pd.DataFrame()
HPO_Best = None
accuracy = None
precision = None
recall = None
f1 = None
roc_auc = None
(n_samples, n_features, sparsity, is_sparse) = _get_shape_and_sparsity(X)
shap_feature_names = _ensure_feature_names(X)
if standard_scaling:
verbose and logger.info("Standard scaling is activated")
if activate_caching:
verbose and logger.info(f"Caching is activate")
data_hasher = _UniversalDatasetHasher(n_samples, verbose=verbose)
X_hash = data_hasher.hash_data(X).hash
y_hash = data_hasher.hash_data(y).hash
all_hash_base_text = f"HASH BASE TEXTPandas Version {pd.__version__}POLARS Version {pl.__version__}Numpy Version {np.__version__}Scikit Learn Version {sklearn.__version__}Scipy Version {scipy.__version__}{('SHAP Version ' + shap.__version__ if return_shap_explainer else 'NO SHAP Version')}{X_hash}{y_hash}{n_neighbors}{neighbor_algorithm}{metric}{p}{weights}{('Use HPO' if use_hpo else 'No HPO')}{(optimization_metric if use_hpo else 'No HPO Metric')}{(optimization_method if use_hpo else 'No HPO Method')}{(optimization_iterations if use_hpo else 'No HPO Iter')}{(cv_folds if use_cross_validation else 'No CV')}{standard_scaling}{('Auto Split' if auto_split else test_val_size)}{shuffle_split}{stratify_split}{return_shap_explainer}{use_shap_sampler}{random_state}{(pos_label_option if pos_label_option else 'default_pos')}{custom_average_strategy}"
all_hash = hashlib.sha256(all_hash_base_text.encode("utf-8")).hexdigest()
verbose and logger.info(f"Hash was computed: {all_hash}")
temp_folder = Path(tempfile.gettempdir())
cache_folder = temp_folder / "coded-flows-cache"
cache_folder.mkdir(parents=True, exist_ok=True)
model_path = cache_folder / f"{all_hash}.model"
metrics_dict_path = cache_folder / f"metrics_{all_hash}.json"
metrics_df_path = cache_folder / f"metrics_{all_hash}.parquet"
cv_metrics_path = cache_folder / f"cv_metrics_{all_hash}.parquet"
hpo_trials_path = cache_folder / f"hpo_trials_{all_hash}.parquet"
hpo_best_params_path = cache_folder / f"hpo_best_params_{all_hash}.json"
prediction_set_path = cache_folder / f"prediction_set_{all_hash}.parquet"
shap_path = cache_folder / f"{all_hash}.shap"
scaler_path = cache_folder / f"{all_hash}.scaler"
skip_computation = model_path.is_file()
if not skip_computation:
try:
_validate_numerical_data(X)
except Exception as e:
verbose and logger.error(
f"Only numerical or boolean types are allowed for 'X' input!"
)
raise
if hasattr(y, "nunique"):
n_classes = y.nunique()
elif hasattr(y, "n_unique"):
n_classes = y.n_unique()
else:
n_classes = len(np.unique(y))
is_multiclass = n_classes > 2
features_names = X.columns if hasattr(X, "columns") else None
fixed_test_split = None if auto_split else test_val_size
(X_train, X_test, y_train, y_test, val_ratio) = _smart_split(
n_samples,
X,
y,
random_state=random_state,
shuffle=shuffle_split,
stratify=y if stratify_split else None,
fixed_test_split=fixed_test_split,
verbose=verbose,
)
if custom_average_strategy == "auto":
metric_average_strategy = _get_best_metric_average_strategy(y_test)
else:
metric_average_strategy = custom_average_strategy
effective_pos_label = None
if metric_average_strategy == "binary":
unique_classes = np.unique(y_train)
if pos_label_option is not None:
if pos_label_option in unique_classes:
effective_pos_label = pos_label_option
else:
try:
as_int = int(pos_label_option)
if as_int in unique_classes:
effective_pos_label = as_int
except ValueError:
pass
elif 1 in unique_classes:
effective_pos_label = 1
elif "1" in unique_classes:
effective_pos_label = "1"
if effective_pos_label is not None:
verbose and logger.info(f"Using positive label: {effective_pos_label}")
elif effective_pos_label is None and metric_average_strategy == "binary":
error_message = 'The target appears to be binary, but no positive label was provided and no "1" class exists in the label set.'
verbose and logger.error(error_message)
raise ValueError(error_message)
if use_hpo:
verbose and logger.info(f"Performing Hyperparameters Optimization")
constant_hyperparameters = {}
(HPO_Best, HPO_Trials) = _hyperparameters_optimization(
X_train,
y_train,
constant_hyperparameters,
optimization_metric,
metric_average_strategy,
val_ratio,
shuffle_split,
stratify_split,
use_cross_validation,
cv_folds,
optimization_iterations,
"maximize",
optimization_method,
standard_scaling,
random_state,
n_jobs_int,
verbose=verbose,
pos_label=effective_pos_label,
)
HPO_Trials = _normalize_hpo_df(HPO_Trials)
n_neighbors = HPO_Best["n_neighbors"]
neighbor_algorithm = "auto"
metric = HPO_Best["metric"]
p = HPO_Best["p"]
weights = HPO_Best["weights"]
if standard_scaling:
Scaler = StandardScaler().set_output(transform="pandas")
X_train = Scaler.fit_transform(X_train)
Model = KNeighborsClassifier(
n_neighbors=n_neighbors,
algorithm=neighbor_algorithm,
metric=metric,
p=p,
weights=weights,
n_jobs=n_jobs_int,
)
if use_cross_validation and (not use_hpo):
verbose and logger.info(
f"Using Cross-Validation to measure performance metrics"
)
CV_Metrics = _perform_cross_validation(
Model,
X_train,
y_train,
cv_folds,
metric_average_strategy,
shuffle_split,
random_state,
n_jobs_int,
verbose,
pos_label=effective_pos_label,
)
Model.fit(X_train, y_train)
y_pred = Model.predict(Scaler.transform(X_test) if standard_scaling else X_test)
if is_multiclass:
y_score = Model.predict_proba(
Scaler.transform(X_test) if standard_scaling else X_test
)
elif effective_pos_label is not None:
try:
pos_idx = list(Model.classes_).index(effective_pos_label)
y_score = Model.predict_proba(
Scaler.transform(X_test) if standard_scaling else X_test
)[:, pos_idx]
except ValueError:
y_score = Model.predict_proba(
Scaler.transform(X_test) if standard_scaling else X_test
)[:, 1]
else:
y_score = Model.predict_proba(
Scaler.transform(X_test) if standard_scaling else X_test
)[:, 1]
score_params = {"average": metric_average_strategy, "zero_division": 0}
if effective_pos_label:
score_params["pos_label"] = effective_pos_label
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred, **score_params)
recall = recall_score(y_test, y_pred, **score_params)
f1 = f1_score(y_test, y_pred, **score_params)
if is_multiclass:
roc_auc = roc_auc_score(
y_test, y_score, multi_class="ovr", average=metric_average_strategy
)
else:
roc_auc = roc_auc_score(y_test, y_score)
if metrics_as == "Dataframe":
Metrics = pd.DataFrame(
{
"Metric": [
"Accuracy",
"Precision",
"Recall",
"F1-Score",
"ROC AUC",
],
"Value": [accuracy, precision, recall, f1, roc_auc],
}
)
else:
Metrics = {
"Accuracy": accuracy,
"Precision": precision,
"Recall": recall,
"F1-Score": f1,
"ROC AUC": roc_auc,
}
verbose and logger.info(f"Accuracy : {accuracy:.4f}")
verbose and logger.info(f"Precision : {precision:.4f}")
verbose and logger.info(f"Recall : {recall:.4f}")
verbose and logger.info(f"F1-Score : {f1:.4f}")
verbose and logger.info(f"ROC-AUC : {roc_auc:.4f}")
Prediction_Set = _combine_test_data(
X_test, y_test, y_pred, y_score, Model.classes_, features_names
)
verbose and logger.info(f"Prediction Set created")
if retrain_on_full:
verbose and logger.info(
"Retraining model on full dataset for production deployment"
)
if standard_scaling:
Scaler = StandardScaler().set_output(transform="pandas")
X = Scaler.fit_transform(X)
Model.fit(X, y)
verbose and logger.info(
"Model successfully retrained on full dataset. Reported metrics remain from original held-out test set."
)
if return_shap_explainer:
SHAP = shap.KernelExplainer(
Model.predict_proba,
(
_smart_shap_background(
X if retrain_on_full else X_train,
model_type="other",
seed=random_state,
verbose=verbose,
)
if use_shap_sampler
else X if retrain_on_full else X_train
),
feature_names=shap_feature_names,
link="logit",
)
verbose and logger.info(f"SHAP explainer generated")
if activate_caching:
verbose and logger.info(f"Caching output elements")
joblib.dump(Model, model_path)
if isinstance(Metrics, dict):
with metrics_dict_path.open("w", encoding="utf-8") as f:
json.dump(Metrics, f, ensure_ascii=False, indent=4)
else:
Metrics.to_parquet(metrics_df_path)
if use_cross_validation and (not use_hpo):
CV_Metrics.to_parquet(cv_metrics_path)
if use_hpo:
HPO_Trials.to_parquet(hpo_trials_path)
with hpo_best_params_path.open("w", encoding="utf-8") as f:
json.dump(HPO_Best, f, ensure_ascii=False, indent=4)
Prediction_Set.to_parquet(prediction_set_path)
if return_shap_explainer:
with shap_path.open("wb") as f:
joblib.dump(SHAP, f)
joblib.dump(Scaler, scaler_path)
verbose and logger.info(f"Caching done")
else:
verbose and logger.info(f"Skipping computations and loading cached elements")
Model = joblib.load(model_path)
verbose and logger.info(f"Model loaded")
if metrics_dict_path.is_file():
with metrics_dict_path.open("r", encoding="utf-8") as f:
Metrics = json.load(f)
else:
Metrics = pd.read_parquet(metrics_df_path)
verbose and logger.info(f"Metrics loaded")
if use_cross_validation and (not use_hpo):
CV_Metrics = pd.read_parquet(cv_metrics_path)
verbose and logger.info(f"Cross Validation metrics loaded")
if use_hpo:
HPO_Trials = pd.read_parquet(hpo_trials_path)
with hpo_best_params_path.open("r", encoding="utf-8") as f:
HPO_Best = json.load(f)
verbose and logger.info(
f"Hyperparameters Optimization trials and best params loaded"
)
Prediction_Set = pd.read_parquet(prediction_set_path)
verbose and logger.info(f"Prediction Set loaded")
if return_shap_explainer:
with shap_path.open("rb") as f:
SHAP = joblib.load(f)
verbose and logger.info(f"SHAP Explainer loaded")
Scaler = joblib.load(scaler_path)
verbose and logger.info(f"Standard Scaler loaded")
Model_Classes = _class_index_df(Model)
return (
Model,
Model_Classes,
SHAP,
Scaler,
Metrics,
CV_Metrics,
Prediction_Set,
HPO_Trials,
HPO_Best,
)
Brick Info
- shap>=0.47.0
- scikit-learn
- pandas
- numpy
- torch
- numba>=0.56.0
- shap
- cmaes
- optuna
- scipy
- polars
- xxhash