diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..3bdf6e2 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,17 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## Unreleased + +### Changed +- `DataFramesNotEqualError` changed to `RowsNotEqualError` to reflect it being raised when testing for row equality. +- The assertion functions `assert_df_equality` and `assert_column_equality` now have optional `precision` parameter to test for approximate equality. + +### Removed +- Removed `are_dfs_equal` because it has been superseded by other parts of the API. +- Removed `assert_approx_df_equality` as it has been replaced by adding the optional `precision` parameter to `assert_df_equality`. +- Removed `assert_approx_column_equality` as it has been replaced by adding the optional `precision` parameter to `assert_column_equality`. diff --git a/README.md b/README.md index e4e3960..ff35e93 100644 --- a/README.md +++ b/README.md @@ -307,7 +307,7 @@ def test_approx_col_equality_same(): (None, None) ] df = spark.createDataFrame(data, ["num1", "num2"]) - assert_approx_column_equality(df, "num1", "num2", 0.1) + assert_column_equality(df, "num1", "num2", precision=0.1) ``` Here's an example of a test with columns that are not approximately equal. @@ -321,7 +321,7 @@ def test_approx_col_equality_different(): (None, None) ] df = spark.createDataFrame(data, ["num1", "num2"]) - assert_approx_column_equality(df, "num1", "num2", 0.1) + assert_column_equality(df, "num1", "num2", precision=0.1) ``` This failing test will output a readable error message so the issue is easy to debug. @@ -350,10 +350,10 @@ def test_approx_df_equality_same(): ] df2 = spark.createDataFrame(data2, ["num", "letter"]) - assert_approx_df_equality(df1, df2, 0.1) + assert_df_equality(df1, df2, precision=0.1) ``` -The `assert_approx_df_equality` method is smart and will only perform approximate equality operations for floating point numbers in DataFrames. It'll perform regular equality for strings and other types. +The `assert_df_equality` method has a `precision` parameter that let's the user control the absolute tolerance of any floating point errors that are accepted by the assertion method. It is smart and will only perform approximate equality operations for floating point numbers in DataFrames. It'll perform regular equality for strings and other types. Let's perform an approximate equality comparison for two DataFrames that are not equal. @@ -375,7 +375,7 @@ def test_approx_df_equality_different(): ] df2 = spark.createDataFrame(data2, ["num", "letter"]) - assert_approx_df_equality(df1, df2, 0.1) + assert_df_equality(df1, df2, precision=0.1) ``` Here's the pretty error message that's outputted: @@ -384,7 +384,7 @@ Here's the pretty error message that's outputted: ## Schema mismatch messages -DataFrame equality messages peform schema comparisons before analyzing the actual content of the DataFrames. DataFrames that don't have the same schemas should error out as fast as possible. +DataFrame equality messages perform schema comparisons before analyzing the actual content of the DataFrames. DataFrames that don't have the same schemas should error out as fast as possible. Let's compare a DataFrame that has a string column an integer column with a DataFrame that has two integer columns to observe the schema mismatch message. diff --git a/chispa/__init__.py b/chispa/__init__.py index 948aea5..a8ceee7 100644 --- a/chispa/__init__.py +++ b/chispa/__init__.py @@ -25,5 +25,6 @@ print("Can't find Apache Spark. Please set environment variable SPARK_HOME to root of installation!") exit(-1) -from .dataframe_comparer import DataFramesNotEqualError, assert_df_equality, assert_approx_df_equality -from .column_comparer import ColumnsNotEqualError, assert_column_equality, assert_approx_column_equality +from .dataframe_comparer import assert_df_equality +from .column_comparer import assert_column_equality, ColumnsNotEqualError +from .row_comparer import RowsNotEqualError diff --git a/chispa/column_comparer.py b/chispa/column_comparer.py index b55f3ef..7806de8 100644 --- a/chispa/column_comparer.py +++ b/chispa/column_comparer.py @@ -1,5 +1,11 @@ -from chispa.bcolors import * +from typing import Optional, Any + +from pyspark.sql import DataFrame +from pyspark.sql.types import DataType + +from chispa.bcolors import blue from chispa.prettytable import PrettyTable +from chispa.number_helpers import check_equal class ColumnsNotEqualError(Exception): @@ -7,47 +13,62 @@ class ColumnsNotEqualError(Exception): pass -def assert_column_equality(df, col_name1, col_name2): - elements = df.select(col_name1, col_name2).collect() - colName1Elements = list(map(lambda x: x[0], elements)) - colName2Elements = list(map(lambda x: x[1], elements)) - if colName1Elements != colName2Elements: - zipped = list(zip(colName1Elements, colName2Elements)) - t = PrettyTable([col_name1, col_name2]) - for elements in zipped: - if elements[0] == elements[1]: - first = bcolors.LightBlue + str(elements[0]) + bcolors.LightRed - second = bcolors.LightBlue + str(elements[1]) + bcolors.LightRed - t.add_row([first, second]) - else: - t.add_row([str(elements[0]), str(elements[1])]) - raise ColumnsNotEqualError("\n" + t.get_string()) +def assert_column_equality( + df: DataFrame, + col_name1: str, + col_name2: str, + precision: Optional[float] = None, + allow_nan_equality: bool = False, +) -> None: + """Assert that two columns in a PySpark DataFrame are equal. + Parameters + ---------- + precision : float, optional + Absolute tolerance when checking for equality. + allow_nan_equality : bool, default False + When True, treats two NaN values as equal. -def assert_approx_column_equality(df, col_name1, col_name2, precision): - elements = df.select(col_name1, col_name2).collect() - colName1Elements = list(map(lambda x: x[0], elements)) - colName2Elements = list(map(lambda x: x[1], elements)) + """ all_rows_equal = True - zipped = list(zip(colName1Elements, colName2Elements)) t = PrettyTable([col_name1, col_name2]) + + # Zip both columns together for iterating through elements. + columns = df.select(col_name1, col_name2).collect() + zipped = zip(*[list(map(lambda x: x[i], columns)) for i in [0, 1]]) + for elements in zipped: - first = bcolors.LightBlue + str(elements[0]) + bcolors.LightRed - second = bcolors.LightBlue + str(elements[1]) + bcolors.LightRed - # when one is None and the other isn't, they're not equal - if (elements[0] == None and elements[1] != None) or (elements[0] != None and elements[1] == None): - all_rows_equal = False - t.add_row([str(elements[0]), str(elements[1])]) - # when both are None, they're equal - elif elements[0] == None and elements[1] == None: - t.add_row([first, second]) - # when the diff is less than the threshhold, they're approximately equal - elif abs(elements[0] - elements[1]) < precision: - t.add_row([first, second]) - # otherwise, they're not equal + if are_elements_equal(*elements, precision, allow_nan_equality): + t.add_row([blue(e) for e in elements]) else: all_rows_equal = False - t.add_row([str(elements[0]), str(elements[1])]) + t.add_row([str(e) for e in elements]) + if all_rows_equal == False: raise ColumnsNotEqualError("\n" + t.get_string()) + +def are_elements_equal( + e1: DataType, + e2: DataType, + precision: Optional[float] = None, + allow_nan_equality: bool = False, +) -> bool: + """ + Return True if both elements are equal. + + Parameters + ---------- + precision : float, optional + Absolute tolerance when checking for equality. + allow_nan_equality: bool, default False + When True, treats two NaN values as equal. + + """ + # If both elements are None they are considered equal. + if e1 is None and e2 is None: + return True + if (e1 is None and e2 is not None) or (e2 is None and e1 is not None): + return False + + return check_equal(e1, e2, precision, allow_nan_equality) diff --git a/chispa/dataframe_comparer.py b/chispa/dataframe_comparer.py index 63496fe..21fdddf 100644 --- a/chispa/dataframe_comparer.py +++ b/chispa/dataframe_comparer.py @@ -1,78 +1,51 @@ -from chispa.prettytable import PrettyTable -from chispa.bcolors import * -from chispa.schema_comparer import assert_schema_equality -from chispa.row_comparer import * -import chispa.six as six from functools import reduce +from typing import Callable, Optional +from pyspark.sql import DataFrame -class DataFramesNotEqualError(Exception): - """The DataFrames are not equal""" - pass - - -def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_nan_equality=False, ignore_column_order=False, ignore_row_order=False): +from chispa.schema_comparer import assert_schema_equality +from chispa.row_comparer import assert_rows_equality + + +def assert_df_equality( + df1: DataFrame, + df2: DataFrame, + precision: Optional[float] = None, + ignore_nullable: bool = False, + allow_nan_equality: bool = False, + ignore_column_order: bool = False, + ignore_row_order: bool = False, + transforms: Callable[[DataFrame], DataFrame] = None, +) -> None: + """Assert that two PySpark DataFrames are equal. + + Parameters + ---------- + precision : float, optional + Absolute tolerance when checking for equality. + ignore_nullable : bool, default False + Ignore nullable option when comparing schemas. + allow_nan_equality : bool, default False + When True, treats two NaN values as equal. + ignore_column_order : bool, default False + When True, sorts columns before comparing. + ignore_row_order : bool, default False + When True, sorts all rows before comparing. + transforms : callable + Additional transforms to make to DataFrame before comparison. + + """ + # Apply row and column order transforms + custom transforms. if transforms is None: transforms = [] if ignore_column_order: transforms.append(lambda df: df.select(sorted(df.columns))) if ignore_row_order: transforms.append(lambda df: df.sort(df.columns)) + df1 = reduce(lambda acc, fn: fn(acc), transforms, df1) df2 = reduce(lambda acc, fn: fn(acc), transforms, df2) - assert_schema_equality(df1.schema, df2.schema, ignore_nullable) - if allow_nan_equality: - assert_generic_rows_equality(df1, df2, are_rows_equal_enhanced, [True]) - else: - assert_basic_rows_equality(df1, df2) - -def are_dfs_equal(df1, df2): - if df1.schema != df2.schema: - return False - if df1.collect() != df2.collect(): - return False - return True - - -def assert_approx_df_equality(df1, df2, precision, ignore_nullable=False): + # Check schema and row equality. assert_schema_equality(df1.schema, df2.schema, ignore_nullable) - assert_generic_rows_equality(df1, df2, are_rows_approx_equal, [precision]) - - -def assert_generic_rows_equality(df1, df2, row_equality_fun, row_equality_fun_args): - df1_rows = df1.collect() - df2_rows = df2.collect() - zipped = list(six.moves.zip_longest(df1_rows, df2_rows)) - t = PrettyTable(["df1", "df2"]) - allRowsEqual = True - for r1, r2 in zipped: - # rows are not equal when one is None and the other isn't - if (r1 is not None and r2 is None) or (r2 is not None and r1 is None): - allRowsEqual = False - t.add_row([r1, r2]) - # rows are equal - elif row_equality_fun(r1, r2, *row_equality_fun_args): - first = bcolors.LightBlue + str(r1) + bcolors.LightRed - second = bcolors.LightBlue + str(r2) + bcolors.LightRed - t.add_row([first, second]) - # otherwise, rows aren't equal - else: - allRowsEqual = False - t.add_row([r1, r2]) - if allRowsEqual == False: - raise DataFramesNotEqualError("\n" + t.get_string()) - - -def assert_basic_rows_equality(df1, df2): - rows1 = df1.collect() - rows2 = df2.collect() - if rows1 != rows2: - t = PrettyTable(["df1", "df2"]) - zipped = list(six.moves.zip_longest(rows1, rows2)) - for r1, r2 in zipped: - if r1 == r2: - t.add_row([blue(r1), blue(r2)]) - else: - t.add_row([r1, r2]) - raise DataFramesNotEqualError("\n" + t.get_string()) + assert_rows_equality(df1, df2, precision, allow_nan_equality) diff --git a/chispa/number_helpers.py b/chispa/number_helpers.py index d9654da..e5c041b 100644 --- a/chispa/number_helpers.py +++ b/chispa/number_helpers.py @@ -1,4 +1,5 @@ import math +from typing import Optional def isnan(x): @@ -8,5 +9,27 @@ def isnan(x): return False -def nan_safe_equality(x, y) -> bool: - return (x == y) or (isnan(x) and isnan(y)) \ No newline at end of file +def check_equal( + x, y, + precision: Optional[float] = None, + allow_nan_equality: bool = False, +) -> bool: + """Return True if x and y are equal. + + Parameters + ---------- + precision : float, optional + Absolute tolerance when checking for equality. + allow_nan_equality: bool, defaults to False + When True, treats two NaN values as equal. + + """ + both_floats = (isinstance(x, float) & isinstance(y, float)) + if (precision is not None) & both_floats: + both_equal = abs(x - y) < precision + else: + both_equal = (x == y) + + both_nan = (isnan(x) and isnan(y)) if allow_nan_equality else False + + return both_equal or both_nan diff --git a/chispa/row_comparer.py b/chispa/row_comparer.py index ed504c9..9924968 100644 --- a/chispa/row_comparer.py +++ b/chispa/row_comparer.py @@ -1,40 +1,76 @@ -from pyspark.sql import Row -from chispa.number_helpers import nan_safe_equality +from typing import Optional +from pyspark.sql import Row, DataFrame -def are_rows_equal(r1: Row, r2: Row) -> bool: - return r1 == r2 +from chispa.bcolors import blue +import chispa.six as six +from chispa.number_helpers import check_equal +from chispa.prettytable import PrettyTable -def are_rows_equal_enhanced(r1: Row, r2: Row, allow_nan_equality: bool) -> bool: - if r1 is None and r2 is None: - return True - if (r1 is None and r2 is not None) or (r2 is None and r1 is not None): - return False - d1 = r1.asDict() - d2 = r2.asDict() - if allow_nan_equality: - for key in d1.keys() & d2.keys(): - if not(nan_safe_equality(d1[key], d2[key])): - return False - return True - else: - return r1 == r2 +class RowsNotEqualError(Exception): + """The DataFrame Rows are not all equal.""" + pass + + +def assert_rows_equality( + df1: DataFrame, + df2: DataFrame, + precision: Optional[float] = None, + allow_nan_equality: bool = False, +) -> None: + """Asserts that all row values in the two DataFrames are equal. + + Raises an error with a PrettyTable row comparison if not. + + Parameters + ---------- + precision : float, optional + Absolute tolerance when checking for equality. + allow_nan_equality: bool, default False + When True, treats two NaN values as equal. + + """ + t = PrettyTable(["df1", "df2"]) + allRowsEqual = True + + zipped = six.moves.zip_longest(df1.collect(), df2.collect()) + for r1, r2 in zipped: + if are_rows_equal(r1, r2, precision, allow_nan_equality): + t.add_row([blue(r1), blue(r2)]) + else: + allRowsEqual = False + t.add_row([r1, r2]) + + if allRowsEqual == False: + raise RowsNotEqualError("\n" + t.get_string()) + + +def are_rows_equal( + r1: Row, + r2: Row, + precision: Optional[float] = None, + allow_nan_equality: bool = False, +) -> bool: + """ + Return True if both rows are equal. + Parameters + ---------- + precision : float, optional + Absolute tolerance when checking for equality. + allow_nan_equality: bool, default False + When True, treats two NaN values as equal. -def are_rows_approx_equal(r1: Row, r2: Row, precision: float) -> bool: + """ + # If both rows are None they are considered equal. if r1 is None and r2 is None: return True if (r1 is None and r2 is not None) or (r2 is None and r1 is not None): return False - d1 = r1.asDict() - d2 = r2.asDict() - allEqual = True - for key in d1.keys() & d2.keys(): - if isinstance(d1[key], float) and isinstance(d2[key], float): - if abs(d1[key] - d2[key]) > precision: - allEqual = False - elif d1[key] != d2[key]: - allEqual = False - return allEqual + # Compare the values for each row. Order matters. + for v1, v2 in zip(r1.asDict().values(), r2.asDict().values()): + if not check_equal(v1, v2, precision, allow_nan_equality): + return False + return True diff --git a/tests/test_column_comparer.py b/tests/test_column_comparer.py index da8af9a..f875d0e 100644 --- a/tests/test_column_comparer.py +++ b/tests/test_column_comparer.py @@ -23,30 +23,35 @@ def it_works_with_integer_values(): assert_column_equality(df, "num1", "num2") + def it_equates_nans_when_allow_nan_equality_is_True(): + data = [(1.0, 1.0), (10.3, 10.3), (float('nan'), float('nan')), (None, None)] + df = spark.createDataFrame(data, ["num1", "num2"]) + assert_column_equality(df, "num1", "num2", allow_nan_equality=True) + + def describe_assert_approx_column_equality(): def it_works_with_no_mismatches(): data = [(1.1, 1.1), (1.0004, 1.0005), (.4, .45), (None, None)] df = spark.createDataFrame(data, ["num1", "num2"]) - assert_approx_column_equality(df, "num1", "num2", 0.1) + assert_column_equality(df, "num1", "num2", precision=0.1) def it_throws_when_difference_is_bigger_than_precision(): data = [(1.5, 1.1), (1.0004, 1.0005), (.4, .45)] df = spark.createDataFrame(data, ["num1", "num2"]) with pytest.raises(ColumnsNotEqualError) as e_info: - assert_approx_column_equality(df, "num1", "num2", 0.1) + assert_column_equality(df, "num1", "num2", precision=0.1) def it_throws_when_comparing_floats_with_none(): data = [(1.1, 1.1), (2.2, 2.2), (3.3, None)] df = spark.createDataFrame(data, ["num1", "num2"]) with pytest.raises(ColumnsNotEqualError) as e_info: - assert_approx_column_equality(df, "num1", "num2", 0.1) + assert_column_equality(df, "num1", "num2", precision=0.1) def it_throws_when_comparing_none_with_floats(): data = [(1.1, 1.1), (2.2, 2.2), (None, 3.3)] df = spark.createDataFrame(data, ["num1", "num2"]) with pytest.raises(ColumnsNotEqualError) as e_info: - assert_approx_column_equality(df, "num1", "num2", 0.1) - + assert_column_equality(df, "num1", "num2", precision=0.1) diff --git a/tests/test_dataframe_comparer.py b/tests/test_dataframe_comparer.py index 2f2c2e9..4d14b01 100644 --- a/tests/test_dataframe_comparer.py +++ b/tests/test_dataframe_comparer.py @@ -2,16 +2,18 @@ from spark import * from chispa import * -from chispa.dataframe_comparer import are_dfs_equal from chispa.schema_comparer import SchemasNotEqualError +from chispa.row_comparer import RowsNotEqualError def describe_assert_df_equality(): def it_throws_with_schema_mismatches(): data1 = [(1, "jose"), (2, "li"), (3, "laura")] df1 = spark.createDataFrame(data1, ["num", "expected_name"]) + data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] df2 = spark.createDataFrame(data2, ["name", "expected_name"]) + with pytest.raises(SchemasNotEqualError) as e_info: assert_df_equality(df1, df2) @@ -19,8 +21,10 @@ def it_throws_with_schema_mismatches(): def it_can_work_with_different_row_orders(): data1 = [(1, "jose"), (2, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) + data2 = [(2, "li"), (1, "jose")] df2 = spark.createDataFrame(data2, ["num", "name"]) + assert_df_equality(df1, df2, transforms=[lambda df: df.sort(df.columns)]) @@ -45,7 +49,7 @@ def it_raises_for_row_insensitive_with_diff_content(): df1 = spark.createDataFrame(data1, ["num", "name"]) data2 = [(2, "li"), (1, "jose")] df2 = spark.createDataFrame(data2, ["num", "name"]) - with pytest.raises(DataFramesNotEqualError): + with pytest.raises(RowsNotEqualError): assert_df_equality(df1, df2, transforms=[lambda df: df.sort(df.columns)]) @@ -82,7 +86,7 @@ def it_throws_with_content_mismatches(): df1 = spark.createDataFrame(data1, ["name", "expected_name"]) data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] df2 = spark.createDataFrame(data2, ["name", "expected_name"]) - with pytest.raises(DataFramesNotEqualError) as e_info: + with pytest.raises(RowsNotEqualError) as e_info: assert_df_equality(df1, df2) @@ -91,7 +95,7 @@ def it_throws_with_length_mismatches(): df1 = spark.createDataFrame(data1, ["name", "expected_name"]) data2 = [("jose", "jose"), ("li", "li")] df2 = spark.createDataFrame(data2, ["name", "expected_name"]) - with pytest.raises(DataFramesNotEqualError) as e_info: + with pytest.raises(RowsNotEqualError) as e_info: assert_df_equality(df1, df2) @@ -108,43 +112,18 @@ def it_does_not_consider_nan_values_equal_by_default(): df1 = spark.createDataFrame(data1, ["num", "name"]) data2 = [(float('nan'), "jose"), (2.0, "li")] df2 = spark.createDataFrame(data2, ["num", "name"]) - with pytest.raises(DataFramesNotEqualError) as e_info: + with pytest.raises(RowsNotEqualError) as e_info: assert_df_equality(df1, df2, allow_nan_equality=False) -def describe_are_dfs_equal(): - def it_returns_false_with_schema_mismatches(): - data1 = [(1, "jose"), (2, "li"), (3, "laura")] - df1 = spark.createDataFrame(data1, ["num", "expected_name"]) - data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] - df2 = spark.createDataFrame(data2, ["name", "expected_name"]) - assert are_dfs_equal(df1, df2) == False - - - def it_returns_false_with_content_mismatches(): - data1 = [("jose", "jose"), ("li", "li"), ("luisa", "laura")] - df1 = spark.createDataFrame(data1, ["name", "expected_name"]) - data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] - df2 = spark.createDataFrame(data2, ["name", "expected_name"]) - assert are_dfs_equal(df1, df2) == False - - - def it_returns_true_when_dfs_are_same(): - data1 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] - df1 = spark.createDataFrame(data1, ["name", "expected_name"]) - data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] - df2 = spark.createDataFrame(data2, ["name", "expected_name"]) - assert are_dfs_equal(df1, df2) == True - - def describe_assert_approx_df_equality(): def it_throws_with_content_mismatch(): data1 = [(1.0, "jose"), (1.1, "li"), (1.2, "laura"), (1.0, None)] df1 = spark.createDataFrame(data1, ["num", "expected_name"]) data2 = [(1.0, "jose"), (1.05, "li"), (1.0, "laura"), (None, "hi")] df2 = spark.createDataFrame(data2, ["num", "expected_name"]) - with pytest.raises(DataFramesNotEqualError) as e_info: - assert_approx_df_equality(df1, df2, 0.1) + with pytest.raises(RowsNotEqualError) as e_info: + assert_df_equality(df1, df2, precision=0.1) def it_throws_with_with_length_mismatch(): @@ -152,8 +131,8 @@ def it_throws_with_with_length_mismatch(): df1 = spark.createDataFrame(data1, ["num", "expected_name"]) data2 = [(1.0, "jose"), (1.05, "li")] df2 = spark.createDataFrame(data2, ["num", "expected_name"]) - with pytest.raises(DataFramesNotEqualError) as e_info: - assert_approx_df_equality(df1, df2, 0.1) + with pytest.raises(RowsNotEqualError) as e_info: + assert_df_equality(df1, df2, precision=0.1) def it_does_not_throw_with_no_mismatch(): @@ -161,5 +140,4 @@ def it_does_not_throw_with_no_mismatch(): df1 = spark.createDataFrame(data1, ["num", "expected_name"]) data2 = [(1.0, "jose"), (1.05, "li"), (1.2, "laura"), (None, None)] df2 = spark.createDataFrame(data2, ["num", "expected_name"]) - assert_approx_df_equality(df1, df2, 0.1) - + assert_df_equality(df1, df2, precision=0.1) diff --git a/tests/test_readme_examples.py b/tests/test_readme_examples.py index cb1717b..f82053c 100644 --- a/tests/test_readme_examples.py +++ b/tests/test_readme_examples.py @@ -86,7 +86,7 @@ def test_remove_non_word_characters_long_error(): (None, None) ] expected_df = spark.createDataFrame(expected_data, ["name", "clean_name"]) - with pytest.raises(DataFramesNotEqualError) as e_info: + with pytest.raises(RowsNotEqualError) as e_info: assert_df_equality(actual_df, expected_df) @@ -131,7 +131,7 @@ def test_approx_col_equality_same(): (None, None) ] df = spark.createDataFrame(data, ["num1", "num2"]) - assert_approx_column_equality(df, "num1", "num2", 0.1) + assert_column_equality(df, "num1", "num2", precision=0.1) def test_approx_col_equality_different(): @@ -143,7 +143,7 @@ def test_approx_col_equality_different(): ] df = spark.createDataFrame(data, ["num1", "num2"]) with pytest.raises(ColumnsNotEqualError) as e_info: - assert_approx_column_equality(df, "num1", "num2", 0.1) + assert_column_equality(df, "num1", "num2", precision=0.1) def test_approx_df_equality_same(): @@ -161,7 +161,7 @@ def test_approx_df_equality_same(): (None, None) ] df2 = spark.createDataFrame(data2, ["num", "letter"]) - assert_approx_df_equality(df1, df2, 0.1) + assert_df_equality(df1, df2, precision=0.1) def test_approx_df_equality_different(): @@ -179,8 +179,8 @@ def test_approx_df_equality_different(): (None, None) ] df2 = spark.createDataFrame(data2, ["num", "letter"]) - with pytest.raises(DataFramesNotEqualError) as e_info: - assert_approx_df_equality(df1, df2, 0.1) + with pytest.raises(RowsNotEqualError) as e_info: + assert_df_equality(df1, df2, precision=0.1) def describe_schema_mismatch_messages(): diff --git a/tests/test_row_comparer.py b/tests/test_row_comparer.py index e8b633a..ceff054 100644 --- a/tests/test_row_comparer.py +++ b/tests/test_row_comparer.py @@ -1,28 +1,32 @@ -import pytest - -from spark import * -from chispa.row_comparer import * +from chispa.row_comparer import are_rows_equal from pyspark.sql import Row -def test_are_rows_equal(): - assert are_rows_equal(Row("bob", "jose"), Row("li", "li")) == False - assert are_rows_equal(Row("luisa", "laura"), Row("luisa", "laura")) == True - assert are_rows_equal(Row(None, None), Row(None, None)) == True - -def test_are_rows_equal_enhanced(): - assert are_rows_equal_enhanced(Row(n1 = "bob", n2 = "jose"), Row(n1 = "li", n2 = "li"), False) == False - assert are_rows_equal_enhanced(Row(n1 = "luisa", n2 = "laura"), Row(n1 = "luisa", n2 = "laura"), False) == True - assert are_rows_equal_enhanced(Row(n1 = None, n2 = None), Row(n1 = None, n2 = None), False) == True - - assert are_rows_equal_enhanced(Row(n1="bob", n2="jose"), Row(n1="li", n2="li"), True) == False - assert are_rows_equal_enhanced(Row(n1=float('nan'), n2="jose"), Row(n1=float('nan'), n2="jose"), True) == True - assert are_rows_equal_enhanced(Row(n1=float('nan'), n2="jose"), Row(n1="hi", n2="jose"), True) == False - - -def test_are_rows_approx_equal(): - assert are_rows_approx_equal(Row(num = 1.1, first_name = "li"), Row(num = 1.05, first_name = "li"), 0.1) == True - assert are_rows_approx_equal(Row(num = 5.0, first_name = "laura"), Row(num = 5.0, first_name = "laura"), 0.1) == True - assert are_rows_approx_equal(Row(num = 5.0, first_name = "laura"), Row(num = 5.9, first_name = "laura"), 0.1) == False - assert are_rows_approx_equal(Row(num = None, first_name = None), Row(num = None, first_name = None), 0.1) == True +def describe_are_rows_equal(): + def returns_False_when_string_values_are_not_equal(): + assert are_rows_equal(Row(n1="bob", n2="jose"), Row(n1="li", n2="li")) == False + def returns_True_when_string_values_are_equal(): + assert are_rows_equal(Row(n1="luisa", n2="laura"), Row(n1="luisa", n2="laura")) == True + def returns_True_when_both_rows_are_None(): + assert are_rows_equal(Row(n1=None, n2=None), Row(n1=None, n2=None)) == True + + +def describe_are_rows_equal_when_allowing_nan_equality(): + def returns_False_when_no_NaN_values_to_compare_and_other_values_are_not_equal(): + assert are_rows_equal(Row(n1="bob", n2="jose"), Row(n1="li", n2="li"), allow_nan_equality=True) == False + def returns_True_when_either_the_values_are_equal_or_both_nan(): + assert are_rows_equal(Row(n1=float('nan'), n2="jose"), Row(n1=float('nan'), n2="jose"), allow_nan_equality=True) == True + def returns_False_when_comparing_nan_to_string(): + assert are_rows_equal(Row(n1=float('nan'), n2="jose"), Row(n1="hi", n2="jose"), allow_nan_equality=True) == False + + +def describe_are_rows_equal_when_given_precision(): + def returns_True_when_float_value_difference_is_less_than_precision(): + assert are_rows_equal(Row(num = 1.1, first_name = "li"), Row(num = 1.05, first_name = "li"), precision=0.1) == True + def returns_True_when_float_values_are_exactly_equal(): + assert are_rows_equal(Row(num = 5.0, first_name = "laura"), Row(num = 5.0, first_name = "laura"), precision=0.1) == True + def returns_False_when_float_value_difference_is_more_than_precision(): + assert are_rows_equal(Row(num = 5.0, first_name = "laura"), Row(num = 5.9, first_name = "laura"), precision=0.1) == False + def returns_True_when_all_values_are_Nones(): + assert are_rows_equal(Row(num = None, first_name = None), Row(num = None, first_name = None), precision=0.1) == True