Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ This library requires you to set a flag to consider two NaN values to be equal.
assert_df_equality(df1, df2, allow_nan_equality=True)
```

### Compare columns on error

If this argument is set to an integer n, the first n columns will be compared seperately in case the Dataframes are not equal. The diff of the n columns will each be printed out, before the original error is raised. This improves readability for Dataframes with many columns, especially if important information like ID or timestamp are among the first columns.

## Approximate column equality

We can check if columns are approximately equal, which is especially useful for floating number comparisons.
Expand Down
16 changes: 13 additions & 3 deletions chispa/dataframe_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class DataFramesNotEqualError(Exception):


def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_nan_equality=False,
ignore_column_order=False, ignore_row_order=False, ignore_schema=False):
ignore_column_order=False, ignore_row_order=False, ignore_schema=False, compare_columns_on_error=None):
if transforms is None:
transforms = []
if ignore_column_order:
Expand All @@ -23,6 +23,8 @@ def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_n
df2 = reduce(lambda acc, fn: fn(acc), transforms, df2)
if not ignore_schema:
assert_schema_equality(df1.schema, df2.schema, ignore_nullable)
if compare_columns_on_error:
assert_generic_rows_equality(df1, df2, are_rows_equal_enhanced, [True], compare_columns_on_error)
if allow_nan_equality:
assert_generic_rows_equality(df1, df2, are_rows_equal_enhanced, [True])
else:
Expand All @@ -42,7 +44,7 @@ def assert_approx_df_equality(df1, df2, precision, ignore_nullable=False):
assert_generic_rows_equality(df1, df2, are_rows_approx_equal, [precision])


def assert_generic_rows_equality(df1, df2, row_equality_fun, row_equality_fun_args):
def assert_generic_rows_equality(df1, df2, row_equality_fun, row_equality_fun_args, compare_columns_on_error):
df1_rows = df1.collect()
df2_rows = df2.collect()
zipped = list(six.moves.zip_longest(df1_rows, df2_rows))
Expand All @@ -62,7 +64,15 @@ def assert_generic_rows_equality(df1, df2, row_equality_fun, row_equality_fun_ar
else:
allRowsEqual = False
t.add_row([r1, r2])
if allRowsEqual == False:
if allRowsEqual == False and not compare_columns_on_error:
raise DataFramesNotEqualError("\n" + t.get_string())
if allRowsEqual == False and compare_columns_on_error:
for name in df1.schema.names[0:compare_columns_on_error]:
try:
assert_df_equality(df1.select(name), df2.select(name), ignore_row_order=True)
except DataFramesNotEqualError as e:
print(e)
continue
raise DataFramesNotEqualError("\n" + t.get_string())


Expand Down