From 1b35fe438429a6a8ab3f21e0dbe61e69f81ca6f4 Mon Sep 17 00:00:00 2001 From: Mister Lobster Date: Sun, 3 May 2026 08:34:50 -0400 Subject: [PATCH] feat: add flatten_dataframe function for nested DataFrames Add flatten_dataframe function to recursively flatten nested structures in DataFrames, including StructType, ArrayType, and MapType columns. - Added flatten_dataframe function in dataframe_transformer.py - Function supports custom separator for flattened column names - Handles StructType by expanding sub-elements to columns - Handles ArrayType by exploding arrays to rows - Handles MapType by extracting all keys as columns - Added comprehensive test suite with 12 test cases - Added function to public API in __init__.py Closes #47 --- chispa/__init__.py | 2 + chispa/dataframe_transformer.py | 119 +++++++++++++++++ tests/test_dataframe_transformer.py | 195 ++++++++++++++++++++++++++++ 3 files changed, 316 insertions(+) create mode 100644 chispa/dataframe_transformer.py create mode 100644 tests/test_dataframe_transformer.py diff --git a/chispa/__init__.py b/chispa/__init__.py index 57b35c6..b3be133 100644 --- a/chispa/__init__.py +++ b/chispa/__init__.py @@ -4,6 +4,7 @@ from pyspark.sql import DataFrame +from chispa.dataframe_transformer import flatten_dataframe from chispa.default_formats import DefaultFormats from chispa.formatting import Color, Format, FormattingConfig, Style @@ -73,4 +74,5 @@ def assert_df_equality( "assert_basic_rows_equality", "assert_column_equality", "assert_df_equality", + "flatten_dataframe", ) diff --git a/chispa/dataframe_transformer.py b/chispa/dataframe_transformer.py new file mode 100644 index 0000000..5a02772 --- /dev/null +++ b/chispa/dataframe_transformer.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from pyspark.sql import Column, DataFrame +from pyspark.sql import functions as F +from pyspark.sql.types import ArrayType, MapType, StructType + + +def flatten_dataframe(df: DataFrame, sep: str = "_") -> DataFrame: + """Flatten a nested DataFrame by expanding StructType, ArrayType, and MapType columns. + + This function recursively flattens nested structures in a DataFrame, converting + complex types into flat columns with names separated by the specified delimiter. + + Parameters + ---------- + df : DataFrame + The input DataFrame to flatten. + sep : str, optional + Delimiter for flattened column names. Default is "_". + Note: Do not use "." as the separator, as it won't work correctly with + nested DataFrames with more than one level. + + Returns + ------- + DataFrame + A flattened DataFrame with all nested structures expanded. + + Notes + ----- + - StructType fields are expanded into individual columns + - ArrayType fields are exploded to add array elements as rows + - MapType fields are expanded by extracting all keys as columns + - Flattening MapType requires finding every key in the column, which can be slow + - The function processes fields iteratively until no complex types remain + + Examples + -------- + Flatten a DataFrame with nested struct fields: + + >>> data = [ + ... {"id": 1, "name": "Cole", "fitness": {"height": 130, "weight": 60}}, + ... {"id": 2, "name": "Faye", "fitness": {"height": 130, "weight": 60}}, + ... ] + >>> df = spark.createDataFrame(data) + >>> flat_df = flatten_dataframe(df, sep=":") + >>> flat_df.columns + ['id', 'name', 'fitness:height', 'fitness:weight'] + + Flatten a DataFrame with map fields: + + >>> data = [ + ... {"state": "Florida", "info": {"governor": "Rick Scott"}}, + ... {"state": "Ohio", "info": {"governor": "John Kasich"}}, + ... ] + >>> df = spark.createDataFrame(data) + >>> flat_df = flatten_dataframe(df, sep=":") + >>> flat_df.columns + ['state', 'info:governor'] + + Flatten a DataFrame with array fields: + + >>> data = [ + ... {"name": "John", "scores": [85, 90, 95]}, + ... {"name": "Jane", "scores": [88, 92, 94]}, + ... ] + >>> df = spark.createDataFrame(data) + >>> flat_df = flatten_dataframe(df) + >>> "scores" in flat_df.columns + False + """ + # Compute Complex Fields (Arrays, Structs and MapTypes) in Schema + complex_fields: dict[str, StructType | ArrayType | MapType] = dict( + [ + (field.name, field.dataType) + for field in df.schema.fields + if isinstance(field.dataType, ArrayType | StructType | MapType) + ] + ) + + while len(complex_fields) != 0: + col_name = list(complex_fields.keys())[0] + + # If StructType then convert all sub-element to columns. + # i.e. flatten structs + if isinstance(complex_fields[col_name], StructType): + expanded = [ + F.col(col_name + "." + k).alias(col_name + sep + k) + for k in [n.name for n in complex_fields[col_name]] + ] + df = df.select("*", *expanded).drop(col_name) + + # If ArrayType then add the Array Elements as Rows using the explode function + # i.e. explode Arrays + elif isinstance(complex_fields[col_name], ArrayType): + df = df.withColumn(col_name, F.explode_outer(col_name)) + + # If MapType then convert all sub-element to columns. + # i.e. flatten maps + elif isinstance(complex_fields[col_name], MapType): + keys_df = df.select(F.explode_outer(F.map_keys(F.col(col_name)))).distinct() + keys = [row[0] for row in keys_df.collect()] + key_cols = [ + F.col(col_name).getItem(f).alias(str(col_name + sep + f)) for f in keys + ] + drop_column_list = [col_name] + df = df.select( + [c for c in df.columns if c not in drop_column_list] + key_cols + ) + + # Recompute remaining Complex Fields in Schema + complex_fields = dict( + [ + (field.name, field.dataType) + for field in df.schema.fields + if isinstance(field.dataType, ArrayType | StructType | MapType) + ] + ) + + return df diff --git a/tests/test_dataframe_transformer.py b/tests/test_dataframe_transformer.py new file mode 100644 index 0000000..0636653 --- /dev/null +++ b/tests/test_dataframe_transformer.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +import pytest +from pyspark.sql import SparkSession +from pyspark.sql.types import ArrayType, IntegerType, LongType, MapType, StringType, StructField, StructType + +from chispa.dataframe_transformer import flatten_dataframe + + +def describe_flatten_dataframe(): + def it_flattens_struct_fields(spark: SparkSession): + data = [ + {"id": 1, "name": "Cole Volk", "fitness": {"height": 130, "weight": 60}}, + {"name": "Mark Reg", "fitness": {"height": 130, "weight": 60}}, + {"id": 2, "name": "Faye Raker", "fitness": {"height": 130, "weight": 60}}, + ] + df = spark.createDataFrame(data) + + flat_df = flatten_dataframe(df, sep=":") + + assert "fitness:height" in flat_df.columns + assert "fitness:weight" in flat_df.columns + assert "fitness" not in flat_df.columns + + def it_flattens_map_fields(spark: SparkSession): + data = [ + {"state": "Florida", "shortname": "FL", "info": {"governor": "Rick Scott"}}, + {"state": "Ohio", "shortname": "OH", "info": {"governor": "John Kasich"}}, + ] + df = spark.createDataFrame(data) + + flat_df = flatten_dataframe(df, sep=":") + + assert "info:governor" in flat_df.columns + assert "info" not in flat_df.columns + assert "state" in flat_df.columns + assert "shortname" in flat_df.columns + + def it_flattens_array_fields(spark: SparkSession): + data = [ + {"name": "John", "scores": [85, 90, 95]}, + {"name": "Jane", "scores": [88, 92, 94]}, + ] + df = spark.createDataFrame(data) + + flat_df = flatten_dataframe(df) + + # Arrays are exploded, so the original column is removed + assert "scores" not in flat_df.columns + + def it_flattens_mixed_complex_types(spark: SparkSession): + data_mixed = [ + { + "state": "Florida", + "shortname": "FL", + "info": {"governor": "Rick Scott"}, + "counties": [ + {"name": "Dade", "population": 12345}, + {"name": "Broward", "population": 40000}, + {"name": "Palm Beach", "population": 60000}, + ], + }, + { + "state": "Ohio", + "shortname": "OH", + "info": {"governor": "John Kasich"}, + "counties": [ + {"name": "Summit", "population": 1234}, + {"name": "Cuyahoga", "population": 1337}, + ], + }, + ] + df = spark.createDataFrame(data_mixed) + + flat_df = flatten_dataframe(df, sep=":") + + # Check that complex fields are flattened + assert "info:governor" in flat_df.columns + assert "info" not in flat_df.columns + # Arrays are exploded + assert "counties" not in flat_df.columns + # Simple fields remain + assert "state" in flat_df.columns + assert "shortname" in flat_df.columns + + def it_uses_default_separator(spark: SparkSession): + data = [{"id": 1, "name": "Cole", "fitness": {"height": 130, "weight": 60}}] + df = spark.createDataFrame(data) + + flat_df = flatten_dataframe(df) + + assert "fitness_height" in flat_df.columns + assert "fitness_weight" in flat_df.columns + + def it_uses_custom_separator(spark: SparkSession): + data = [{"id": 1, "name": "Cole", "fitness": {"height": 130, "weight": 60}}] + df = spark.createDataFrame(data) + + flat_df = flatten_dataframe(df, sep=":") + + assert "fitness:height" in flat_df.columns + assert "fitness:weight" in flat_df.columns + + def it_preserves_simple_fields(spark: SparkSession): + data = [ + {"id": 1, "name": "Cole", "age": 25}, + {"id": 2, "name": "Jane", "age": 30}, + ] + df = spark.createDataFrame(data) + + flat_df = flatten_dataframe(df) + + assert set(flat_df.columns) == {"id", "name", "age"} + + def it_handles_nested_structs(spark: SparkSession): + data = [ + ((("James", None, "Smith"),), "OH", "M"), + (("Anna", "Rose", ""), "NY", "F"), + ] + schema = StructType( + [ + StructField( + "name", + StructType( + [ + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), + ] + ), + ), + StructField("state", StringType(), True), + StructField("gender", StringType(), True), + ] + ) + df = spark.createDataFrame(data, schema) + + flat_df = flatten_dataframe(df, sep=":") + + assert "name:firstname" in flat_df.columns + assert "name:middlename" in flat_df.columns + assert "name:lastname" in flat_df.columns + assert "name" not in flat_df.columns + assert "state" in flat_df.columns + assert "gender" in flat_df.columns + + def it_handles_empty_dataframe(spark: SparkSession): + data = [] + df = spark.createDataFrame(data, "id INT, name STRING") + + flat_df = flatten_dataframe(df) + + assert set(flat_df.columns) == {"id", "name"} + + def it_handles_dataframe_with_only_simple_fields(spark: SparkSession): + data = [{"id": 1, "name": "Cole"}, {"id": 2, "name": "Jane"}] + df = spark.createDataFrame(data) + + flat_df = flatten_dataframe(df) + + assert set(flat_df.columns) == {"id", "name"} + + def it_handles_map_with_multiple_keys(spark: SparkSession): + data = [ + {"state": "FL", "info": {"governor": "Rick Scott", "capital": "Tallahassee"}}, + {"state": "OH", "info": {"governor": "John Kasich", "capital": "Columbus"}}, + ] + df = spark.createDataFrame(data) + + flat_df = flatten_dataframe(df, sep=":") + + assert "info:governor" in flat_df.columns + assert "info:capital" in flat_df.columns + assert "info" not in flat_df.columns + + def it_handles_struct_with_multiple_fields(spark: SparkSession): + data = [ + { + "person": {"first": "John", "middle": "Q", "last": "Doe"}, + "age": 30, + }, + { + "person": {"first": "Jane", "middle": "A", "last": "Smith"}, + "age": 25, + }, + ] + df = spark.createDataFrame(data) + + flat_df = flatten_dataframe(df, sep="_") + + assert "person_first" in flat_df.columns + assert "person_middle" in flat_df.columns + assert "person_last" in flat_df.columns + assert "person" not in flat_df.columns + assert "age" in flat_df.columns