Scikit Object Load

Restores a Scikit-learn object from a ZIP archive or directory created by the Save Scikit Object brick.

Scikit Object Load

Processing

This brick restores a Scikit-learn machine learning model or data transformer that was previously saved using the "Save Scikit Object" brick. It reads a ZIP archive or a directory, extracts the necessary files, and reconstructs the object (like a trained classifier or a data scaler) exactly as it was when saved. This allows you to reuse trained models in different workflows without retraining them.

Inputs

path
The location of the saved Scikit-learn object. This must be a path to a .zip file or a directory that contains the required metadata.json and object artifacts (typically created by the "Save Scikit Object" brick).

Inputs Types

Input Types
path Str, FilePath, DirectoryPath
options Dict

You can check the list of supported types here: Available Type Hints.

Outputs

loaded object
The fully restored Scikit-learn object.

Outputs Types

Output Types
loaded object Any

You can check the list of supported types here: Available Type Hints.

Options

The Scikit Object Load brick contains some changeable options:

Verbose
Controls the level of detail logged during the loading process.
import json
import joblib
import pickle
import zipfile
import logging
import pathlib
import tempfile
from coded_flows.types import Union, Str, Any, FilePath, DirectoryPath, Dict
from coded_flows.utils import CodedFlowsLogger

logger = CodedFlowsLogger(name="Scikit Object Load", level=logging.INFO)


def _validate_scikit_object(obj: Any) -> bool:
    """
    Validates if the loaded object is a recognizable Scikit-learn object
    (either from the sklearn module or inheriting from BaseEstimator).
    """
    obj_type = type(obj)
    module_path = obj_type.__module__
    if "sklearn" in module_path:
        return True
    try:
        from sklearn.base import BaseEstimator

        if isinstance(obj, BaseEstimator):
            return True
    except ImportError:
        pass
    return False


def load_scikit_object(
    path: Union[Str, FilePath, DirectoryPath], options: Dict = None
) -> Any:
    """
    Loads a Scikit-learn object from a ZIP file or directory using metadata to determine
    the correct serialization backend (joblib or pickle).
    """
    options = options or {}
    verbose = options.get("verbose", True)
    loaded_object = None
    if not path:
        verbose and logger.error("No path provided.")
        raise ValueError("Path to object (ZIP or Directory) is required.")
    input_path = pathlib.Path(path)
    if not input_path.exists():
        verbose and logger.error(f"Path does not exist: {input_path}")
        raise FileNotFoundError(f"Object path not found: {input_path}")
    verbose and logger.info(f"Attempting to load object from: {input_path}")
    with tempfile.TemporaryDirectory() as temp_dir_str:
        work_dir = pathlib.Path(temp_dir_str)
        try:
            if input_path.is_file() and input_path.suffix.lower() == ".zip":
                verbose and logger.info("Detected ZIP archive. Extracting...")
                with zipfile.ZipFile(input_path, "r") as zf:
                    zf.extractall(work_dir)
            elif input_path.is_dir():
                work_dir = input_path
            else:
                raise ValueError(
                    f"Unsupported input type: {input_path}. Must be a directory or .zip file."
                )
            metadata_path = work_dir / "metadata.json"
            if not metadata_path.exists():
                raise FileNotFoundError(
                    "metadata.json not found. The provided path is not a valid Scikit Object archive."
                )
            with open(metadata_path, "r", encoding="utf-8") as f:
                metadata = json.load(f)
            class_name = metadata.get("class_name", "Unknown")
            library = metadata.get("library", "Unknown")
            backend = metadata.get("serialization_backend", "joblib")
            object_filename = metadata.get("artifacts", {}).get("object")
            verbose and logger.info(
                f"Metadata loaded. Library: {library}, Class: {class_name}, Backend: {backend}"
            )
            if library != "scikit-learn":
                verbose and logger.warning(
                    f"Metadata indicates library is '{library}', but this brick expects 'scikit-learn'."
                )
            if not object_filename:
                raise ValueError("Metadata is missing artifact filename information.")
            object_file_path = work_dir / object_filename
            if not object_file_path.exists():
                raise FileNotFoundError(
                    f"Object artifact '{object_filename}' missing from archive."
                )
            if backend == "joblib":
                loaded_object = joblib.load(object_file_path)
            elif backend == "pickle":
                with open(object_file_path, "rb") as f:
                    loaded_object = pickle.load(f)
            else:
                raise ValueError(
                    f"Unsupported serialization backend specified in metadata: {backend}"
                )
            if not _validate_scikit_object(loaded_object):
                type_name = type(loaded_object).__name__
                raise ValueError(
                    f"The loaded object '{type_name}' does not appear to be a valid Scikit-learn estimator or transformer."
                )
            verbose and logger.info(
                f"Successfully loaded and validated Scikit-learn object: {type(loaded_object).__name__}"
            )
        except Exception as e:
            verbose and logger.error(f"Failed to load object: {e}")
            raise e
    return loaded_object

Brick Info

version v0.1.4
python 3.11, 3.12, 3.13
requirements
  • shap>=0.47.0
  • joblib
  • numba>=0.56.0
  • scikit-learn