Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,21 @@ assert_df_equality(df1, df2, underline_cells=True)

![DfsNotEqualUnderlined](https://github.com/MrPowers/chispa/blob/main/images/df_not_equal_underlined.png)

### Changing the default color scheme

You can choose to change the default color scheme of the output font when comparing dataframes.

```python
colour_scheme = {
"default":"light_red",
"matched":"light_blue",
"underlined":"purple"
}
assert_df_equality(df1, df2, allow_nan_equality=False, underline_cells=True, color_scheme=colour_scheme)
```

A list of available colors are available in the [`bcolors` class](https://github.com/MrPowers/chispa/blob/main/chispa/bcolors.py#L1).

## Approximate column equality

We can check if columns are approximately equal, which is especially useful for floating number comparisons.
Expand Down
70 changes: 40 additions & 30 deletions chispa/bcolors.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,50 @@
class bcolors:
NC = '\033[0m' # No Color, reset all

Bold = '\033[1m'
Underlined = '\033[4m'
Blink = '\033[5m'
Inverted = '\033[7m'
Hidden = '\033[8m'

Black = '\033[30m'
Red = '\033[31m'
Green = '\033[32m'
Yellow = '\033[33m'
Blue = '\033[34m'
Purple = '\033[35m'
Cyan = '\033[36m'
LightGray = '\033[37m'
DarkGray = '\033[30m'
LightRed = '\033[31m'
LightGreen = '\033[32m'
LightYellow = '\033[93m'
LightBlue = '\033[34m'
LightPurple = '\033[35m'
LightCyan = '\033[36m'
White = '\033[97m'
nc = '\033[0m' # No Color, reset all

# Style
Bold = '\033[1m'
Underline = '\033[4m'
bold = '\033[1m'
underlined = '\033[4m'
blink = '\033[5m'
inverted = '\033[7m'
hidden = '\033[8m'

# Colors
black = '\033[30m'
red = '\033[31m'
green = '\033[32m'
yellow = '\033[33m'
blue = '\033[34m'
purple = '\033[35m'
cyan = '\033[36m'
light_gray = '\033[37m'
dark_gray = '\033[30m'
light_red = '\033[31m'
light_green = '\033[32m'
light_yellow = '\033[93m'
light_blue = '\033[34m'
light_purple = '\033[35m'
light_cyan = '\033[36m'
white = '\033[97m'

def blue(s: str) -> str:
return bcolors.LightBlue + str(s) + bcolors.LightRed

def normal_text(input_text: str, color_scheme: dict) -> str:
return get_color(color_scheme["matched"]) + input_text + get_color(color_scheme["default"])

def underline_text(input_text: str) -> str:

def underline_text(input_text: str, color_scheme: dict) -> str:
"""
Takes an input string and returns a white, underlined string (based on PrettyTable formatting)
"""
return bcolors.White + bcolors.Underline + input_text + bcolors.NC + bcolors.LightRed
return get_color(color_scheme["underlined"]) + bcolors.underlined + input_text + bcolors.nc + get_color(color_scheme["default"])

def get_color(color_string: str) -> str:
"""
Takes a color string, e.g. "Red" and returns Pretty Tables color code string if it exists in the bcolors class, otherwise raise an Exception
"""
color_string_cleaned = color_string.lower() # Clean string
if hasattr(bcolors(), color_string_cleaned):
return getattr(bcolors(), color_string_cleaned)
else:
raise Exception(f"Unable to find color '{color_string}' in bcolors.")


8 changes: 4 additions & 4 deletions chispa/column_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def assert_column_equality(df, col_name1, col_name2):
t = PrettyTable([col_name1, col_name2])
for elements in zipped:
if elements[0] == elements[1]:
first = bcolors.LightBlue + str(elements[0]) + bcolors.LightRed
second = bcolors.LightBlue + str(elements[1]) + bcolors.LightRed
first = bcolors.light_blue + str(elements[0]) + bcolors.light_red
second = bcolors.light_blue + str(elements[1]) + bcolors.light_red
t.add_row([first, second])
else:
t.add_row([str(elements[0]), str(elements[1])])
Expand All @@ -32,8 +32,8 @@ def assert_approx_column_equality(df, col_name1, col_name2, precision):
zipped = list(zip(colName1Elements, colName2Elements))
t = PrettyTable([col_name1, col_name2])
for elements in zipped:
first = bcolors.LightBlue + str(elements[0]) + bcolors.LightRed
second = bcolors.LightBlue + str(elements[1]) + bcolors.LightRed
first = bcolors.light_blue + str(elements[0]) + bcolors.light_red
second = bcolors.light_blue + str(elements[1]) + bcolors.light_red
# when one is None and the other isn't, they're not equal
if (elements[0] == None and elements[1] != None) or (elements[0] != None and elements[1] == None):
all_rows_equal = False
Expand Down
21 changes: 16 additions & 5 deletions chispa/dataframe_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,19 @@ class DataFramesNotEqualError(Exception):
pass


default_colour_scheme = {
"default":"light_red",
"matched":"light_blue",
"underlined":"green"
}

def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_nan_equality=False,
ignore_column_order=False, ignore_row_order=False, underline_cells=False):
ignore_column_order=False, ignore_row_order=False, underline_cells=False, color_scheme=None):
if color_scheme is None:
color_scheme = default_colour_scheme
else:
if ("default" not in color_scheme.keys()) or ("matched" not in color_scheme.keys()) or ("underlined" not in color_scheme.keys()):
raise Exception("Color scheme requires keys:'default', 'matched' and 'underlined'.")
if transforms is None:
transforms = []
if ignore_column_order:
Expand All @@ -22,10 +33,10 @@ def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_n
assert_schema_equality(df1.schema, df2.schema, ignore_nullable)
if allow_nan_equality:
assert_generic_rows_equality(
df1.collect(), df2.collect(), are_rows_equal_enhanced, [True], underline_cells=underline_cells)
df1.collect(), df2.collect(), are_rows_equal_enhanced, [True], color_scheme=color_scheme, underline_cells=underline_cells)
else:
assert_basic_rows_equality(
df1.collect(), df2.collect(), underline_cells=underline_cells)
df1.collect(), df2.collect(), color_scheme=color_scheme, underline_cells=underline_cells)


def are_dfs_equal(df1, df2):
Expand All @@ -48,8 +59,8 @@ def assert_approx_df_equality(df1, df2, precision, ignore_nullable=False, transf
df2 = reduce(lambda acc, fn: fn(acc), transforms, df2)
assert_schema_equality(df1.schema, df2.schema, ignore_nullable)
if precision != 0:
assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_approx_equal, [precision, allow_nan_equality])
assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_approx_equal, [precision, allow_nan_equality], color_scheme=default_colour_scheme)
elif allow_nan_equality:
assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_equal_enhanced, [True])
assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_equal_enhanced, [True], color_scheme=default_colour_scheme)
else:
assert_basic_rows_equality(df1.collect(), df2.collect())
22 changes: 11 additions & 11 deletions chispa/rows_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import List


def assert_basic_rows_equality(rows1, rows2, underline_cells=False):
def assert_basic_rows_equality(rows1, rows2, color_scheme, underline_cells=False):
if underline_cells:
row_column_names = rows1[0].__fields__
num_columns = len(row_column_names)
Expand All @@ -15,17 +15,17 @@ def assert_basic_rows_equality(rows1, rows2, underline_cells=False):
zipped = list(six.moves.zip_longest(rows1, rows2))
for r1, r2 in zipped:
if r1 == r2:
t.add_row([blue(r1), blue(r2)])
t.add_row([normal_text(input_text=str(r1), color_scheme=color_scheme), normal_text(input_text=str(r2), color_scheme=color_scheme)])
else:
if underline_cells:
t.add_row(__underline_cells_in_row(
r1=r1, r2=r2, row_column_names=row_column_names, num_columns=num_columns))
r1=r1, r2=r2, row_column_names=row_column_names, num_columns=num_columns, color_scheme=color_scheme))
else:
t.add_row([r1, r2])
raise chispa.DataFramesNotEqualError("\n" + t.get_string())


def assert_generic_rows_equality(rows1, rows2, row_equality_fun, row_equality_fun_args, underline_cells=False):
def assert_generic_rows_equality(rows1, rows2, row_equality_fun, row_equality_fun_args, color_scheme, underline_cells=False):
df1_rows = rows1
df2_rows = rows2
zipped = list(six.moves.zip_longest(df1_rows, df2_rows))
Expand All @@ -41,23 +41,23 @@ def assert_generic_rows_equality(rows1, rows2, row_equality_fun, row_equality_fu
t.add_row([r1, r2])
# rows are equal
elif row_equality_fun(r1, r2, *row_equality_fun_args):
first = bcolors.LightBlue + str(r1) + bcolors.LightRed
second = bcolors.LightBlue + str(r2) + bcolors.LightRed
first = get_color(color_scheme["matched"]) + str(r1) + get_color(color_scheme["default"])
second = get_color(color_scheme["matched"]) + str(r2) + get_color(color_scheme["default"])
t.add_row([first, second])
# otherwise, rows aren't equal
else:
allRowsEqual = False
# Underline cells if requested
if underline_cells:
t.add_row(__underline_cells_in_row(
r1=r1, r2=r2, row_column_names=row_column_names, num_columns=num_columns))
r1=r1, r2=r2, row_column_names=row_column_names, num_columns=num_columns, color_scheme=color_scheme))
else:
t.add_row([r1, r2])
if allRowsEqual == False:
raise chispa.DataFramesNotEqualError("\n" + t.get_string())


def __underline_cells_in_row(r1=Row, r2=Row, row_column_names=List[str], num_columns=int) -> List[str]:
def __underline_cells_in_row(r1:Row, r2:Row, row_column_names:List[str], num_columns:int, color_scheme:dict) -> List[str]:
"""
Takes two Row types, a list of column names for the Rows and the length of columns
Returns list of two strings, with underlined columns within rows that are different for PrettyTable
Expand All @@ -72,14 +72,14 @@ def __underline_cells_in_row(r1=Row, r2=Row, row_column_names=List[str], num_col

if r1[column] != r2[column]:
r1_string += underline_text(
f"{column}='{r1[column]}'") + f"{append_str}"
f"{column}='{r1[column]}'", color_scheme=color_scheme) + f"{append_str}"
r2_string += underline_text(
f"{column}='{r2[column]}'") + f"{append_str}"
f"{column}='{r2[column]}'", color_scheme=color_scheme) + f"{append_str}"
else:
r1_string += f"{column}='{r1[column]}'{append_str}"
r2_string += f"{column}='{r2[column]}'{append_str}"

r1_string += ")"
r2_string += ")"

return [bcolors.LightRed + r1_string, r2_string]
return [r1_string, r2_string]
10 changes: 8 additions & 2 deletions chispa/schema_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ class SchemasNotEqualError(Exception):
pass


default_colour_scheme = {
"default":"light_red",
"matched":"light_blue",
"underlined":"green"
}

def assert_schema_equality(s1, s2, ignore_nullable=False):
if ignore_nullable:
assert_schema_equality_ignore_nullable(s1, s2)
Expand All @@ -21,7 +27,7 @@ def assert_basic_schema_equality(s1, s2):
zipped = list(six.moves.zip_longest(s1, s2))
for sf1, sf2 in zipped:
if sf1 == sf2:
t.add_row([blue(sf1), blue(sf2)])
t.add_row([normal_text(str(sf1), color_scheme=default_colour_scheme), normal_text(str(sf2), color_scheme=default_colour_scheme)])
else:
t.add_row([sf1, sf2])
raise SchemasNotEqualError("\n" + t.get_string())
Expand All @@ -33,7 +39,7 @@ def assert_schema_equality_ignore_nullable(s1, s2):
zipped = list(six.moves.zip_longest(s1, s2))
for sf1, sf2 in zipped:
if are_structfields_equal(sf1, sf2, True):
t.add_row([blue(sf1), blue(sf2)])
t.add_row([normal_text(str(sf1), color_scheme=default_colour_scheme), normal_text(str(sf2), color_scheme=default_colour_scheme)])
else:
t.add_row([sf1, sf2])
raise SchemasNotEqualError("\n" + t.get_string())
Expand Down
10 changes: 8 additions & 2 deletions tests/test_rows_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,25 @@
import math


default_colour_scheme = {
"default":"light_red",
"matched":"light_blue",
"underlined":"green"
}

def describe_assert_basic_rows_equality():
def it_throws_with_row_mismatches():
data1 = [(1, "jose"), (2, "li"), (3, "laura")]
df1 = spark.createDataFrame(data1, ["num", "expected_name"])
data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")]
df2 = spark.createDataFrame(data2, ["name", "expected_name"])
with pytest.raises(DataFramesNotEqualError) as e_info:
assert_basic_rows_equality(df1.collect(), df2.collect())
assert_basic_rows_equality(df1.collect(), df2.collect(), color_scheme=default_colour_scheme)

def it_works_when_rows_are_the_same():
data1 = [(1, "jose"), (2, "li"), (3, "laura")]
df1 = spark.createDataFrame(data1, ["num", "expected_name"])
data2 = [(1, "jose"), (2, "li"), (3, "laura")]
df2 = spark.createDataFrame(data2, ["name", "expected_name"])
assert_basic_rows_equality(df1.collect(), df2.collect())
assert_basic_rows_equality(df1.collect(), df2.collect(), color_scheme=default_colour_scheme)