Scikit Object Save
Saves Scikit-learn objects (encoders, scalers, pipelines) with metadata into a portable archive.
Scikit Object Save
Processing
This brick saves your trained Scikit-learn models, preprocessors, or pipelines into a portable file or folder. It solves the problem of "how do I keep this model to use later?" by packaging the model alongside a metadata.json file.
The brick strictly enforces Scikit-learn compatibility (it accepts Estimators, Transformers, and Pipelines). It handles versioning automatically (e.g., creating _v1, _v2) so you never accidentally overwrite previous work. The result is a standardized archive containing the serialized binary and context about the environment in which it was created.
Inputs
- element to save
- The Scikit-learn object you want to store. This is usually a trained model (like a Random Forest) or a fitted data transformer (like a Standard Scaler). The brick will validate that this is a supported Scikit-learn object before saving.
- directory
- The destination folder where the file or archive will be created.
Inputs Types
| Input | Types |
|---|---|
element to save |
Any |
directory |
Str, DirectoryPath |
You can check the list of supported types here: Available Type Hints.
Outputs
- file path
- The full path to the saved result. Depending on your settings, this points to either a
.zipfile or a specific folder containing the model artifacts.
Outputs Types
| Output | Types |
|---|---|
file path |
Str, FilePath |
You can check the list of supported types here: Available Type Hints.
Options
The Scikit Object Save brick contains some changeable options:
- Filename Prefix
- A custom text label added to the beginning of the filename (default is "sk_object"). This helps you identify what the model is (e.g., "customer_churn_model").
- Serialization Backend
- The technical method used to convert the object into a file.
- joblib: (Default) Efficient for objects containing large NumPy arrays. Best for most Scikit-learn models.
- pickle: The standard Python object serialization method. Use this if you have specific compatibility needs outside of the standard Scikit-learn ecosystem.
- Include Date (YYYYMMDD)
- If enabled, adds the current date to the filename (e.g.,
_20231025). Useful for tracking when a model was trained. - Include Time (HHMMSS)
- If enabled, adds the current time to the filename.
- Auto-Increment Version
- Prevents overwriting files. If enabled, the brick checks the directory and automatically appends
_v1,_v2, etc., to ensure a unique filename. - Output Format
- Determines how the data is packaged.
- zip: Compresses the model and metadata into a single
.zipfile. Best for sharing or downloading. - directory: Saves the model and metadata as loose files inside a new folder. Best if you need immediate access to the internal files.
- Return as Path Object
- If enabled, returns the output as a Python
pathlib.Pathobject instead of a standard string. Keep this disabled unless your workflow specifically requires Path objects. - Verbose
- If enabled, detailed logs about the saving process (detection, paths, versioning) will be printed to the console.
import re
import json
import shutil
import joblib
import pickle
import zipfile
import logging
import pathlib
import tempfile
import datetime
import importlib.metadata
from coded_flows.types import Union, Str, Any, FilePath, DirectoryPath
from coded_flows.utils import CodedFlowsLogger
logger = CodedFlowsLogger(name="Scikit Object Save", level=logging.INFO)
def _to_snake_case(name: str) -> str:
"""Converts CamelCase to snake_case for standardized naming."""
name = re.sub("(.)([A-Z][a-z]+)", "\\1_\\2", name)
return re.sub("([a-z0-9])([A-Z])", "\\1_\\2", name).lower()
def _get_library_version(library_name: str) -> str:
"""Retrieves the installed version of the library safely."""
try:
return importlib.metadata.version(library_name)
except importlib.metadata.PackageNotFoundError:
return "unknown"
def _detect_object_context(obj: Any) -> dict:
"""
Analyzes the object to extract context for metadata.
Strictly enforces Scikit-learn objects only.
"""
obj_type = type(obj)
module_path = obj_type.__module__
class_name = obj_type.__name__
snake_name = _to_snake_case(class_name)
is_sklearn = False
if "sklearn" in module_path:
is_sklearn = True
if not is_sklearn:
try:
from sklearn.base import BaseEstimator
if isinstance(obj, BaseEstimator):
is_sklearn = True
except ImportError:
pass
if not is_sklearn:
raise ValueError(
f"The object '{class_name}' (from module '{module_path}') is not recognized as a valid Scikit-learn object. This function only supports Scikit-learn estimators, transformers, and pipelines."
)
version = _get_library_version("scikit-learn")
return {
"class_name": class_name,
"module_path": module_path,
"base_name": snake_name,
"library": "scikit-learn",
"library_version": version,
}
def _get_next_version_index(
directory: pathlib.Path, base_name: str, is_archive: bool
) -> int:
"""Finds the next integer version based on existing files in the directory."""
ext_pattern = "\\.zip" if is_archive else ""
pattern = re.compile(f"^{re.escape(base_name)}_v(\\d+){ext_pattern}$")
max_v = 0
if directory.exists():
for item in directory.iterdir():
match = pattern.match(item.name)
if match:
current_v = int(match.group(1))
if current_v > max_v:
max_v = current_v
return max_v + 1
def save_scikit_object(
element_to_save: Any, directory: Union[Str, DirectoryPath], options: dict = None
) -> Union[Str, FilePath]:
options = options or {}
verbose = options.get("verbose", True)
custom_prefix = options.get("custom_prefix", "sk_object")
backend = options.get("backend", "joblib")
include_date = options.get("include_date", False)
include_time = options.get("include_time", False)
use_versioning = options.get("use_versioning", True)
archive_format = options.get("archive_format", "zip")
return_as_pathlib = options.get("return_as_pathlib", False)
save_dir = pathlib.Path(directory)
result_path = None
if element_to_save is None:
verbose and logger.error("Input element is None.")
raise ValueError("The input element to save is None.")
with tempfile.TemporaryDirectory() as temp_dir_str:
temp_dir = pathlib.Path(temp_dir_str)
try:
if not save_dir.exists():
verbose and logger.info(f"Creating output directory: {save_dir}")
save_dir.mkdir(parents=True, exist_ok=True)
context = _detect_object_context(element_to_save)
verbose and logger.info(
f"Detected {context['library']} object: {context['class_name']}"
)
base_name_parts = [custom_prefix, context["base_name"]]
now = datetime.datetime.now()
if include_date:
base_name_parts.append(now.strftime("%Y%m%d"))
if include_time:
base_name_parts.append(now.strftime("%H%M%S"))
clean_base_name = "_".join(filter(None, base_name_parts))
version_str = ""
if use_versioning:
next_v = _get_next_version_index(
save_dir, clean_base_name, archive_format == "zip"
)
version_str = f"_v{next_v}"
final_name = f"{clean_base_name}{version_str}"
object_ext = ".joblib" if backend == "joblib" else ".pkl"
object_filename = f"object{object_ext}"
object_temp_path = temp_dir / object_filename
if backend == "joblib":
joblib.dump(element_to_save, object_temp_path)
elif backend == "pickle":
with open(object_temp_path, "wb") as f:
pickle.dump(element_to_save, f)
else:
raise ValueError(f"Unsupported backend: {backend}")
metadata = {
"id": final_name,
"timestamp": now.isoformat(),
"library": context["library"],
"library_version": context["library_version"],
"class_name": context["class_name"],
"module_path": context["module_path"],
"serialization_backend": backend,
"python_version": importlib.metadata.sys.version.split()[0],
"artifacts": {"object": object_filename},
}
metadata_path = temp_dir / "metadata.json"
with open(metadata_path, "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=4)
verbose and logger.info("Object artifacts and metadata prepared.")
if archive_format == "zip":
output_zip_path = save_dir / f"{final_name}.zip"
with zipfile.ZipFile(output_zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
zf.write(object_temp_path, arcname=object_filename)
zf.write(metadata_path, arcname="metadata.json")
result_path = output_zip_path
verbose and logger.info(f"Saved portable archive to: {result_path}")
else:
output_folder_path = save_dir / final_name
if output_folder_path.exists():
shutil.rmtree(output_folder_path)
shutil.copytree(temp_dir, output_folder_path)
result_path = output_folder_path
verbose and logger.info(f"Saved object directory to: {result_path}")
except Exception as e:
verbose and logger.error(f"Failed to save object: {e}")
raise e
file_path = result_path if return_as_pathlib else str(result_path)
return file_path
Brick Info
- shap>=0.47.0
- joblib
- numba>=0.56.0
- scikit-learn