Cla. XGBoost

Train an XGBoost classification model.

Cla. XGBoost

Processing

This brick trains an XGBoost Classification model, one of the most powerful and popular machine learning algorithms. It works by building a series of "decision trees," where each new tree corrects the errors of the previous ones.

Beyond standard training, this brick includes advanced features:

  • Auto-Balancing: Automatically detects if one class (e.g., "Fraud") is much rarer than another and adjusts weights so the model doesn't ignore the minority class.
  • Hyperparameter Optimization (HPO): Can automatically run multiple experiments to find the best possible settings for your data.
  • Explainability: Can generate SHAP values to explain why the model made specific predictions.

Important Notice for Apple Silicon Users

If you're using a Mac with Apple Silicon (M1, M2, M3, M4, or later), you must install the OpenMP library before using this brick, or XGBoost will crash or fail to load.

Required Setup:

Open your Terminal and run:

brew install libomp

If you already have Homebrew installed but still encounter errors, try:

brew reinstall libomp

Common Error Messages (without libomp):

  • OMP: Error #15: ...
  • dlopen(...libomp.dylib...): image not found
  • Symbol not found: _omp_init_lock

After installation: Restart your terminal, IDE, or kernel before running the brick again.

Inputs

X
The dataset containing the features (independent variables) you want to use for prediction. These are the characteristics (like Age, Price, Dimensions) the model analyzes.
y
The column containing the target labels (dependent variable) you want to predict. This contains the categories (like "Spam/Not Spam" or "High/Medium/Low").

Inputs Types

Input Types
X DataFrame
y DataSeries, NDArray, List

You can check the list of supported types here: Available Type Hints.

Outputs

Model
The trained XGBoost classifier object. You can pass this to other bricks to make predictions on new, unseen data.
Model Classes
A reference table that maps the internal numerical IDs used by the model back to your original class names (e.g., 0 = "Cat", 1 = "Dog").
SHAP
The SHAP explainer object. This is used by visualization bricks to show which features had the biggest impact on the model's decisions.
Label Encoder
The encoder tool used to transform your text labels into numbers. Useful if you need to reverse-engineer predictions later.
Metrics
A summary of how well the model performed on the test set. Contains scores like Accuracy, Precision, Recall, F1-Score, and ROC AUC.
CV Metrics
If Cross-Validation is enabled, this provides performance statistics averaged across multiple test "folds," giving a more robust estimate of model reliability.
Features Importance
A list ranking your input columns by how useful they were to the model. Higher importance means the feature had a greater impact on the prediction.
Prediction Set
A copy of the test dataset including the model's predictions, the actual true values, and the calculated probabilities (confidence).
HPO Trials
If Hyperparameter Optimization was used, this log details every experiment run, showing which settings were tried and the resulting score.
HPO Best
A dictionary containing the specific combination of settings (hyperparameters) that yielded the best results during optimization.

The Prediction Set output contains the following specific data fields:

  • feature_...: The original feature columns from your input X.
  • proba: (For binary) The probability score of the positive class.
  • proba_{class_name}: (For multiclass) The probability score for each specific class.
  • y_true: The actual label from the test data.
  • y_pred: The label predicted by the model.
  • is_false_prediction: Boolean (True/False) indicating if the model made a mistake on this row.

The Metrics output contains the following specific data fields:

  • Accuracy: The percentage of correct predictions.
  • Precision: The accuracy of positive predictions.
  • Recall: The percentage of actual positives correctly identified.
  • F1-Score: The harmonic mean of precision and recall.
  • ROC AUC: The area under the ROC curve (performance across all classification thresholds).

Outputs Types

Output Types
Model Any
Model Classes DataFrame
SHAP Any
Label Encoder Any
Metrics DataFrame, Dict
CV Metrics DataFrame
Features Importance DataFrame
Prediction Set DataFrame
HPO Trials DataFrame
HPO Best Dict

You can check the list of supported types here: Available Type Hints.

Options

The Cla. XGBoost brick contains some changeable options:

Max Number of Trees
Controls the maximum number of decision trees to build. More trees can learn more complex patterns but take longer to train.
Enable Early Stopping
If enabled, training will stop automatically once the model's performance stops improving. This prevents "overfitting" (memorizing the data) and saves time.
Early Stopping Rounds
The number of rounds to wait for an improvement before stopping. For example, if set to 10, the model stops if the score hasn't improved in the last 10 trees.
Use Dropout
Enables DART (Dropout Additive Regression Trees). This randomly drops trees during training, which can help prevent overfitting in very complex datasets.
Max Depth
How deep each tree can grow. Deeper trees capture more complex interactions but are more likely to overfit.
  • Low values (1-3): Simple models, safe but potentially less accurate.
  • High values (10+): Complex models, high accuracy but risk of overfitting.
Learning Rate (Eta)
Controls how much the model learns from each tree.
  • Lower values: Slower training, but often results in a more generalized/better model.
  • Higher values: Faster training, but might miss the optimal solution.
L2 Regularization (Lambda)
A penalty term that discourages large weights. Increasing this makes the model more conservative.
L1 Regularization (Alpha)
Another penalty term. Increasing this encourages the model to ignore irrelevant features (setting their weights to zero).
Gamma (Min Split Loss)
The minimum loss reduction required to make a further partition on a leaf node of the tree. Higher values make the algorithm more conservative.
Min Child Weight
Controls the minimum amount of data required to create a new node in the tree. Higher values prevent the model from learning extremely specific rules based on just a few data points.
Subsample Ratio
The fraction of the training data randomly sampled to train each tree. Setting this below 1.0 (e.g., 0.8) helps prevent overfitting.
Colsample by Tree
The fraction of columns (features) randomly sampled for each tree. Similar to "Subsample Ratio," this adds randomness to make the model more robust.
Auto Split Data
If enabled, the brick automatically splits your X and y data into Training and Test sets based on the dataset size.
Shuffle Split
Whether to shuffle the data randomly before splitting. Recommended to ensure the test set represents the whole dataset.
Stratify Split
Ensures that the Training and Test sets have the same proportion of class labels (e.g., if 10% of data is "Fraud", the test set will also be 10% "Fraud"). Highly recommended for classification.
Test/Validation Set %
The percentage of data to hold back for testing (only used if "Auto Split Data" is disabled).
Retrain On Full Data
If enabled, after testing and calculating metrics, the model is re-trained on 100% of the data. Use this when you are ready to deploy the model to production.
Average Strategy
How to calculate metrics (like Precision/Recall) for multiclass problems.
  • auto: Automatically selects based on class balance.
  • binary: Only for two classes (e.g., Yes/No).
  • micro: Calculate metrics globally by counting total true positives, false negatives and false positives.
  • macro: Calculate metrics for each label, and find their unweighted mean.
  • weighted: Calculate metrics for each label, and find their average weighted by support (the number of true instances for each label).
Enable Cross-Validation
Instead of one train/test split, the model is trained and tested multiple times on different "folds" of data. Provides a more reliable accuracy estimate but takes longer.
Number of CV Folds
The number of groups to split the data into for Cross-Validation (e.g., 5 means 5 training runs).
Hyperparameter Optim.
If enabled, the brick will ignore the manual settings above and run an automated search to find the best configuration.
Optimization Metric
The specific score the optimizer should try to maximize.
  • F1 Score: Balances precision and recall (good for uneven classes).
  • Accuracy: Overall correctness.
  • Precision: Minimizing false positives.
  • Recall: Minimizing false negatives.
  • ROC-AUC: Ability to distinguish between classes.
Optimization Method
The algorithm used to search for the best parameters.
  • Tree-structured Parzen: A Bayesian optimization method that models good vs bad parameter regions using probability distributions and prioritizes sampling where success is statistically more likely.
  • Gaussian Process: ses a probabilistic regression model (Gaussian Process) to estimate performance uncertainty and selects new trials using acquisition functions.
  • CMA-ES: An evolutionary strategy that adapts the covariance matrix of a multivariate normal distribution to efficiently search complex, non-linear, non-convex spaces.
  • Random Sobol Search: Uses low-discrepancy quasi-random sequences to ensure uniform coverage of the parameter space, avoiding clustering and gaps.
  • Random Search: Uniform random sampling of parameter configurations without learning or feedback between iterations.
Optimization Iterations
How many different combinations of settings to try. More iterations take longer but may find a better model.
Positive Label (Binary Only)
If you have a binary problem (e.g., "Yes"/"No"), specify which one is the "Positive" class (e.g., "Yes"). Required for accurate precision/recall calculations.
Metrics as
Choose the output format for the Metrics variable.
SHAP Explainer
If enabled, generates the SHAP output object for explainability.
SHAP Sampler
Uses a subset of data to calculate SHAP backgrounds. Faster for large datasets.
SHAP Feature Perturbation
Technical method for calculating SHAP values. "Interventional" is generally more accurate but slower; "Tree Path Dependent" is faster.
Number of Jobs
The number of CPU cores to use. "All" uses the full power of the machine.
Random State
A seed number. Keeping this constant ensures that if you run the brick again with the same data, you get the exact same result.
Brick Caching
If enabled, saves the results to a temporary cache. If you re-run the flow without changing inputs, it loads the result instantly instead of re-training.
Verbose Logging
Prints detailed progress updates to the logs during training.
import logging
import warnings
import shap
import json
import xxhash
import hashlib
import tempfile
import sklearn
import scipy
import joblib
import numpy as np
import pandas as pd
import polars as pl
from xgboost import XGBClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_sample_weight
from pathlib import Path
from scipy import sparse
from collections import Counter
from optuna.samplers import (
    TPESampler,
    RandomSampler,
    GPSampler,
    CmaEsSampler,
    QMCSampler,
)
import optuna
from optuna import Study
from optuna.trial import FrozenTrial
from optuna.pruners import HyperbandPruner
from optuna import create_study
from sklearn.model_selection import train_test_split, cross_validate, StratifiedKFold
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    make_scorer,
)
from dataclasses import dataclass
from datetime import datetime
from coded_flows.types import (
    Union,
    Dict,
    List,
    Tuple,
    NDArray,
    DataFrame,
    DataSeries,
    Any,
    Tuple,
)
from coded_flows.utils import CodedFlowsLogger

logger = CodedFlowsLogger(name="Cla. XGBoost", level=logging.INFO)
optuna.logging.set_verbosity(optuna.logging.ERROR)
warnings.filterwarnings("ignore", category=optuna.exceptions.ExperimentalWarning)
DataType = Union[
    pd.DataFrame, pl.DataFrame, np.ndarray, sparse.spmatrix, pd.Series, pl.Series
]


@dataclass
class _DatasetFingerprint:
    """Lightweight fingerprint of a dataset."""

    hash: str
    shape: tuple
    computed_at: str
    data_type: str
    method: str


class _UniversalDatasetHasher:
    """
    High-performance dataset hasher optimizing for zero-copy operations
    and native backend execution (C/Rust).
    """

    def __init__(
        self,
        data_size: int,
        method: str = "auto",
        sample_size: int = 100000,
        verbose: bool = False,
    ):
        self.method = method
        self.sample_size = sample_size
        self.data_size = data_size
        self.verbose = verbose

    def hash_data(self, data: DataType) -> _DatasetFingerprint:
        """
        Main entry point: hash any supported data format.
        Auto-detects format and applies optimal strategy.
        """
        if isinstance(data, pd.DataFrame):
            return self._hash_pandas(data)
        elif isinstance(data, pl.DataFrame):
            return self._hash_polars(data)
        elif isinstance(data, pd.Series):
            return self._hash_pandas_series(data)
        elif isinstance(data, pl.Series):
            return self._hash_polars_series(data)
        elif isinstance(data, np.ndarray):
            return self._hash_numpy(data)
        elif sparse.issparse(data):
            return self._hash_sparse(data)
        else:
            raise TypeError(f"Unsupported data type: {type(data)}")

    def _hash_pandas(self, df: pd.DataFrame) -> _DatasetFingerprint:
        """
        Optimized Pandas hashing using pd.util.hash_pandas_object.
        Avoids object-to-string conversion overhead.
        """
        method = self._determine_method(self.data_size, self.method)
        self.verbose and logger.info(
            f"Hashing Pandas: {self.data_size:,} rows - {method}"
        )
        target_df = df
        if method == "sampled":
            target_df = self._get_pandas_sample(df)
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "columns": df.columns.tolist(),
                "dtypes": {k: str(v) for (k, v) in df.dtypes.items()},
                "shape": df.shape,
            },
        )
        try:
            row_hashes = pd.util.hash_pandas_object(target_df, index=False)
            hasher.update(memoryview(row_hashes.values))
        except Exception as e:
            self.verbose and logger.warning(
                f"Fast hash failed, falling back to slow hash: {e}"
            )
            self._hash_pandas_fallback(hasher, target_df)
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=df.shape,
            computed_at=datetime.now().isoformat(),
            data_type="pandas",
            method=method,
        )

    def _get_pandas_sample(self, df: pd.DataFrame) -> pd.DataFrame:
        """Deterministic slicing for sampling (Zero randomness)."""
        if self.data_size <= self.sample_size:
            return df
        chunk = self.sample_size // 3
        head = df.iloc[:chunk]
        mid_idx = self.data_size // 2
        mid = df.iloc[mid_idx : mid_idx + chunk]
        tail = df.iloc[-chunk:]
        return pd.concat([head, mid, tail])

    def _hash_pandas_fallback(self, hasher, df: pd.DataFrame):
        """Legacy fallback for complex object types."""
        for col in df.columns:
            val = df[col].astype(str).values
            hasher.update(val.astype(np.bytes_).tobytes())

    def _hash_polars(self, df: pl.DataFrame) -> _DatasetFingerprint:
        """
        Optimized Polars hashing using native Rust execution.
        """
        method = self._determine_method(self.data_size, self.method)
        self.verbose and logger.info(
            f"Hashing Polars: {self.data_size:,} rows - {method}"
        )
        target_df = df
        if method == "sampled" and self.data_size > self.sample_size:
            indices = self._get_sample_indices(self.data_size, self.sample_size)
            target_df = df.gather(indices)
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "columns": df.columns,
                "dtypes": [str(t) for t in df.dtypes],
                "shape": df.shape,
            },
        )
        row_hashes = target_df.hash_rows()
        hasher.update(memoryview(row_hashes.to_numpy()))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=df.shape,
            computed_at=datetime.now().isoformat(),
            data_type="polars",
            method=method,
        )

    def _hash_pandas_series(self, series: pd.Series) -> _DatasetFingerprint:
        """Hash Pandas Series using the fastest vectorized method."""
        self.verbose and logger.info(f"Hashing Pandas Series: {self.data_size:,} rows")
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "name": series.name if series.name else "None",
                "dtype": str(series.dtype),
                "shape": series.shape,
            },
        )
        try:
            row_hashes = pd.util.hash_pandas_object(series, index=False)
            hasher.update(memoryview(row_hashes.values))
        except Exception as e:
            self.verbose and logger.warning(f"Series hash failed, falling back: {e}")
            hasher.update(memoryview(series.astype(str).values.tobytes()))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=series.shape,
            computed_at=datetime.now().isoformat(),
            data_type="pandas_series",
            method="full",
        )

    def _hash_polars_series(self, series: pl.Series) -> _DatasetFingerprint:
        """Hash Polars Series using native Polars expressions."""
        self.verbose and logger.info(f"Hashing Polars Series: {self.data_size:,} rows")
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {"name": series.name, "dtype": str(series.dtype), "shape": series.shape},
        )
        try:
            row_hashes = series.hash()
            hasher.update(memoryview(row_hashes.to_numpy()))
        except Exception as e:
            self.verbose and logger.warning(
                f"Polars series native hash failed. Falling back."
            )
            hasher.update(str(series.to_list()).encode())
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=series.shape,
            computed_at=datetime.now().isoformat(),
            data_type="polars_series",
            method="full",
        )

    def _hash_numpy(self, arr: np.ndarray) -> _DatasetFingerprint:
        """
        Optimized NumPy hashing using Buffer Protocol (Zero-Copy).
        """
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {"shape": arr.shape, "dtype": str(arr.dtype), "strides": arr.strides},
        )
        if arr.flags["C_CONTIGUOUS"] or arr.flags["F_CONTIGUOUS"]:
            hasher.update(memoryview(arr))
        else:
            hasher.update(memoryview(np.ascontiguousarray(arr)))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=arr.shape,
            computed_at=datetime.now().isoformat(),
            data_type="numpy",
            method="full",
        )

    def _hash_sparse(self, matrix: sparse.spmatrix) -> _DatasetFingerprint:
        """
        Optimized sparse hashing. Hashes underlying data arrays directly.
        """
        if not (sparse.isspmatrix_csr(matrix) or sparse.isspmatrix_csc(matrix)):
            matrix = matrix.tocsr()
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher, {"shape": matrix.shape, "format": matrix.format, "nnz": matrix.nnz}
        )
        hasher.update(memoryview(matrix.data))
        hasher.update(memoryview(matrix.indices))
        hasher.update(memoryview(matrix.indptr))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=matrix.shape,
            computed_at=datetime.now().isoformat(),
            data_type=f"sparse_{matrix.format}",
            method="sparse",
        )

    def _determine_method(self, rows: int, requested: str) -> str:
        if requested != "auto":
            return requested
        if rows < 5000000:
            return "full"
        return "sampled"

    def _hash_schema(self, hasher, schema: Dict[str, Any]):
        """Compact schema hashing."""
        hasher.update(
            json.dumps(schema, sort_keys=True, separators=(",", ":")).encode()
        )

    def _get_sample_indices(self, total_rows: int, sample_size: int) -> list:
        """Calculate indices for sampling without generating full range lists."""
        chunk = sample_size // 3
        indices = list(range(min(chunk, total_rows)))
        mid_start = max(0, total_rows // 2 - chunk // 2)
        mid_end = min(mid_start + chunk, total_rows)
        indices.extend(range(mid_start, mid_end))
        last_start = max(0, total_rows - chunk)
        indices.extend(range(last_start, total_rows))
        return sorted(list(set(indices)))


class _AutoBalanceXGBClassifier(XGBClassifier):

    def fit(self, X, y, **kwargs):
        """
        Overridden fit method that detects imbalance and applies:
        - scale_pos_weight (Binary)
        - sample_weight (Multiclass)
        """
        decision = self._xgb_class_weight_decision(y)
        self.imbalance_meta_ = decision
        if decision["strategy"] == "scale_pos_weight":
            self.scale_pos_weight = decision["value"]
            if "sample_weight" in kwargs:
                pass
        elif decision["strategy"] == "sample_weight":
            self.scale_pos_weight = 1
            weights = compute_sample_weight(class_weight="balanced", y=y)
            kwargs["sample_weight"] = weights
        else:
            self.scale_pos_weight = 1
        return super().fit(X, y, **kwargs)

    @staticmethod
    def _xgb_class_weight_decision(y):
        """
        Your logic for determining imbalance strategy.
        """
        y = y.values if isinstance(y, pd.Series) else np.asarray(y)
        classes = np.unique(y)
        n_classes = len(classes)
        counts = Counter(y)
        if n_classes < 2:
            return {
                "problem": "mono",
                "imbalance": 0,
                "strategy": "none",
                "value": None,
            }
        max_c = max(counts.values())
        min_c = min(counts.values())
        if min_c == 0:
            imbalance = float("inf")
        else:
            imbalance = round(max_c / min_c, 3)
        if imbalance < 1.5:
            return {
                "problem": "binary" if n_classes == 2 else "multiclass",
                "imbalance": imbalance,
                "strategy": "none",
                "value": None,
            }
        if n_classes == 2:
            ratio = max_c / min_c
            return {
                "problem": "binary",
                "imbalance": imbalance,
                "strategy": "scale_pos_weight",
                "value": round(ratio, 3),
            }
        return {
            "problem": "multiclass",
            "imbalance": imbalance,
            "strategy": "sample_weight",
            "value": None,
        }


def _normalize_hpo_df(df):
    df = df.copy()
    param_cols = [c for c in df.columns if c.startswith("params_")]
    df[param_cols] = df[param_cols].astype("string[pyarrow]")
    return df


def _get_shape_and_sparsity(X: Any) -> Tuple[int, int, float, bool]:
    """
    Efficiently extracts shape and estimates sparsity without converting
    the entire dataset to numpy.
    """
    (n_samples, n_features) = (0, 0)
    is_sparse = False
    sparsity = 0.0
    if hasattr(X, "nnz") and hasattr(X, "shape"):
        (n_samples, n_features) = X.shape
        is_sparse = True
        sparsity = 1.0 - X.nnz / (n_samples * n_features)
        return (n_samples, n_features, sparsity, is_sparse)
    if hasattr(X, "height") and hasattr(X, "width"):
        (n_samples, n_features) = (X.height, X.width)
        return (n_samples, n_features, 0.0, False)
    if hasattr(X, "shape") and hasattr(X, "iloc"):
        (n_samples, n_features) = X.shape
        return (n_samples, n_features, 0.0, False)
    if isinstance(X, list):
        X = np.array(X)
    if hasattr(X, "shape"):
        (n_samples, n_features) = X.shape
        return (n_samples, n_features, 0.0, False)
    raise ValueError("Unsupported data type")


def _smart_split(
    n_samples,
    X,
    y,
    *,
    random_state=42,
    shuffle=True,
    stratify=None,
    fixed_test_split=None,
    verbose=True,
):
    """
    Parameters
    ----------
    n_samples : int
        Number of samples in the dataset (len(X) or len(y))
    X : array-like
        Features
    y : array-like
        Target
    random_state : int
    shuffle : bool
    stratify : array-like or None
        For stratified splitting (recommended for classification)

    Returns
    -------
    If return_val=True  → X_train, X_val, X_test, y_train, y_val, y_test
    If return_val=False → X_train, X_test, y_train, y_test
    """
    if fixed_test_split:
        test_ratio = fixed_test_split
        val_ratio = fixed_test_split
    elif n_samples <= 1000:
        test_ratio = 0.2
        val_ratio = 0.1
    elif n_samples < 10000:
        test_ratio = 0.15
        val_ratio = 0.15
    elif n_samples < 100000:
        test_ratio = 0.1
        val_ratio = 0.1
    elif n_samples < 1000000:
        test_ratio = 0.05
        val_ratio = 0.05
    else:
        test_ratio = 0.01
        val_ratio = 0.01
    (X_train, X_test, y_train, y_test) = train_test_split(
        X,
        y,
        test_size=test_ratio,
        random_state=random_state,
        shuffle=shuffle,
        stratify=stratify,
    )
    val_size_in_train = val_ratio / (1 - test_ratio)
    verbose and logger.info(
        f"Split → Train: {1 - test_ratio:.2%} | Test: {test_ratio:.2%} (no validation set)"
    )
    return (X_train, X_test, y_train, y_test, val_size_in_train)


def _get_best_metric_average_strategy(y_true, balance_threshold: float = 0.5) -> str:
    """
    Analyzes y_true to determine the best averaging strategy.

    Args:
        y_true: Input array (Numpy array, Pandas Series, or Polars Series).
        balance_threshold: Float (0 to 1). If min_class_count / max_class_count
                           is below this, the data is considered imbalanced.

    Returns:
        str: 'binary', 'weighted', or 'macro'
    """
    counts = None
    if hasattr(y_true, "value_counts") and hasattr(y_true, "values"):
        counts = y_true.value_counts().values
    elif hasattr(y_true, "value_counts") and hasattr(y_true, "to_numpy"):
        vc = y_true.value_counts()
        if "count" in vc.columns:
            counts = vc["count"].to_numpy()
        else:
            counts = vc[:, 1].to_numpy()
    elif isinstance(y_true, np.ndarray):
        (_, counts) = np.unique(y_true, return_counts=True)
    else:
        (_, counts) = np.unique(np.array(y_true), return_counts=True)
    if counts is None or len(counts) == 0:
        raise ValueError("Input y_true appears to be empty.")
    n_classes = len(counts)
    if n_classes <= 2:
        return "binary"
    min_c = np.min(counts)
    max_c = np.max(counts)
    ratio = min_c / max_c
    if ratio < balance_threshold:
        return "weighted"
    else:
        return "macro"


def _ensure_feature_names(X, feature_names=None):
    if isinstance(X, pd.DataFrame):
        return list(X.columns)
    if isinstance(X, np.ndarray):
        if feature_names is None:
            feature_names = [f"feature_{i}" for i in range(X.shape[1])]
        return feature_names
    raise TypeError("X must be a pandas DataFrame or numpy ndarray")


def _perform_cross_validation(
    model,
    X,
    y,
    cv_folds,
    average_strategy,
    shuffle,
    random_state,
    n_jobs,
    verbose,
    pos_label=None,
) -> dict[str, Any]:
    """Perform cross-validation on the model."""
    verbose and logger.info(f"Performing {cv_folds}-fold cross-validation...")
    cv = StratifiedKFold(n_splits=cv_folds, shuffle=shuffle, random_state=random_state)
    if average_strategy == "binary":
        scoring = {
            "accuracy": "accuracy",
            "precision": make_scorer(
                precision_score, average="binary", pos_label=pos_label
            ),
            "recall": make_scorer(recall_score, average="binary", pos_label=pos_label),
            "f1": make_scorer(f1_score, average="binary", pos_label=pos_label),
            "roc_auc": "roc_auc",
        }
    else:
        average_strategy_suffix = f"_{average_strategy}"
        roc_average_strategy_suffix = (
            f"_{average_strategy}" if average_strategy == "weighted" else ""
        )
        roc_auc_ovr_suffix = "_ovr"
        scoring = (
            f"f1{average_strategy_suffix}",
            "accuracy",
            f"precision{average_strategy_suffix}",
            f"recall{average_strategy_suffix}",
            f"roc_auc{roc_auc_ovr_suffix}{roc_average_strategy_suffix}",
        )
    cv_results = cross_validate(
        model, X, y, cv=cv, scoring=scoring, return_train_score=True, n_jobs=n_jobs
    )

    def get_score_mean_std(metric_key):
        if metric_key in cv_results:
            return (cv_results[metric_key].mean(), cv_results[metric_key].std())
        return (0.0, 0.0)

    if average_strategy == "binary":
        (accuracy_mean, accuracy_std) = get_score_mean_std("test_accuracy")
        (precision_mean, precision_std) = get_score_mean_std("test_precision")
        (recall_mean, recall_std) = get_score_mean_std("test_recall")
        (f1_mean, f1_std) = get_score_mean_std("test_f1")
        (roc_auc_mean, roc_auc_std) = get_score_mean_std("test_roc_auc")
    else:
        (accuracy_mean, accuracy_std) = get_score_mean_std("test_accuracy")
        (precision_mean, precision_std) = get_score_mean_std(
            f"test_precision{average_strategy_suffix}"
        )
        (recall_mean, recall_std) = get_score_mean_std(
            f"test_recall{average_strategy_suffix}"
        )
        (f1_mean, f1_std) = get_score_mean_std(f"test_f1{average_strategy_suffix}")
        roc_key = f"test_roc_auc{roc_auc_ovr_suffix}{roc_average_strategy_suffix}"
        (roc_auc_mean, roc_auc_std) = get_score_mean_std(roc_key)
    verbose and logger.info(
        f"CV Accuracy  : {accuracy_mean:.4f} (+/- {accuracy_std:.4f})"
    )
    verbose and logger.info(
        f"CV Precision : {precision_mean:.4f} (+/- {precision_std:.4f})"
    )
    verbose and logger.info(f"CV Recall    : {recall_mean:.4f} (+/- {recall_std:.4f})")
    verbose and logger.info(f"CV F1 Score  : {f1_mean:.4f} (+/- {f1_std:.4f})")
    verbose and logger.info(
        f"CV ROC-AUC   : {roc_auc_mean:.4f} (+/- {roc_auc_std:.4f})"
    )
    CV_metrics = pd.DataFrame(
        {
            "Metric": ["Accuracy", "Precision", "Recall", "F1-Score", "ROC AUC"],
            "Mean": [accuracy_mean, precision_mean, recall_mean, f1_mean, roc_auc_mean],
            "Std": [accuracy_std, precision_std, recall_std, f1_std, roc_auc_std],
        }
    )
    return CV_metrics


def _compute_score(model, X, y, metric, average_strategy, pos_label=None):
    score_params = {"average": average_strategy, "zero_division": 0}
    y_pred = model.predict(X)
    if average_strategy != "binary":
        y_score = model.predict_proba(X)
    else:
        score_params["pos_label"] = pos_label
        if pos_label is not None:
            classes = list(model.classes_)
            try:
                pos_idx = classes.index(pos_label)
            except ValueError:
                pos_idx = 1 if len(classes) > 1 else 0
            y_score = model.predict_proba(X)[:, pos_idx]
        else:
            y_score = model.predict_proba(X)[:, 1]
    if metric == "Accuracy":
        score = accuracy_score(y, y_pred)
    elif metric == "Precision":
        score = precision_score(y, y_pred, **score_params)
    elif metric == "Recall":
        score = recall_score(y, y_pred, **score_params)
    elif metric == "F1 Score":
        score = f1_score(y, y_pred, **score_params)
    elif metric == "ROC-AUC":
        if average_strategy != "binary":
            score = roc_auc_score(
                y, y_score, multi_class="ovr", average=average_strategy
            )
        else:
            score = roc_auc_score(y, y_score)
    return score


def _get_cv_scoring_object(metric, average_strategy, pos_label=None):
    """
    Returns a scoring object (string or callable) suitable for cross_validate.
    Used during HPO.
    """
    if average_strategy == "binary":
        if metric == "F1 Score":
            return make_scorer(f1_score, average="binary", pos_label=pos_label)
        elif metric == "Accuracy":
            return "accuracy"
        elif metric == "Precision":
            return make_scorer(precision_score, average="binary", pos_label=pos_label)
        elif metric == "Recall":
            return make_scorer(recall_score, average="binary", pos_label=pos_label)
        elif metric == "ROC-AUC":
            return "roc_auc"
    else:
        average_strategy_suffix = f"_{average_strategy}"
        roc_auc_ovr_suffix = "_ovr"
        roc_average_strategy_suffix = (
            f"_{average_strategy}" if average_strategy == "weighted" else ""
        )
        if metric == "F1 Score":
            return f"f1{average_strategy_suffix}"
        elif metric == "Accuracy":
            return "accuracy"
        elif metric == "Precision":
            return f"precision{average_strategy_suffix}"
        elif metric == "Recall":
            return f"recall{average_strategy_suffix}"
        elif metric == "ROC-AUC":
            return f"roc_auc{roc_auc_ovr_suffix}{roc_average_strategy_suffix}"


def _hyperparameters_optimization(
    X,
    y,
    constant_hyperparameters,
    optimization_metric,
    metric_average_strategy,
    val_ratio,
    shuffle_split,
    stratify_split,
    use_cross_val,
    cv_folds,
    n_trials=50,
    strategy="maximize",
    sampler="Tree-structured Parzen",
    seed=None,
    n_jobs=-1,
    verbose=False,
    pos_label=None,
):
    direction = "maximize" if strategy.lower() == "maximize" else "minimize"
    sampler_map = {
        "Tree-structured Parzen": TPESampler(seed=seed),
        "Gaussian Process": GPSampler(seed=seed),
        "CMA-ES": CmaEsSampler(seed=seed),
        "Random Search": RandomSampler(seed=seed),
        "Random Sobol Search": QMCSampler(seed=seed),
    }
    if sampler in sampler_map:
        chosen_sampler = sampler_map[sampler]
    else:
        logger.warning(f"Sampler '{sampler}' not recognized → falling back to TPE")
        chosen_sampler = TPESampler(seed=seed)
    chosen_pruner = HyperbandPruner()
    if use_cross_val:
        cv = StratifiedKFold(
            n_splits=cv_folds, shuffle=shuffle_split, random_state=seed
        )
        cv_score_obj = _get_cv_scoring_object(
            optimization_metric, metric_average_strategy, pos_label
        )
    else:
        (X_train, X_val, y_train, y_val) = train_test_split(
            X,
            y,
            test_size=val_ratio,
            random_state=seed,
            stratify=y if stratify_split else None,
            shuffle=shuffle_split,
        )

    def logging_callback(study: Study, trial: FrozenTrial):
        """Callback function to log trial progress"""
        verbose and logger.info(
            f"Trial {trial.number} finished with value: {trial.value} and parameters: {trial.params}"
        )
        try:
            verbose and logger.info(f"Best value so far: {study.best_value}")
            verbose and logger.info(f"Best parameters so far: {study.best_params}")
        except ValueError:
            verbose and logger.info(f"No successful trials completed yet")
        verbose and logger.info(f"" + "-" * 50)

    def objective(trial):
        try:
            booster = constant_hyperparameters.get("booster")
            feature_types = constant_hyperparameters.get("feature_types")
            params = {}
            params["n_estimators"] = trial.suggest_int("n_estimators", 50, 1000)
            params["max_depth"] = trial.suggest_int("max_depth", 1, 15)
            params["learning_rate"] = trial.suggest_float(
                "learning_rate", 0.0001, 1.0, log=True
            )
            params["reg_lambda"] = trial.suggest_float(
                "reg_lambda", 1e-08, 100.0, log=True
            )
            params["reg_alpha"] = trial.suggest_float(
                "reg_alpha", 1e-08, 100.0, log=True
            )
            params["gamma"] = trial.suggest_float("gamma", 0.0, 5.0)
            params["min_child_weight"] = trial.suggest_float(
                "min_child_weight", 0.0, 10.0
            )
            params["subsample"] = trial.suggest_float("subsample", 0.1, 1.0)
            params["colsample_bytree"] = trial.suggest_float(
                "colsample_bytree", 0.1, 1.0
            )
            model = _AutoBalanceXGBClassifier(
                **params,
                feature_types=feature_types,
                tree_method="hist",
                enable_categorical=True,
                booster=booster,
                random_state=seed,
                n_jobs=n_jobs,
            )
            if use_cross_val:
                scores = cross_validate(
                    model, X, y, cv=cv, n_jobs=n_jobs, scoring=cv_score_obj
                )
                return scores["test_score"].mean()
            else:
                model.fit(X_train, y_train)
                score = _compute_score(
                    model,
                    X_val,
                    y_val,
                    optimization_metric,
                    metric_average_strategy,
                    pos_label,
                )
                return score
        except Exception as e:
            verbose and logger.error(
                f"Trial {trial.number} failed with error: {str(e)}"
            )
            raise

    study = create_study(
        direction=direction, sampler=chosen_sampler, pruner=chosen_pruner
    )
    study.optimize(
        objective,
        n_trials=n_trials,
        catch=(Exception,),
        n_jobs=n_jobs,
        callbacks=[logging_callback],
    )
    verbose and logger.info(f"Optimization completed!")
    verbose and logger.info(
        f"   Best Number of Trees       : {study.best_params['n_estimators']}"
    )
    verbose and logger.info(
        f"   Best Max Depth             : {study.best_params['max_depth']}"
    )
    verbose and logger.info(
        f"   Best Learning Rate         : {study.best_params['learning_rate']}"
    )
    verbose and logger.info(
        f"   Best L1 Regularization     : {study.best_params['reg_alpha']}"
    )
    verbose and logger.info(
        f"   Best L2 Regularization     : {study.best_params['reg_lambda']}"
    )
    verbose and logger.info(
        f"   Best Gamma                 : {study.best_params['gamma']}"
    )
    verbose and logger.info(
        f"   Best Min Child Weight      : {study.best_params['min_child_weight']}"
    )
    verbose and logger.info(
        f"   Best Subsample Ratio       : {study.best_params['subsample']}"
    )
    verbose and logger.info(
        f"   Best Colsample by Tree     : {study.best_params['colsample_bytree']}"
    )
    verbose and logger.info(
        f"   Best {optimization_metric:<22}: {study.best_value:.4f}"
    )
    verbose and logger.info(f"   Sampler used               : {sampler}")
    verbose and logger.info(f"   Direction                  : {direction}")
    if use_cross_val:
        verbose and logger.info(f"   Cross-validation           : {cv_folds}-fold")
    else:
        verbose and logger.info(
            f"   Validation                 : single train/val split"
        )
    trials = study.trials_dataframe()
    trials["best_value"] = trials["value"].cummax()
    cols = list(trials.columns)
    value_idx = cols.index("value")
    cols = [c for c in cols if c != "best_value"]
    new_order = cols[: value_idx + 1] + ["best_value"] + cols[value_idx + 1 :]
    trials = trials[new_order]
    return (study.best_params, trials)


def _combine_test_data(
    X_test, y_true, y_pred, y_proba, class_names, features_names=None
):
    """
    Combine X_test, y_true, y_pred, and y_proba into a single DataFrame.

    Parameters:
    -----------
    X_test : pandas/polars DataFrame, numpy array, or scipy sparse matrix
        Test features
    y_true : pandas/polars Series, numpy array, or list
        True labels
    y_pred : pandas/polars Series, numpy array, or list
        Predicted labels
    y_proba : pandas/polars Series/DataFrame, numpy array (1D or 2D), or list
        Prediction probabilities - can be:
        - 1D array for binary classification (probability of positive class)
        - 2D array for multiclass (probabilities for each class)
    class_names : list or array-like
        Names of the classes in order.
        For binary classification with 1D y_proba, only the positive class name is needed.

    Returns:
    --------
    pandas.DataFrame
        Combined DataFrame with features, probabilities, y_true, and y_pred
    """
    if sparse.issparse(X_test):
        X_df = pd.DataFrame(X_test.toarray())
    elif isinstance(X_test, np.ndarray):
        X_df = pd.DataFrame(X_test)
    elif hasattr(X_test, "to_pandas"):
        X_df = X_test.to_pandas()
    elif isinstance(X_test, pd.DataFrame):
        X_df = X_test.copy()
    else:
        raise TypeError(f"Unsupported type for X_test: {type(X_test)}")
    if X_df.columns.tolist() == list(range(len(X_df.columns))):
        X_df.columns = (
            [f"feature_{i}" for i in range(len(X_df.columns))]
            if features_names is None
            else features_names
        )
    if isinstance(y_true, list):
        y_true_series = pd.Series(y_true, name="y_true")
    elif isinstance(y_true, np.ndarray):
        y_true_series = pd.Series(y_true, name="y_true")
    elif hasattr(y_true, "to_pandas"):
        y_true_series = y_true.to_pandas()
        y_true_series.name = "y_true"
    elif isinstance(y_true, pd.Series):
        y_true_series = y_true.copy()
        y_true_series.name = "y_true"
    else:
        raise TypeError(f"Unsupported type for y_true: {type(y_true)}")
    if isinstance(y_pred, list):
        y_pred_series = pd.Series(y_pred, name="y_pred")
    elif isinstance(y_pred, np.ndarray):
        y_pred_series = pd.Series(y_pred, name="y_pred")
    elif hasattr(y_pred, "to_pandas"):
        y_pred_series = y_pred.to_pandas()
        y_pred_series.name = "y_pred"
    elif isinstance(y_pred, pd.Series):
        y_pred_series = y_pred.copy()
        y_pred_series.name = "y_pred"
    else:
        raise TypeError(f"Unsupported type for y_pred: {type(y_pred)}")
    if isinstance(y_proba, list):
        y_proba_array = np.array(y_proba)
    elif isinstance(y_proba, np.ndarray):
        y_proba_array = y_proba
    elif hasattr(y_proba, "to_pandas"):
        y_proba_pd = y_proba.to_pandas()
        if isinstance(y_proba_pd, pd.Series):
            y_proba_array = y_proba_pd.values
        else:
            y_proba_array = y_proba_pd.values
    elif isinstance(y_proba, pd.Series):
        y_proba_array = y_proba.values
    elif isinstance(y_proba, pd.DataFrame):
        y_proba_array = y_proba.values
    else:
        raise TypeError(f"Unsupported type for y_proba: {type(y_proba)}")

    def sanitize_class_name(class_name):
        """Convert class name to valid column name by replacing spaces and special chars"""
        return str(class_name).replace(" ", "_").replace("-", "_")

    if y_proba_array.ndim == 1:
        y_proba_df = pd.DataFrame({"proba": y_proba_array})
    else:
        n_classes = y_proba_array.shape[1]
        if len(class_names) == n_classes:
            proba_columns = [f"proba_{sanitize_class_name(cls)}" for cls in class_names]
        else:
            proba_columns = [f"proba_{i}" for i in range(n_classes)]
        y_proba_df = pd.DataFrame(y_proba_array, columns=proba_columns)
    y_proba_df = y_proba_df.reset_index(drop=True)
    X_df = X_df.reset_index(drop=True)
    y_true_series = y_true_series.reset_index(drop=True)
    y_pred_series = y_pred_series.reset_index(drop=True)
    is_false_prediction = pd.Series(
        y_true_series != y_pred_series, name="is_false_prediction"
    ).reset_index(drop=True)
    result_df = pd.concat(
        [X_df, y_proba_df, y_true_series, y_pred_series, is_false_prediction], axis=1
    )
    return result_df


def _get_feature_importance(model, feature_names=None, sort=True, top_n=None):
    """
    Extract feature importance from a Random Forest model.

    Parameters:
    -----------
    model : Fitted model
    feature_names : list or array-like, optional
        Names of features. If None, uses generic names like 'feature_0', 'feature_1', etc.
    sort : bool, default=True
        Whether to sort features by importance (descending)
    top_n : int, optional
        If specified, returns only the top N most important features

    Returns:
    --------
    pd.DataFrame
        DataFrame with columns: 'feature', 'importance'
        Importance values represent the mean decrease in impurity (Gini importance)
    """
    importances = model.feature_importances_
    if feature_names is None:
        feature_names = [f"feature_{i}" for i in range(len(importances))]
    importance_df = pd.DataFrame({"feature": feature_names, "importance": importances})
    if sort:
        importance_df = importance_df.sort_values("importance", ascending=False)
    importance_df = importance_df.reset_index(drop=True)
    if top_n is not None:
        importance_df = importance_df.head(top_n)
    return importance_df


def _smart_shap_background(
    X: Union[np.ndarray, pd.DataFrame],
    model_type: str = "tree",
    seed: int = 42,
    verbose: bool = True,
) -> Union[np.ndarray, pd.DataFrame, object]:
    """
    Intelligently prepares a background dataset for SHAP based on model type.

    Strategies:
    - Tree: Higher sample cap (1000), uses Random Sampling (preserves data structure).
    - Other: Lower sample cap (100), uses K-Means (maximizes info density).
    """
    (n_rows, n_features) = X.shape
    if model_type == "tree":
        max_samples = 1000
        use_kmeans = False
    else:
        max_samples = 100
        use_kmeans = True
    if n_rows <= max_samples:
        verbose and logger.info(
            f"✓ Dataset small ({n_rows} <= {max_samples}). Using full data."
        )
        return X
    verbose and logger.info(
        f"⚡ Large dataset detected ({n_rows} rows). Optimization Strategy: {('K-Means' if use_kmeans else 'Random Sampling')}"
    )
    if use_kmeans:
        try:
            verbose and logger.info(
                f"   Summarizing to {max_samples} weighted centroids..."
            )
            return shap.kmeans(X, max_samples)
        except Exception as e:
            logger.warning(
                f"   K-Means failed ({str(e)}). Falling back to random sampling."
            )
            return shap.sample(X, max_samples, random_state=seed)
    else:
        verbose and logger.info(f"   Sampling {max_samples} random rows...")
        return shap.sample(X, max_samples, random_state=seed)


def _get_class_mapping_from_encoder(encoder):
    """
    Returns a DataFrame showing the encoding:
      index → original class label

    Parameters
    ----------
    encoder : LabelEncoder
        The fitted LabelEncoder object (must have .classes_)
    """
    if encoder is None:
        return pd.DataFrame(columns=["index", "class"])
    if not hasattr(encoder, "classes_"):
        raise ValueError(
            "The provided encoder does not have .classes_ attribute. Was it fitted with LabelEncoder.fit() or fit_transform()?"
        )
    classes = encoder.classes_
    return pd.DataFrame({"index": range(len(classes)), "class": classes})


def _get_xgb_feature_types(X):
    """
    Generate feature_types array for XGBClassifier based on data types in X.

    Parameters:
    -----------
    X : pandas.DataFrame or numpy.ndarray
        Feature matrix for XGBClassifier

    Returns:
    --------
    list : List of feature types ('q' for quantitative, 'c' for categorical)

    Notes:
    ------
    - Numeric types (int, float) and boolean are treated as quantitative ('q')
    - Object and category types are treated as categorical ('c')
    - If X is a numpy array, all features are assumed quantitative
    """
    if isinstance(X, pd.DataFrame):
        feature_types = []
        for col in X.columns:
            dtype = X[col].dtype
            if dtype == "bool":
                feature_types.append("q")
            elif dtype == "object" or isinstance(dtype, pd.CategoricalDtype):
                feature_types.append("c")
            elif pd.api.types.is_numeric_dtype(dtype):
                feature_types.append("q")
            else:
                feature_types.append("q")
        return feature_types
    elif isinstance(X, np.ndarray):
        if X.dtype == bool:
            return ["c"] * X.shape[1]
        else:
            return ["q"] * X.shape[1]
    else:
        raise TypeError("X must be a pandas DataFrame or numpy ndarray")


def train_cla_xgboost(
    X: DataFrame, y: Union[DataSeries, NDArray, List], options=None
) -> Tuple[
    Any,
    DataFrame,
    Any,
    Any,
    Union[DataFrame, Dict],
    DataFrame,
    DataFrame,
    DataFrame,
    DataFrame,
    Dict,
]:
    options = options or {}
    n_estimators = options.get("n_estimators", 100)
    early_stopping = options.get("early_stopping", True)
    early_stopping_rounds = options.get("early_stopping_rounds", 10)
    use_dart = options.get("use_dart", False)
    max_depth = options.get("max_depth", 6)
    learning_rate = options.get("learning_rate", 0.3)
    reg_lambda = options.get("reg_lambda", 1.0)
    reg_alpha = options.get("reg_alpha", 0.0)
    gamma = options.get("gamma", 0.0)
    min_child_weight = options.get("min_child_weight", 1.0)
    subsample = options.get("subsample", 1.0)
    colsample_bytree = options.get("colsample_bytree", 1.0)
    auto_split = options.get("auto_split", True)
    test_val_size = options.get("test_val_size", 15) / 100
    shuffle_split = options.get("shuffle_split", True)
    stratify_split = options.get("stratify_split", True)
    retrain_on_full = options.get("retrain_on_full", False)
    custom_average_strategy = options.get("custom_average_strategy", "auto")
    use_cross_validation = options.get("use_cross_validation", False)
    cv_folds = options.get("cv_folds", 5)
    use_hpo = options.get("use_hyperparameter_optimization", False)
    optimization_metric = options.get("optimization_metric", "F1 Score")
    optimization_method = options.get("optimization_method", "Tree-structured Parzen")
    optimization_iterations = options.get("optimization_iterations", 50)
    pos_label_option = options.get("pos_label", "").strip()
    if pos_label_option == "":
        pos_label_option = None
    return_shap_explainer = options.get("return_shap_explainer", False)
    use_shap_sampler = options.get("use_shap_sampler", False)
    shap_feature_perturbation = options.get(
        "shap_feature_perturbation", "Interventional"
    )
    metrics_as = options.get("metrics_as", "Dataframe")
    n_jobs_str = options.get("n_jobs", "1")
    random_state = options.get("random_state", 42)
    activate_caching = options.get("activate_caching", False)
    verbose = options.get("verbose", True)
    n_jobs_int = -1 if n_jobs_str == "All" else int(n_jobs_str)
    skip_computation = False
    Model = None
    Metrics = pd.DataFrame()
    CV_Metrics = pd.DataFrame()
    Features_Importance = pd.DataFrame()
    Label_Encoder = None
    SHAP = None
    HPO_Trials = pd.DataFrame()
    HPO_Best = None
    accuracy = None
    precision = None
    recall = None
    f1 = None
    roc_auc = None
    (n_samples, n_features, sparsity, is_sparse) = _get_shape_and_sparsity(X)
    if activate_caching:
        verbose and logger.info(f"Caching is activate")
        data_hasher = _UniversalDatasetHasher(n_samples, verbose=verbose)
        X_hash = data_hasher.hash_data(X).hash
        y_hash = data_hasher.hash_data(y).hash
        all_hash_base_text = f"HASH BASE TEXTPandas Version {pd.__version__}POLARS Version {pl.__version__}Numpy Version {np.__version__}Scikit Learn Version {sklearn.__version__}Scipy Version {scipy.__version__}{('SHAP Version ' + shap.__version__ if return_shap_explainer else 'NO SHAP Version')}{X_hash}{y_hash}{n_estimators}{early_stopping}{early_stopping_rounds}{max_depth}{learning_rate}{reg_lambda}{reg_alpha}{gamma}{min_child_weight}{subsample}{colsample_bytree}{use_dart}{('Use HPO' if use_hpo else 'No HPO')}{(optimization_metric if use_hpo else 'No HPO Metric')}{(optimization_method if use_hpo else 'No HPO Method')}{(optimization_iterations if use_hpo else 'No HPO Iter')}{(cv_folds if use_cross_validation else 'No CV')}{('Auto Split' if auto_split else test_val_size)}{shuffle_split}{stratify_split}{return_shap_explainer}{shap_feature_perturbation}{use_shap_sampler}{random_state}{(pos_label_option if pos_label_option else 'default_pos')}{custom_average_strategy}"
        all_hash = hashlib.sha256(all_hash_base_text.encode("utf-8")).hexdigest()
        verbose and logger.info(f"Hash was computed: {all_hash}")
        temp_folder = Path(tempfile.gettempdir())
        cache_folder = temp_folder / "coded-flows-cache"
        cache_folder.mkdir(parents=True, exist_ok=True)
        model_path = cache_folder / f"{all_hash}.json"
        metrics_dict_path = cache_folder / f"metrics_{all_hash}.json"
        metrics_df_path = cache_folder / f"metrics_{all_hash}.parquet"
        cv_metrics_path = cache_folder / f"cv_metrics_{all_hash}.parquet"
        hpo_trials_path = cache_folder / f"hpo_trials_{all_hash}.parquet"
        hpo_best_params_path = cache_folder / f"hpo_best_params_{all_hash}.json"
        features_importance_path = (
            cache_folder / f"features_importance_{all_hash}.parquet"
        )
        prediction_set_path = cache_folder / f"prediction_set_{all_hash}.parquet"
        shap_path = cache_folder / f"{all_hash}.shap"
        label_encoder_path = cache_folder / f"{all_hash}.encoder"
        skip_computation = model_path.is_file()
    if not skip_computation:
        if hasattr(y, "nunique"):
            n_classes = y.nunique()
        elif hasattr(y, "n_unique"):
            n_classes = y.n_unique()
        else:
            n_classes = len(np.unique(y))
        is_multiclass = n_classes > 2
        features_names = X.columns if hasattr(X, "columns") else None
        shap_feature_names = _ensure_feature_names(X)
        Label_Encoder = LabelEncoder()
        y = Label_Encoder.fit_transform(y)
        booster = "dart" if use_dart else "gbtree"
        eval_metric = "mlogloss" if is_multiclass else "logloss"
        es_objective = "multi:softprob" if is_multiclass else "binary:logistic"
        feature_types = _get_xgb_feature_types(X)
        fixed_test_split = None if auto_split else test_val_size
        (X_train, X_test, y_train, y_test, val_ratio) = _smart_split(
            n_samples,
            X,
            y,
            random_state=random_state,
            shuffle=shuffle_split,
            stratify=y if stratify_split else None,
            fixed_test_split=fixed_test_split,
            verbose=verbose,
        )
        if custom_average_strategy == "auto":
            metric_average_strategy = _get_best_metric_average_strategy(y_test)
        else:
            metric_average_strategy = custom_average_strategy
        effective_pos_label = None
        if metric_average_strategy == "binary":
            unique_classes = np.unique(y_train)
            if pos_label_option is not None:
                if pos_label_option in unique_classes:
                    effective_pos_label = pos_label_option
                else:
                    try:
                        as_int = int(pos_label_option)
                        if as_int in unique_classes:
                            effective_pos_label = as_int
                    except ValueError:
                        pass
            elif 1 in unique_classes:
                effective_pos_label = 1
            elif "1" in unique_classes:
                effective_pos_label = "1"
            if effective_pos_label is not None:
                verbose and logger.info(f"Using positive label: {effective_pos_label}")
            elif effective_pos_label is None and metric_average_strategy == "binary":
                error_message = 'The target appears to be binary, but no positive label was provided and no "1" class exists in the label set.'
                verbose and logger.error(error_message)
                raise ValueError(error_message)
        if use_hpo:
            verbose and logger.info(f"Performing Hyperparameters Optimization")
            constant_hyperparameters = {
                "booster": booster,
                "feature_types": feature_types,
            }
            (HPO_Best, HPO_Trials) = _hyperparameters_optimization(
                X_train,
                y_train,
                constant_hyperparameters,
                optimization_metric,
                metric_average_strategy,
                val_ratio,
                shuffle_split,
                stratify_split,
                use_cross_validation,
                cv_folds,
                optimization_iterations,
                "maximize",
                optimization_method,
                random_state,
                n_jobs_int,
                verbose=verbose,
                pos_label=effective_pos_label,
            )
            HPO_Trials = _normalize_hpo_df(HPO_Trials)
            n_estimators = HPO_Best["n_estimators"]
            max_depth = HPO_Best["max_depth"]
            learning_rate = HPO_Best["learning_rate"]
            reg_lambda = HPO_Best["reg_lambda"]
            reg_alpha = HPO_Best["reg_alpha"]
            gamma = HPO_Best["gamma"]
            min_child_weight = HPO_Best["min_child_weight"]
            subsample = HPO_Best["subsample"]
            colsample_bytree = HPO_Best["colsample_bytree"]
        model_params = {}
        model_params["n_estimators"] = n_estimators
        model_params["max_depth"] = max_depth
        model_params["learning_rate"] = learning_rate
        model_params["reg_lambda"] = reg_lambda
        model_params["reg_alpha"] = reg_alpha
        model_params["gamma"] = gamma
        model_params["min_child_weight"] = min_child_weight
        model_params["subsample"] = subsample
        model_params["colsample_bytree"] = colsample_bytree
        if early_stopping and (not use_hpo):
            model_params["early_stopping_rounds"] = early_stopping_rounds
            model_params["objective"] = es_objective
            model_params["eval_metric"] = eval_metric
        Model = _AutoBalanceXGBClassifier(
            **model_params,
            feature_types=feature_types,
            tree_method="hist",
            enable_categorical=True,
            booster=booster,
            random_state=random_state,
            n_jobs=n_jobs_int,
        )
        if early_stopping and (not use_hpo):
            (X_train, X_val, y_train, y_val) = train_test_split(
                X_train,
                y_train,
                test_size=val_ratio,
                random_state=random_state,
                stratify=y_train if stratify_split else None,
                shuffle=shuffle_split,
            )
            Model.fit(X_train, y_train, eval_set=[(X_val, y_val)], verbose=False)
            model_params["n_estimators"] = Model.best_iteration + 1
        else:
            Model.fit(X_train, y_train, verbose=False)
        if use_cross_validation and (not use_hpo):
            verbose and logger.info(
                f"Using Cross-Validation to measure performance metrics"
            )
            cv_params = model_params.copy()
            cv_params.pop("early_stopping_rounds", None)
            cv_params.pop("objective", None)
            cv_params.pop("eval_metric", None)
            CV_Model = _AutoBalanceXGBClassifier(
                **cv_params,
                feature_types=feature_types,
                tree_method="hist",
                enable_categorical=True,
                booster=booster,
                random_state=random_state,
                n_jobs=n_jobs_int,
            )
            CV_Metrics = _perform_cross_validation(
                CV_Model,
                X_train,
                y_train,
                cv_folds,
                metric_average_strategy,
                shuffle_split,
                random_state,
                n_jobs_int,
                verbose,
                pos_label=effective_pos_label,
            )
        y_pred = Model.predict(X_test)
        if is_multiclass:
            y_score = Model.predict_proba(X_test)
        elif effective_pos_label is not None:
            try:
                pos_idx = list(Model.classes_).index(effective_pos_label)
                y_score = Model.predict_proba(X_test)[:, pos_idx]
            except ValueError:
                y_score = Model.predict_proba(X_test)[:, 1]
        else:
            y_score = Model.predict_proba(X_test)[:, 1]
        score_params = {"average": metric_average_strategy, "zero_division": 0}
        if effective_pos_label:
            score_params["pos_label"] = effective_pos_label
        accuracy = accuracy_score(y_test, y_pred)
        precision = precision_score(y_test, y_pred, **score_params)
        recall = recall_score(y_test, y_pred, **score_params)
        f1 = f1_score(y_test, y_pred, **score_params)
        if is_multiclass:
            roc_auc = roc_auc_score(
                y_test, y_score, multi_class="ovr", average=metric_average_strategy
            )
        else:
            roc_auc = roc_auc_score(y_test, y_score)
        if metrics_as == "Dataframe":
            Metrics = pd.DataFrame(
                {
                    "Metric": [
                        "Accuracy",
                        "Precision",
                        "Recall",
                        "F1-Score",
                        "ROC AUC",
                    ],
                    "Value": [accuracy, precision, recall, f1, roc_auc],
                }
            )
        else:
            Metrics = {
                "Accuracy": accuracy,
                "Precision": precision,
                "Recall": recall,
                "F1-Score": f1,
                "ROC AUC": roc_auc,
            }
        verbose and logger.info(f"Accuracy  : {accuracy:.4f}")
        verbose and logger.info(f"Precision : {precision:.4f}")
        verbose and logger.info(f"Recall    : {recall:.4f}")
        verbose and logger.info(f"F1-Score  : {f1:.4f}")
        verbose and logger.info(f"ROC-AUC   : {roc_auc:.4f}")
        Prediction_Set = _combine_test_data(
            X_test, y_test, y_pred, y_score, Model.classes_, features_names
        )
        verbose and logger.info(f"Prediction Set created")
        if retrain_on_full:
            verbose and logger.info(
                "Retraining model on full dataset for production deployment"
            )
            if early_stopping:
                model_params.pop("early_stopping_rounds", None)
                model_params.pop("objective", None)
                model_params.pop("eval_metric", None)
                Model = _AutoBalanceXGBClassifier(
                    **model_params,
                    feature_types=feature_types,
                    tree_method="hist",
                    enable_categorical=True,
                    booster=booster,
                    random_state=random_state,
                    n_jobs=n_jobs_int,
                )
            Model.fit(X, y, verbose=False)
            verbose and logger.info(
                "Model successfully retrained on full dataset. Reported metrics remain from original held-out test set."
            )
        Features_Importance = _get_feature_importance(Model, features_names)
        verbose and logger.info(f"Features Importance computed")
        if return_shap_explainer:
            if shap_feature_perturbation == "Interventional":
                SHAP = shap.TreeExplainer(
                    Model,
                    (
                        _smart_shap_background(
                            X if retrain_on_full else X_train,
                            model_type="tree",
                            seed=random_state,
                            verbose=verbose,
                        )
                        if use_shap_sampler
                        else X if retrain_on_full else X_train
                    ),
                    feature_names=shap_feature_names,
                )
            else:
                SHAP = shap.TreeExplainer(
                    Model,
                    feature_names=shap_feature_names,
                    feature_perturbation="tree_path_dependent",
                )
            verbose and logger.info(f"SHAP explainer generated")
        if activate_caching:
            verbose and logger.info(f"Caching output elements")
            Model.save_model(model_path)
            if isinstance(Metrics, dict):
                with metrics_dict_path.open("w", encoding="utf-8") as f:
                    json.dump(Metrics, f, ensure_ascii=False, indent=4)
            else:
                Metrics.to_parquet(metrics_df_path)
            if use_cross_validation and (not use_hpo):
                CV_Metrics.to_parquet(cv_metrics_path)
            if use_hpo:
                HPO_Trials.to_parquet(hpo_trials_path)
                with hpo_best_params_path.open("w", encoding="utf-8") as f:
                    json.dump(HPO_Best, f, ensure_ascii=False, indent=4)
            Features_Importance.to_parquet(features_importance_path)
            Prediction_Set.to_parquet(prediction_set_path)
            if return_shap_explainer:
                with shap_path.open("wb") as f:
                    joblib.dump(SHAP, f)
            joblib.dump(Label_Encoder, label_encoder_path)
            verbose and logger.info(f"Caching done")
    else:
        verbose and logger.info(f"Skipping computations and loading cached elements")
        Model = _AutoBalanceXGBClassifier()
        Model.load_model(model_path)
        verbose and logger.info(f"Model loaded")
        if metrics_dict_path.is_file():
            with metrics_dict_path.open("r", encoding="utf-8") as f:
                Metrics = json.load(f)
        else:
            Metrics = pd.read_parquet(metrics_df_path)
        verbose and logger.info(f"Metrics loaded")
        if use_cross_validation and (not use_hpo):
            CV_Metrics = pd.read_parquet(cv_metrics_path)
            verbose and logger.info(f"Cross Validation metrics loaded")
        if use_hpo:
            HPO_Trials = pd.read_parquet(hpo_trials_path)
            with hpo_best_params_path.open("r", encoding="utf-8") as f:
                HPO_Best = json.load(f)
            verbose and logger.info(
                f"Hyperparameters Optimization trials and best params loaded"
            )
        Features_Importance = pd.read_parquet(features_importance_path)
        verbose and logger.info(f"Features Importance loaded")
        Prediction_Set = pd.read_parquet(prediction_set_path)
        verbose and logger.info(f"Prediction Set loaded")
        if return_shap_explainer:
            with shap_path.open("rb") as f:
                SHAP = joblib.load(f)
            verbose and logger.info(f"SHAP Explainer loaded")
        Label_Encoder = joblib.load(label_encoder_path)
        verbose and logger.info(f"Label Encoder loaded")
    Model_Classes = _get_class_mapping_from_encoder(Label_Encoder)
    return (
        Model,
        Model_Classes,
        SHAP,
        Label_Encoder,
        Metrics,
        CV_Metrics,
        Features_Importance,
        Prediction_Set,
        HPO_Trials,
        HPO_Best,
    )

Brick Info

version v0.1.4
python 3.11, 3.12, 3.13
requirements
  • shap>=0.47.0
  • scikit-learn
  • pandas
  • numpy
  • torch
  • numba>=0.56.0
  • shap
  • xgboost
  • cmaes
  • optuna
  • scipy
  • polars
  • xxhash