diff --git a/README.md b/README.md index d902cde..0b4a493 100644 --- a/README.md +++ b/README.md @@ -321,6 +321,21 @@ assert_df_equality(df1, df2, underline_cells=True) ![DfsNotEqualUnderlined](https://github.com/MrPowers/chispa/blob/main/images/df_not_equal_underlined.png) +### Changing the default color scheme + +You can choose to change the default color scheme of the output font when comparing dataframes. + +```python +colour_scheme = { + "default":"light_red", + "matched":"light_blue", + "underlined":"purple" +} +assert_df_equality(df1, df2, allow_nan_equality=False, underline_cells=True, color_scheme=colour_scheme) +``` + +A list of available colors are available in the [`bcolors` class](https://github.com/MrPowers/chispa/blob/main/chispa/bcolors.py#L1). + ## Approximate column equality We can check if columns are approximately equal, which is especially useful for floating number comparisons. diff --git a/chispa/bcolors.py b/chispa/bcolors.py index bbbb930..71a9c78 100644 --- a/chispa/bcolors.py +++ b/chispa/bcolors.py @@ -1,40 +1,50 @@ class bcolors: - NC = '\033[0m' # No Color, reset all - - Bold = '\033[1m' - Underlined = '\033[4m' - Blink = '\033[5m' - Inverted = '\033[7m' - Hidden = '\033[8m' - - Black = '\033[30m' - Red = '\033[31m' - Green = '\033[32m' - Yellow = '\033[33m' - Blue = '\033[34m' - Purple = '\033[35m' - Cyan = '\033[36m' - LightGray = '\033[37m' - DarkGray = '\033[30m' - LightRed = '\033[31m' - LightGreen = '\033[32m' - LightYellow = '\033[93m' - LightBlue = '\033[34m' - LightPurple = '\033[35m' - LightCyan = '\033[36m' - White = '\033[97m' + nc = '\033[0m' # No Color, reset all # Style - Bold = '\033[1m' - Underline = '\033[4m' + bold = '\033[1m' + underlined = '\033[4m' + blink = '\033[5m' + inverted = '\033[7m' + hidden = '\033[8m' + # Colors + black = '\033[30m' + red = '\033[31m' + green = '\033[32m' + yellow = '\033[33m' + blue = '\033[34m' + purple = '\033[35m' + cyan = '\033[36m' + light_gray = '\033[37m' + dark_gray = '\033[30m' + light_red = '\033[31m' + light_green = '\033[32m' + light_yellow = '\033[93m' + light_blue = '\033[34m' + light_purple = '\033[35m' + light_cyan = '\033[36m' + white = '\033[97m' -def blue(s: str) -> str: - return bcolors.LightBlue + str(s) + bcolors.LightRed +def normal_text(input_text: str, color_scheme: dict) -> str: + return get_color(color_scheme["matched"]) + input_text + get_color(color_scheme["default"]) -def underline_text(input_text: str) -> str: + +def underline_text(input_text: str, color_scheme: dict) -> str: """ Takes an input string and returns a white, underlined string (based on PrettyTable formatting) """ - return bcolors.White + bcolors.Underline + input_text + bcolors.NC + bcolors.LightRed + return get_color(color_scheme["underlined"]) + bcolors.underlined + input_text + bcolors.nc + get_color(color_scheme["default"]) + +def get_color(color_string: str) -> str: + """ + Takes a color string, e.g. "Red" and returns Pretty Tables color code string if it exists in the bcolors class, otherwise raise an Exception + """ + color_string_cleaned = color_string.lower() # Clean string + if hasattr(bcolors(), color_string_cleaned): + return getattr(bcolors(), color_string_cleaned) + else: + raise Exception(f"Unable to find color '{color_string}' in bcolors.") + + diff --git a/chispa/column_comparer.py b/chispa/column_comparer.py index b55f3ef..02604a5 100644 --- a/chispa/column_comparer.py +++ b/chispa/column_comparer.py @@ -16,8 +16,8 @@ def assert_column_equality(df, col_name1, col_name2): t = PrettyTable([col_name1, col_name2]) for elements in zipped: if elements[0] == elements[1]: - first = bcolors.LightBlue + str(elements[0]) + bcolors.LightRed - second = bcolors.LightBlue + str(elements[1]) + bcolors.LightRed + first = bcolors.light_blue + str(elements[0]) + bcolors.light_red + second = bcolors.light_blue + str(elements[1]) + bcolors.light_red t.add_row([first, second]) else: t.add_row([str(elements[0]), str(elements[1])]) @@ -32,8 +32,8 @@ def assert_approx_column_equality(df, col_name1, col_name2, precision): zipped = list(zip(colName1Elements, colName2Elements)) t = PrettyTable([col_name1, col_name2]) for elements in zipped: - first = bcolors.LightBlue + str(elements[0]) + bcolors.LightRed - second = bcolors.LightBlue + str(elements[1]) + bcolors.LightRed + first = bcolors.light_blue + str(elements[0]) + bcolors.light_red + second = bcolors.light_blue + str(elements[1]) + bcolors.light_red # when one is None and the other isn't, they're not equal if (elements[0] == None and elements[1] != None) or (elements[0] != None and elements[1] == None): all_rows_equal = False diff --git a/chispa/dataframe_comparer.py b/chispa/dataframe_comparer.py index 96db21d..846fe85 100644 --- a/chispa/dataframe_comparer.py +++ b/chispa/dataframe_comparer.py @@ -9,8 +9,19 @@ class DataFramesNotEqualError(Exception): pass +default_colour_scheme = { + "default":"light_red", + "matched":"light_blue", + "underlined":"green" +} + def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_nan_equality=False, - ignore_column_order=False, ignore_row_order=False, underline_cells=False): + ignore_column_order=False, ignore_row_order=False, underline_cells=False, color_scheme=None): + if color_scheme is None: + color_scheme = default_colour_scheme + else: + if ("default" not in color_scheme.keys()) or ("matched" not in color_scheme.keys()) or ("underlined" not in color_scheme.keys()): + raise Exception("Color scheme requires keys:'default', 'matched' and 'underlined'.") if transforms is None: transforms = [] if ignore_column_order: @@ -22,10 +33,10 @@ def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_n assert_schema_equality(df1.schema, df2.schema, ignore_nullable) if allow_nan_equality: assert_generic_rows_equality( - df1.collect(), df2.collect(), are_rows_equal_enhanced, [True], underline_cells=underline_cells) + df1.collect(), df2.collect(), are_rows_equal_enhanced, [True], color_scheme=color_scheme, underline_cells=underline_cells) else: assert_basic_rows_equality( - df1.collect(), df2.collect(), underline_cells=underline_cells) + df1.collect(), df2.collect(), color_scheme=color_scheme, underline_cells=underline_cells) def are_dfs_equal(df1, df2): @@ -48,8 +59,8 @@ def assert_approx_df_equality(df1, df2, precision, ignore_nullable=False, transf df2 = reduce(lambda acc, fn: fn(acc), transforms, df2) assert_schema_equality(df1.schema, df2.schema, ignore_nullable) if precision != 0: - assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_approx_equal, [precision, allow_nan_equality]) + assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_approx_equal, [precision, allow_nan_equality], color_scheme=default_colour_scheme) elif allow_nan_equality: - assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_equal_enhanced, [True]) + assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_equal_enhanced, [True], color_scheme=default_colour_scheme) else: assert_basic_rows_equality(df1.collect(), df2.collect()) diff --git a/chispa/rows_comparer.py b/chispa/rows_comparer.py index e9114f8..6dc9a9e 100644 --- a/chispa/rows_comparer.py +++ b/chispa/rows_comparer.py @@ -6,7 +6,7 @@ from typing import List -def assert_basic_rows_equality(rows1, rows2, underline_cells=False): +def assert_basic_rows_equality(rows1, rows2, color_scheme, underline_cells=False): if underline_cells: row_column_names = rows1[0].__fields__ num_columns = len(row_column_names) @@ -15,17 +15,17 @@ def assert_basic_rows_equality(rows1, rows2, underline_cells=False): zipped = list(six.moves.zip_longest(rows1, rows2)) for r1, r2 in zipped: if r1 == r2: - t.add_row([blue(r1), blue(r2)]) + t.add_row([normal_text(input_text=str(r1), color_scheme=color_scheme), normal_text(input_text=str(r2), color_scheme=color_scheme)]) else: if underline_cells: t.add_row(__underline_cells_in_row( - r1=r1, r2=r2, row_column_names=row_column_names, num_columns=num_columns)) + r1=r1, r2=r2, row_column_names=row_column_names, num_columns=num_columns, color_scheme=color_scheme)) else: t.add_row([r1, r2]) raise chispa.DataFramesNotEqualError("\n" + t.get_string()) -def assert_generic_rows_equality(rows1, rows2, row_equality_fun, row_equality_fun_args, underline_cells=False): +def assert_generic_rows_equality(rows1, rows2, row_equality_fun, row_equality_fun_args, color_scheme, underline_cells=False): df1_rows = rows1 df2_rows = rows2 zipped = list(six.moves.zip_longest(df1_rows, df2_rows)) @@ -41,8 +41,8 @@ def assert_generic_rows_equality(rows1, rows2, row_equality_fun, row_equality_fu t.add_row([r1, r2]) # rows are equal elif row_equality_fun(r1, r2, *row_equality_fun_args): - first = bcolors.LightBlue + str(r1) + bcolors.LightRed - second = bcolors.LightBlue + str(r2) + bcolors.LightRed + first = get_color(color_scheme["matched"]) + str(r1) + get_color(color_scheme["default"]) + second = get_color(color_scheme["matched"]) + str(r2) + get_color(color_scheme["default"]) t.add_row([first, second]) # otherwise, rows aren't equal else: @@ -50,14 +50,14 @@ def assert_generic_rows_equality(rows1, rows2, row_equality_fun, row_equality_fu # Underline cells if requested if underline_cells: t.add_row(__underline_cells_in_row( - r1=r1, r2=r2, row_column_names=row_column_names, num_columns=num_columns)) + r1=r1, r2=r2, row_column_names=row_column_names, num_columns=num_columns, color_scheme=color_scheme)) else: t.add_row([r1, r2]) if allRowsEqual == False: raise chispa.DataFramesNotEqualError("\n" + t.get_string()) -def __underline_cells_in_row(r1=Row, r2=Row, row_column_names=List[str], num_columns=int) -> List[str]: +def __underline_cells_in_row(r1:Row, r2:Row, row_column_names:List[str], num_columns:int, color_scheme:dict) -> List[str]: """ Takes two Row types, a list of column names for the Rows and the length of columns Returns list of two strings, with underlined columns within rows that are different for PrettyTable @@ -72,9 +72,9 @@ def __underline_cells_in_row(r1=Row, r2=Row, row_column_names=List[str], num_col if r1[column] != r2[column]: r1_string += underline_text( - f"{column}='{r1[column]}'") + f"{append_str}" + f"{column}='{r1[column]}'", color_scheme=color_scheme) + f"{append_str}" r2_string += underline_text( - f"{column}='{r2[column]}'") + f"{append_str}" + f"{column}='{r2[column]}'", color_scheme=color_scheme) + f"{append_str}" else: r1_string += f"{column}='{r1[column]}'{append_str}" r2_string += f"{column}='{r2[column]}'{append_str}" @@ -82,4 +82,4 @@ def __underline_cells_in_row(r1=Row, r2=Row, row_column_names=List[str], num_col r1_string += ")" r2_string += ")" - return [bcolors.LightRed + r1_string, r2_string] + return [r1_string, r2_string] diff --git a/chispa/schema_comparer.py b/chispa/schema_comparer.py index e5d43ba..6b1642e 100644 --- a/chispa/schema_comparer.py +++ b/chispa/schema_comparer.py @@ -8,6 +8,12 @@ class SchemasNotEqualError(Exception): pass +default_colour_scheme = { + "default":"light_red", + "matched":"light_blue", + "underlined":"green" +} + def assert_schema_equality(s1, s2, ignore_nullable=False): if ignore_nullable: assert_schema_equality_ignore_nullable(s1, s2) @@ -21,7 +27,7 @@ def assert_basic_schema_equality(s1, s2): zipped = list(six.moves.zip_longest(s1, s2)) for sf1, sf2 in zipped: if sf1 == sf2: - t.add_row([blue(sf1), blue(sf2)]) + t.add_row([normal_text(str(sf1), color_scheme=default_colour_scheme), normal_text(str(sf2), color_scheme=default_colour_scheme)]) else: t.add_row([sf1, sf2]) raise SchemasNotEqualError("\n" + t.get_string()) @@ -33,7 +39,7 @@ def assert_schema_equality_ignore_nullable(s1, s2): zipped = list(six.moves.zip_longest(s1, s2)) for sf1, sf2 in zipped: if are_structfields_equal(sf1, sf2, True): - t.add_row([blue(sf1), blue(sf2)]) + t.add_row([normal_text(str(sf1), color_scheme=default_colour_scheme), normal_text(str(sf2), color_scheme=default_colour_scheme)]) else: t.add_row([sf1, sf2]) raise SchemasNotEqualError("\n" + t.get_string()) diff --git a/tests/test_rows_comparer.py b/tests/test_rows_comparer.py index 533daf1..8b5f60c 100644 --- a/tests/test_rows_comparer.py +++ b/tests/test_rows_comparer.py @@ -7,6 +7,12 @@ import math +default_colour_scheme = { + "default":"light_red", + "matched":"light_blue", + "underlined":"green" +} + def describe_assert_basic_rows_equality(): def it_throws_with_row_mismatches(): data1 = [(1, "jose"), (2, "li"), (3, "laura")] @@ -14,12 +20,12 @@ def it_throws_with_row_mismatches(): data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] df2 = spark.createDataFrame(data2, ["name", "expected_name"]) with pytest.raises(DataFramesNotEqualError) as e_info: - assert_basic_rows_equality(df1.collect(), df2.collect()) + assert_basic_rows_equality(df1.collect(), df2.collect(), color_scheme=default_colour_scheme) def it_works_when_rows_are_the_same(): data1 = [(1, "jose"), (2, "li"), (3, "laura")] df1 = spark.createDataFrame(data1, ["num", "expected_name"]) data2 = [(1, "jose"), (2, "li"), (3, "laura")] df2 = spark.createDataFrame(data2, ["name", "expected_name"]) - assert_basic_rows_equality(df1.collect(), df2.collect()) + assert_basic_rows_equality(df1.collect(), df2.collect(), color_scheme=default_colour_scheme)