ML Datasets

Load popular ML datasets (Sklearn & SHAP). Automatically maps numeric targets to class names where possible.

ML Datasets

Processing

This brick allows you to instantly generate standard datasets commonly used for machine learning practice, benchmarking, and testing. Instead of importing external files, you can select from a library of famous datasets (such as the Iris flower dataset, Census income data, or Housing prices).

Inputs

This brick acts as a data source. It does not require any input connections from previous steps.

Outputs

data
A structured table containing the full dataset. It includes all feature columns (e.g., age, height, pixel values) and a specific target column (often named "target" or the specific class name) containing the labels or values to be predicted.

Outputs Types

Output Types
data DataFrame

You can check the list of supported types here: Available Type Hints.

Options

The ML Datasets brick contains some changeable options:

Choose Dataset
Determines which specific dataset to generate.
  • Adult (Census): A classic dataset for Classification. Predicts whether a person makes over $50K a year based on census data.
  • Breast Cancer: Classification data used to predict whether a tumor is malignant or benign based on cell characteristics.
  • California Housing Prices: A standard Regression dataset. Predicts the median house value for California districts.
  • Communities & Crime: Regression data used to predict violent crime rates in US communities.
  • Correlated Groups 60: A dataset primarily used for analyzing feature correlations.
  • Covertype (Forest Cover): Classification data to predict the forest cover type from cartographic variables.
  • Diabetes: Regression data. Measures disease progression one year after baseline.
  • Handwritten Digits: Image Classification data. Contains 8x8 pixel grids representing handwritten numbers (0-9).
  • IMDB Reviews: Text/NLP dataset. Contains 25,000 movie reviews labeled by sentiment (positive/negative).
  • Iris: The most famous Classification dataset. Predicts the species of Iris flowers based on petal and sepal measurements.
  • Linnerud: Multi-output Regression. Contains physiological and exercise data.
  • Wine Recognition: Classification data. distincts between three different cultivars of wines grown in Italy.
Verbose Output
Controls the amount of logging information.
import logging
import pandas as pd
import shap
from sklearn import datasets
from coded_flows.types import DataFrame, Dict
from coded_flows.utils import CodedFlowsLogger

logger = CodedFlowsLogger(name="ML Datasets", level=logging.INFO)


def load_ml_datasets(options=None) -> DataFrame:
    """
    Loads selected dataset, normalizes it to a single DataFrame, and maps target IDs to names if available.
    """
    options = options or {}
    verbose = options.get("verbose", True)
    human_name = options.get("dataset_name", "Iris")
    data = pd.DataFrame()
    human_to_key = {
        "Adult (Census)": "shap_adult",
        "Breast Cancer": "sklearn_breast_cancer",
        "California Housing Prices": "sklearn_california",
        "Communities & Crime": "shap_communities",
        "Correlated Groups 60": "shap_corrgroups",
        "Covertype (Forest Cover)": "sklearn_covtype",
        "Diabetes": "sklearn_diabetes",
        "Handwritten Digits": "sklearn_digits",
        "IMDB Reviews": "shap_imdb",
        "Iris": "sklearn_iris",
        "Linnerud": "sklearn_linnerud",
        "Wine Recognition": "sklearn_wine",
    }
    dataset_key = human_to_key.get(human_name)
    try:
        verbose and logger.info(f"Preparing to load dataset: '{human_name}'")
        if not dataset_key:
            raise ValueError(f"Unknown dataset selected: {human_name}")
        if dataset_key.startswith("sklearn_"):
            sklearn_registry = {
                "sklearn_iris": datasets.load_iris,
                "sklearn_digits": datasets.load_digits,
                "sklearn_wine": datasets.load_wine,
                "sklearn_breast_cancer": datasets.load_breast_cancer,
                "sklearn_diabetes": datasets.load_diabetes,
                "sklearn_linnerud": datasets.load_linnerud,
                "sklearn_california": datasets.fetch_california_housing,
                "sklearn_covtype": datasets.fetch_covtype,
            }
            loader = sklearn_registry[dataset_key]
            bunch = loader(as_frame=True)
            data = bunch.frame.copy()
            if hasattr(bunch, "target_names") and bunch.target_names is not None:
                target_col = "target"
                if target_col in data.columns:
                    if pd.api.types.is_numeric_dtype(data[target_col]):
                        verbose and logger.info(
                            f"Mapping numeric targets to names: {bunch.target_names}"
                        )
                        name_map = {
                            i: name for (i, name) in enumerate(bunch.target_names)
                        }
                        data[target_col] = (
                            data[target_col].map(name_map).fillna(data[target_col])
                        )
        elif dataset_key.startswith("shap_"):
            (X, y) = (None, None)
            if dataset_key == "shap_adult":
                (X, y) = shap.datasets.adult(display=True)
            elif dataset_key == "shap_communities":
                (X, y) = shap.datasets.communitiesandcrime()
            elif dataset_key == "shap_corrgroups":
                (X, y) = shap.datasets.corrgroups60()
            elif dataset_key == "shap_imdb":
                (raw_X, y) = shap.datasets.imdb()
                X = pd.DataFrame(raw_X, columns=["text"])
            if not isinstance(X, pd.DataFrame):
                X = pd.DataFrame(X)
            data = X.copy()
            target_col_name = "target"
            if isinstance(y, (pd.Series, pd.DataFrame)):
                data[target_col_name] = y.values
            else:
                data[target_col_name] = y
        verbose and logger.info(
            f"Loaded '{human_name}'. Shape: {data.shape[0]} rows, {data.shape[1]} cols."
        )
    except Exception as e:
        verbose and logger.error(f"Failed to load dataset '{human_name}'")
        raise RuntimeError(f"Error loading dataset: {str(e)}") from e
    return data

Brick Info

version v0.1.4
python 3.11, 3.12, 3.13
requirements
  • shap>=0.47.0
  • scikit-learn
  • pandas
  • numba>=0.56.0
  • shap