Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions chispa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pyspark.sql import DataFrame

from chispa.dataframe_transformer import flatten_dataframe
from chispa.default_formats import DefaultFormats
from chispa.formatting import Color, Format, FormattingConfig, Style

Expand Down Expand Up @@ -73,4 +74,5 @@ def assert_df_equality(
"assert_basic_rows_equality",
"assert_column_equality",
"assert_df_equality",
"flatten_dataframe",
)
119 changes: 119 additions & 0 deletions chispa/dataframe_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from __future__ import annotations

from pyspark.sql import Column, DataFrame
from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, MapType, StructType


def flatten_dataframe(df: DataFrame, sep: str = "_") -> DataFrame:
"""Flatten a nested DataFrame by expanding StructType, ArrayType, and MapType columns.

This function recursively flattens nested structures in a DataFrame, converting
complex types into flat columns with names separated by the specified delimiter.

Parameters
----------
df : DataFrame
The input DataFrame to flatten.
sep : str, optional
Delimiter for flattened column names. Default is "_".
Note: Do not use "." as the separator, as it won't work correctly with
nested DataFrames with more than one level.

Returns
-------
DataFrame
A flattened DataFrame with all nested structures expanded.

Notes
-----
- StructType fields are expanded into individual columns
- ArrayType fields are exploded to add array elements as rows
- MapType fields are expanded by extracting all keys as columns
- Flattening MapType requires finding every key in the column, which can be slow
- The function processes fields iteratively until no complex types remain

Examples
--------
Flatten a DataFrame with nested struct fields:

>>> data = [
... {"id": 1, "name": "Cole", "fitness": {"height": 130, "weight": 60}},
... {"id": 2, "name": "Faye", "fitness": {"height": 130, "weight": 60}},
... ]
>>> df = spark.createDataFrame(data)
>>> flat_df = flatten_dataframe(df, sep=":")
>>> flat_df.columns
['id', 'name', 'fitness:height', 'fitness:weight']

Flatten a DataFrame with map fields:

>>> data = [
... {"state": "Florida", "info": {"governor": "Rick Scott"}},
... {"state": "Ohio", "info": {"governor": "John Kasich"}},
... ]
>>> df = spark.createDataFrame(data)
>>> flat_df = flatten_dataframe(df, sep=":")
>>> flat_df.columns
['state', 'info:governor']

Flatten a DataFrame with array fields:

>>> data = [
... {"name": "John", "scores": [85, 90, 95]},
... {"name": "Jane", "scores": [88, 92, 94]},
... ]
>>> df = spark.createDataFrame(data)
>>> flat_df = flatten_dataframe(df)
>>> "scores" in flat_df.columns
False
"""
# Compute Complex Fields (Arrays, Structs and MapTypes) in Schema
complex_fields: dict[str, StructType | ArrayType | MapType] = dict(
[
(field.name, field.dataType)
for field in df.schema.fields
if isinstance(field.dataType, ArrayType | StructType | MapType)
]
)

while len(complex_fields) != 0:
col_name = list(complex_fields.keys())[0]

# If StructType then convert all sub-element to columns.
# i.e. flatten structs
if isinstance(complex_fields[col_name], StructType):
expanded = [
F.col(col_name + "." + k).alias(col_name + sep + k)
for k in [n.name for n in complex_fields[col_name]]
]
df = df.select("*", *expanded).drop(col_name)

# If ArrayType then add the Array Elements as Rows using the explode function
# i.e. explode Arrays
elif isinstance(complex_fields[col_name], ArrayType):
df = df.withColumn(col_name, F.explode_outer(col_name))

# If MapType then convert all sub-element to columns.
# i.e. flatten maps
elif isinstance(complex_fields[col_name], MapType):
keys_df = df.select(F.explode_outer(F.map_keys(F.col(col_name)))).distinct()
keys = [row[0] for row in keys_df.collect()]
key_cols = [
F.col(col_name).getItem(f).alias(str(col_name + sep + f)) for f in keys
]
drop_column_list = [col_name]
df = df.select(
[c for c in df.columns if c not in drop_column_list] + key_cols
)

# Recompute remaining Complex Fields in Schema
complex_fields = dict(
[
(field.name, field.dataType)
for field in df.schema.fields
if isinstance(field.dataType, ArrayType | StructType | MapType)
]
)

return df
195 changes: 195 additions & 0 deletions tests/test_dataframe_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from __future__ import annotations

import pytest
from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, IntegerType, LongType, MapType, StringType, StructField, StructType

from chispa.dataframe_transformer import flatten_dataframe


def describe_flatten_dataframe():
def it_flattens_struct_fields(spark: SparkSession):
data = [
{"id": 1, "name": "Cole Volk", "fitness": {"height": 130, "weight": 60}},
{"name": "Mark Reg", "fitness": {"height": 130, "weight": 60}},
{"id": 2, "name": "Faye Raker", "fitness": {"height": 130, "weight": 60}},
]
df = spark.createDataFrame(data)

flat_df = flatten_dataframe(df, sep=":")

assert "fitness:height" in flat_df.columns
assert "fitness:weight" in flat_df.columns
assert "fitness" not in flat_df.columns

def it_flattens_map_fields(spark: SparkSession):
data = [
{"state": "Florida", "shortname": "FL", "info": {"governor": "Rick Scott"}},
{"state": "Ohio", "shortname": "OH", "info": {"governor": "John Kasich"}},
]
df = spark.createDataFrame(data)

flat_df = flatten_dataframe(df, sep=":")

assert "info:governor" in flat_df.columns
assert "info" not in flat_df.columns
assert "state" in flat_df.columns
assert "shortname" in flat_df.columns

def it_flattens_array_fields(spark: SparkSession):
data = [
{"name": "John", "scores": [85, 90, 95]},
{"name": "Jane", "scores": [88, 92, 94]},
]
df = spark.createDataFrame(data)

flat_df = flatten_dataframe(df)

# Arrays are exploded, so the original column is removed
assert "scores" not in flat_df.columns

def it_flattens_mixed_complex_types(spark: SparkSession):
data_mixed = [
{
"state": "Florida",
"shortname": "FL",
"info": {"governor": "Rick Scott"},
"counties": [
{"name": "Dade", "population": 12345},
{"name": "Broward", "population": 40000},
{"name": "Palm Beach", "population": 60000},
],
},
{
"state": "Ohio",
"shortname": "OH",
"info": {"governor": "John Kasich"},
"counties": [
{"name": "Summit", "population": 1234},
{"name": "Cuyahoga", "population": 1337},
],
},
]
df = spark.createDataFrame(data_mixed)

flat_df = flatten_dataframe(df, sep=":")

# Check that complex fields are flattened
assert "info:governor" in flat_df.columns
assert "info" not in flat_df.columns
# Arrays are exploded
assert "counties" not in flat_df.columns
# Simple fields remain
assert "state" in flat_df.columns
assert "shortname" in flat_df.columns

def it_uses_default_separator(spark: SparkSession):
data = [{"id": 1, "name": "Cole", "fitness": {"height": 130, "weight": 60}}]
df = spark.createDataFrame(data)

flat_df = flatten_dataframe(df)

assert "fitness_height" in flat_df.columns
assert "fitness_weight" in flat_df.columns

def it_uses_custom_separator(spark: SparkSession):
data = [{"id": 1, "name": "Cole", "fitness": {"height": 130, "weight": 60}}]
df = spark.createDataFrame(data)

flat_df = flatten_dataframe(df, sep=":")

assert "fitness:height" in flat_df.columns
assert "fitness:weight" in flat_df.columns

def it_preserves_simple_fields(spark: SparkSession):
data = [
{"id": 1, "name": "Cole", "age": 25},
{"id": 2, "name": "Jane", "age": 30},
]
df = spark.createDataFrame(data)

flat_df = flatten_dataframe(df)

assert set(flat_df.columns) == {"id", "name", "age"}

def it_handles_nested_structs(spark: SparkSession):
data = [
((("James", None, "Smith"),), "OH", "M"),
(("Anna", "Rose", ""), "NY", "F"),
]
schema = StructType(
[
StructField(
"name",
StructType(
[
StructField("firstname", StringType(), True),
StructField("middlename", StringType(), True),
StructField("lastname", StringType(), True),
]
),
),
StructField("state", StringType(), True),
StructField("gender", StringType(), True),
]
)
df = spark.createDataFrame(data, schema)

flat_df = flatten_dataframe(df, sep=":")

assert "name:firstname" in flat_df.columns
assert "name:middlename" in flat_df.columns
assert "name:lastname" in flat_df.columns
assert "name" not in flat_df.columns
assert "state" in flat_df.columns
assert "gender" in flat_df.columns

def it_handles_empty_dataframe(spark: SparkSession):
data = []
df = spark.createDataFrame(data, "id INT, name STRING")

flat_df = flatten_dataframe(df)

assert set(flat_df.columns) == {"id", "name"}

def it_handles_dataframe_with_only_simple_fields(spark: SparkSession):
data = [{"id": 1, "name": "Cole"}, {"id": 2, "name": "Jane"}]
df = spark.createDataFrame(data)

flat_df = flatten_dataframe(df)

assert set(flat_df.columns) == {"id", "name"}

def it_handles_map_with_multiple_keys(spark: SparkSession):
data = [
{"state": "FL", "info": {"governor": "Rick Scott", "capital": "Tallahassee"}},
{"state": "OH", "info": {"governor": "John Kasich", "capital": "Columbus"}},
]
df = spark.createDataFrame(data)

flat_df = flatten_dataframe(df, sep=":")

assert "info:governor" in flat_df.columns
assert "info:capital" in flat_df.columns
assert "info" not in flat_df.columns

def it_handles_struct_with_multiple_fields(spark: SparkSession):
data = [
{
"person": {"first": "John", "middle": "Q", "last": "Doe"},
"age": 30,
},
{
"person": {"first": "Jane", "middle": "A", "last": "Smith"},
"age": 25,
},
]
df = spark.createDataFrame(data)

flat_df = flatten_dataframe(df, sep="_")

assert "person_first" in flat_df.columns
assert "person_middle" in flat_df.columns
assert "person_last" in flat_df.columns
assert "person" not in flat_df.columns
assert "age" in flat_df.columns