Group Data

Group data by specified columns and apply aggregation functions (distinct, max, min, avg, median, sum, std, count, first, last, concat).

Group Data

Processing

This brick groups the input data (DataFrame or Arrow Table) by specified columns and applies various aggregation functions (distinct, max, min, avg, median, sum, std, count, first, last, concat). It validates column types against requested aggregations and supports customizable output formats (Pandas, Polars, or Arrow).

Inputs

data
The input dataset (Pandas DataFrame, Polars DataFrame, or Arrow Table) to be grouped and aggregated.

Inputs Types

Input Types
data DataFrame, ArrowTable

You can check the list of supported types here: Available Type Hints.

Outputs

data
The resulting dataset containing the grouped columns and the aggregated values, returned in the format specified by the output options.

Outputs Types

Output Types
data DataFrame, ArrowTable

You can check the list of supported types here: Available Type Hints.

Options

The Group Data brick contains some changeable options:

Group By Columns
A list of column names used to define the groups. The aggregation functions will be applied within these groups.
Column Aggregations
A list of key-value pairs where the key is the column name to aggregate, and the value is the aggregation function to apply (e.g., distinct, max, min, avg, median, sum, std, count, first, last, concat).
Concatenation Separator
The string separator used when applying the concat aggregation function to string columns.
Output Format
Specifies the desired format for the output data structure (pandas, polars, or arrow). Defaults to pandas.
Safe Mode
If enabled, invalid operations (e.g., column not found, aggregation incompatible with column type) will be logged and skipped instead of raising an immediate error. Defaults to False.
Verbose
If enabled, displays detailed logging information during the execution of the grouping operation. Defaults to True.
import logging
import duckdb
import pandas as pd
import polars as pl
import pyarrow as pa
from coded_flows.types import Union, List, DataFrame, ArrowTable, Str

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def _sanitize_identifier(identifier):
    """
    Sanitize SQL identifier by escaping special characters.
    """
    return identifier.replace('"', '""')


def _is_numeric_type(duckdb_type):
    """
    Check if a DuckDB type is numeric.
    """
    type_lower = duckdb_type.lower()
    return any(
        (
            t in type_lower
            for t in [
                "tinyint",
                "smallint",
                "integer",
                "bigint",
                "int",
                "float",
                "double",
                "real",
                "decimal",
                "numeric",
                "hugeint",
                "ubigint",
                "uinteger",
                "usmallint",
                "utinyint",
            ]
        )
    )


def _is_string_type(duckdb_type):
    """
    Check if a DuckDB type is string/text.
    """
    type_lower = duckdb_type.lower()
    return any(
        (t in type_lower for t in ["varchar", "text", "string", "char", "bpchar"])
    )


def _validate_aggregation_type(
    column_type, aggregation, column_name, safe_mode, verbose, brick_display_name
):
    """
    Validate if an aggregation is compatible with a column type.
    Returns True if valid, False if invalid but safe_mode, raises error otherwise.
    """
    numeric_only = ["avg", "median", "sum", "std"]
    string_only = ["concat"]
    is_numeric = _is_numeric_type(column_type)
    is_string = _is_string_type(column_type)
    if aggregation in numeric_only and (not is_numeric):
        msg = f"Aggregation '{aggregation}' requires numeric column, but '{column_name}' is type '{column_type}'"
        if safe_mode:
            verbose and logger.warning(
                f"[{brick_display_name}] {msg}. Skipping in safe mode."
            )
            return False
        else:
            verbose and logger.error(f"[{brick_display_name}] {msg}.")
            raise TypeError(msg)
    if aggregation in string_only and (not is_string):
        msg = f"Aggregation '{aggregation}' requires string column, but '{column_name}' is type '{column_type}'"
        if safe_mode:
            verbose and logger.warning(
                f"[{brick_display_name}] {msg}. Skipping in safe mode."
            )
            return False
        else:
            verbose and logger.error(f"[{brick_display_name}] {msg}.")
            raise TypeError(msg)
    return True


def _build_aggregation_expression(column, aggregation, column_type, separator):
    """
    Build SQL aggregation expression for a column.
    """
    sanitized_col = _sanitize_identifier(column)
    agg_map = {
        "distinct": f'COUNT(DISTINCT "{sanitized_col}")',
        "max": f'MAX("{sanitized_col}")',
        "min": f'MIN("{sanitized_col}")',
        "avg": f'AVG("{sanitized_col}")',
        "median": f'MEDIAN("{sanitized_col}")',
        "sum": f'SUM("{sanitized_col}")',
        "std": f'STDDEV("{sanitized_col}")',
        "count": f'COUNT("{sanitized_col}")',
        "first": f'FIRST("{sanitized_col}")',
        "last": f'LAST("{sanitized_col}")',
        "concat": f"""STRING_AGG("{sanitized_col}"::VARCHAR, '{separator}')""",
    }
    return agg_map.get(aggregation, f'COUNT("{sanitized_col}")')


def group_data(
    data: Union[DataFrame, ArrowTable], options=None
) -> Union[DataFrame, ArrowTable]:
    brick_display_name = "Group Data"
    options = options or {}
    verbose = options.get("verbose", True)
    group_by_columns = options.get("group_by_columns", [])
    aggregations = options.get("aggregations", [])
    concat_separator = options.get("concat_separator", ", ")
    output_format = options.get("output_format", "pandas")
    safe_mode = options.get("safe_mode", False)
    result = None
    conn = None
    try:
        if not isinstance(group_by_columns, list):
            verbose and logger.error(
                f"[{brick_display_name}] Invalid group_by_columns format! Expected a list."
            )
            raise ValueError("Group by columns must be provided as a list!")
        if not group_by_columns:
            verbose and logger.error(
                f"[{brick_display_name}] No grouping columns specified!"
            )
            raise ValueError("At least one grouping column must be specified!")
        if not isinstance(aggregations, list):
            verbose and logger.error(
                f"[{brick_display_name}] Invalid aggregations format! Expected a list."
            )
            raise ValueError("Aggregations must be provided as a list!")
        if not aggregations:
            verbose and logger.error(
                f"[{brick_display_name}] No aggregations specified!"
            )
            raise ValueError("At least one aggregation must be specified!")
        verbose and logger.info(
            f"[{brick_display_name}] Starting grouping operation with {len(group_by_columns)} grouping column(s) and {len(aggregations)} aggregation(s)."
        )
        data_type = None
        if isinstance(data, pd.DataFrame):
            data_type = "pandas"
        elif isinstance(data, pl.DataFrame):
            data_type = "polars"
        elif isinstance(data, (pa.Table, pa.lib.Table)):
            data_type = "arrow"
        if data_type is None:
            verbose and logger.error(
                f"[{brick_display_name}] Input data must be a pandas DataFrame, Polars DataFrame, or Arrow Table"
            )
            raise ValueError(
                "Input data must be a pandas DataFrame, Polars DataFrame, or Arrow Table"
            )
        verbose and logger.info(
            f"[{brick_display_name}] Detected input format: {data_type}."
        )
        conn = duckdb.connect(":memory:")
        conn.register("input_table", data)
        column_info = conn.execute("DESCRIBE input_table").fetchall()
        all_columns = {col[0]: col[1] for col in column_info}
        verbose and logger.info(
            f"[{brick_display_name}] Total columns in data: {len(all_columns)}."
        )
        missing_group_cols = [col for col in group_by_columns if col not in all_columns]
        if missing_group_cols:
            if safe_mode:
                group_by_columns = [
                    col for col in group_by_columns if col in all_columns
                ]
                verbose and logger.warning(
                    f"[{brick_display_name}] Safe mode: Skipped missing group by columns: {missing_group_cols}"
                )
                if not group_by_columns:
                    verbose and logger.error(
                        f"[{brick_display_name}] No valid group by columns remaining after filtering!"
                    )
                    raise ValueError("No valid group by columns found!")
            else:
                verbose and logger.error(
                    f"[{brick_display_name}] Group by columns not found in data: {missing_group_cols}"
                )
                raise ValueError(
                    f"Group by columns not found in data: {missing_group_cols}"
                )
        group_by_parts = []
        for col in group_by_columns:
            sanitized_col = _sanitize_identifier(col)
            group_by_parts.append(f'"{sanitized_col}"')
        agg_expressions = []
        skipped_aggs = []
        for agg_spec in aggregations:
            if (
                not isinstance(agg_spec, dict)
                or "key" not in agg_spec
                or "value" not in agg_spec
            ):
                verbose and logger.warning(
                    f"[{brick_display_name}] Invalid aggregation specification: {agg_spec}. Skipping."
                )
                continue
            column_name = agg_spec["key"]
            agg_function = agg_spec["value"]
            if column_name not in all_columns:
                if safe_mode:
                    verbose and logger.warning(
                        f"[{brick_display_name}] Safe mode: Column '{column_name}' not found. Skipping aggregation."
                    )
                    skipped_aggs.append(f"{column_name}_{agg_function}")
                    continue
                else:
                    verbose and logger.error(
                        f"[{brick_display_name}] Column '{column_name}' not found in data!"
                    )
                    raise ValueError(f"Column '{column_name}' not found in data!")
            column_type = all_columns[column_name]
            if not _validate_aggregation_type(
                column_type,
                agg_function,
                column_name,
                safe_mode,
                verbose,
                brick_display_name,
            ):
                skipped_aggs.append(f"{column_name}_{agg_function}")
                continue
            agg_expr = _build_aggregation_expression(
                column_name, agg_function, column_type, concat_separator
            )
            output_col_name = f"{column_name}_{agg_function}"
            sanitized_output = _sanitize_identifier(output_col_name)
            agg_expressions.append(f'{agg_expr} AS "{sanitized_output}"')
            verbose and logger.info(
                f"[{brick_display_name}] Added aggregation: {agg_function} on column '{column_name}'."
            )
        if not agg_expressions:
            verbose and logger.error(
                f"[{brick_display_name}] No valid aggregations to perform after validation!"
            )
            raise ValueError("No valid aggregations to perform!")
        if skipped_aggs:
            verbose and logger.warning(
                f"[{brick_display_name}] Skipped {len(skipped_aggs)} incompatible aggregations: {skipped_aggs}"
            )
        select_parts = group_by_parts + agg_expressions
        select_clause = ", ".join(select_parts)
        group_by_clause = ", ".join(group_by_parts)
        query = f"SELECT {select_clause} FROM input_table GROUP BY {group_by_clause}"
        verbose and logger.info(
            f"[{brick_display_name}] Executing grouping query with {len(group_by_columns)} group column(s) and {len(agg_expressions)} aggregation(s)."
        )
        if output_format == "pandas":
            result = conn.execute(query).df()
            verbose and logger.info(
                f"[{brick_display_name}] Converted result to pandas DataFrame."
            )
        elif output_format == "polars":
            result = conn.execute(query).pl()
            verbose and logger.info(
                f"[{brick_display_name}] Converted result to Polars DataFrame."
            )
        elif output_format == "arrow":
            result = conn.execute(query).fetch_arrow_table()
            verbose and logger.info(
                f"[{brick_display_name}] Converted result to Arrow Table."
            )
        else:
            verbose and logger.error(
                f"[{brick_display_name}] Unsupported output format: {output_format}"
            )
            raise ValueError(f"Unsupported output format: {output_format}")
        verbose and logger.info(
            f"[{brick_display_name}] Grouping operation completed successfully. Result has {len(result)} row(s)."
        )
    except Exception as e:
        verbose and logger.error(
            f"[{brick_display_name}] Error during grouping operation: {str(e)}"
        )
        raise
    finally:
        if conn is not None:
            conn.close()
    return result

Brick Info

version v0.1.3
python 3.10, 3.11, 3.12, 3.13
requirements
  • pandas
  • polars[pyarrow]
  • duckdb
  • pyarrow