From a870a563aeeaf57a29b3e5beb28ce8ac77898d88 Mon Sep 17 00:00:00 2001 From: Vlad Slavlotski Date: Sun, 3 May 2026 16:42:40 +0500 Subject: [PATCH] add flatten-dataframe --- chispa/__init__.py | 2 ++ chispa/flatten.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 chispa/flatten.py diff --git a/chispa/__init__.py b/chispa/__init__.py index 57b35c6..423da94 100644 --- a/chispa/__init__.py +++ b/chispa/__init__.py @@ -17,6 +17,7 @@ assert_approx_df_equality, assert_df_equality, ) +from .flatten import flatten_dataframe from .rows_comparer import assert_basic_rows_equality from .schema_comparer import SchemasNotEqualError @@ -73,4 +74,5 @@ def assert_df_equality( "assert_basic_rows_equality", "assert_column_equality", "assert_df_equality", + "flatten_dataframe", ) diff --git a/chispa/flatten.py b/chispa/flatten.py new file mode 100644 index 0000000..15d66b9 --- /dev/null +++ b/chispa/flatten.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pyspark.sql.functions import col, explode_outer, map_keys +from pyspark.sql.types import ArrayType, MapType, StructType + +if TYPE_CHECKING: + from pyspark.sql import DataFrame + from pyspark.sql.types import DataType + + +def _complex_fields(schema: StructType) -> dict[str, DataType]: + return { + field.name: field.dataType + for field in schema.fields + if isinstance(field.dataType, (StructType, ArrayType, MapType)) + } + + +def flatten_dataframe(df: DataFrame, sep: str = "_") -> DataFrame: + if sep == ".": + raise ValueError("`sep` must not be '.', it conflicts with Spark's struct field accessor") + + remaining = _complex_fields(df.schema) + while remaining: + col_name, dtype = next(iter(remaining.items())) + + if isinstance(dtype, StructType): + expanded = [col(f"`{col_name}`.`{f.name}`").alias(f"{col_name}{sep}{f.name}") for f in dtype.fields] + df = df.select("*", *expanded).drop(col_name) + + elif isinstance(dtype, ArrayType): + df = df.withColumn(col_name, explode_outer(col_name)) + + elif isinstance(dtype, MapType): + keys_rows = df.select(explode_outer(map_keys(col(col_name))).alias("k")).distinct().collect() + keys = [row["k"] for row in keys_rows if row["k"] is not None] + key_cols = [col(col_name).getItem(k).alias(f"{col_name}{sep}{k}") for k in keys] + df = df.select(*[c for c in df.columns if c != col_name], *key_cols) + + remaining = _complex_fields(df.schema) + + return df