Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions chispa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
assert_approx_df_equality,
assert_df_equality,
)
from .flatten import flatten_dataframe
from .rows_comparer import assert_basic_rows_equality
from .schema_comparer import SchemasNotEqualError

Expand Down Expand Up @@ -73,4 +74,5 @@ def assert_df_equality(
"assert_basic_rows_equality",
"assert_column_equality",
"assert_df_equality",
"flatten_dataframe",
)
44 changes: 44 additions & 0 deletions chispa/flatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from pyspark.sql.functions import col, explode_outer, map_keys
from pyspark.sql.types import ArrayType, MapType, StructType

if TYPE_CHECKING:
from pyspark.sql import DataFrame
from pyspark.sql.types import DataType


def _complex_fields(schema: StructType) -> dict[str, DataType]:
return {
field.name: field.dataType
for field in schema.fields
if isinstance(field.dataType, (StructType, ArrayType, MapType))
}


def flatten_dataframe(df: DataFrame, sep: str = "_") -> DataFrame:
if sep == ".":
raise ValueError("`sep` must not be '.', it conflicts with Spark's struct field accessor")

remaining = _complex_fields(df.schema)
while remaining:
col_name, dtype = next(iter(remaining.items()))

if isinstance(dtype, StructType):
expanded = [col(f"`{col_name}`.`{f.name}`").alias(f"{col_name}{sep}{f.name}") for f in dtype.fields]
df = df.select("*", *expanded).drop(col_name)

elif isinstance(dtype, ArrayType):
df = df.withColumn(col_name, explode_outer(col_name))

elif isinstance(dtype, MapType):
keys_rows = df.select(explode_outer(map_keys(col(col_name))).alias("k")).distinct().collect()
keys = [row["k"] for row in keys_rows if row["k"] is not None]
key_cols = [col(col_name).getItem(k).alias(f"{col_name}{sep}{k}") for k in keys]
df = df.select(*[c for c in df.columns if c != col_name], *key_cols)

remaining = _complex_fields(df.schema)

return df