diff --git a/README.md b/README.md index 7b452e9..f508f26 100644 --- a/README.md +++ b/README.md @@ -309,6 +309,10 @@ This library requires you to set a flag to consider two NaN values to be equal. assert_df_equality(df1, df2, allow_nan_equality=True) ``` +### Compare columns on error + +If this argument is set to an integer n, the first n columns will be compared seperately in case the Dataframes are not equal. The diff of the n columns will each be printed out, before the original error is raised. This improves readability for Dataframes with many columns, especially if important information like ID or timestamp are among the first columns. + ## Approximate column equality We can check if columns are approximately equal, which is especially useful for floating number comparisons. diff --git a/chispa/dataframe_comparer.py b/chispa/dataframe_comparer.py index 9023f49..ea4c983 100644 --- a/chispa/dataframe_comparer.py +++ b/chispa/dataframe_comparer.py @@ -12,7 +12,7 @@ class DataFramesNotEqualError(Exception): def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_nan_equality=False, - ignore_column_order=False, ignore_row_order=False, ignore_schema=False): + ignore_column_order=False, ignore_row_order=False, ignore_schema=False, compare_columns_on_error=None): if transforms is None: transforms = [] if ignore_column_order: @@ -23,6 +23,8 @@ def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_n df2 = reduce(lambda acc, fn: fn(acc), transforms, df2) if not ignore_schema: assert_schema_equality(df1.schema, df2.schema, ignore_nullable) + if compare_columns_on_error: + assert_generic_rows_equality(df1, df2, are_rows_equal_enhanced, [True], compare_columns_on_error) if allow_nan_equality: assert_generic_rows_equality(df1, df2, are_rows_equal_enhanced, [True]) else: @@ -42,7 +44,7 @@ def assert_approx_df_equality(df1, df2, precision, ignore_nullable=False): assert_generic_rows_equality(df1, df2, are_rows_approx_equal, [precision]) -def assert_generic_rows_equality(df1, df2, row_equality_fun, row_equality_fun_args): +def assert_generic_rows_equality(df1, df2, row_equality_fun, row_equality_fun_args, compare_columns_on_error): df1_rows = df1.collect() df2_rows = df2.collect() zipped = list(six.moves.zip_longest(df1_rows, df2_rows)) @@ -62,7 +64,15 @@ def assert_generic_rows_equality(df1, df2, row_equality_fun, row_equality_fun_ar else: allRowsEqual = False t.add_row([r1, r2]) - if allRowsEqual == False: + if allRowsEqual == False and not compare_columns_on_error: + raise DataFramesNotEqualError("\n" + t.get_string()) + if allRowsEqual == False and compare_columns_on_error: + for name in df1.schema.names[0:compare_columns_on_error]: + try: + assert_df_equality(df1.select(name), df2.select(name), ignore_row_order=True) + except DataFramesNotEqualError as e: + print(e) + continue raise DataFramesNotEqualError("\n" + t.get_string())