diff --git a/chispa/schema_comparer.py b/chispa/schema_comparer.py index c341988..ef98cbf 100644 --- a/chispa/schema_comparer.py +++ b/chispa/schema_comparer.py @@ -1,6 +1,7 @@ from chispa.prettytable import PrettyTable from chispa.bcolors import * import chispa.six as six +from pyspark.sql.types import StructField class SchemasNotEqualError(Exception): @@ -8,6 +9,19 @@ class SchemasNotEqualError(Exception): pass +class StructFieldPrettyPrint(StructField): + def __init__(self, structfield: StructField) -> None: + self.structfield = structfield + + def __repr__(self): + return "StructField(%s, %s, %s, %s)" % ( + self.structfield.name, + self.structfield.dataType, + str(self.structfield.nullable).lower(), + str(self.structfield.metadata) + ) + + def assert_schema_equality(s1, s2, ignore_nullable=False, ignore_metadata=False): if not ignore_nullable and not ignore_metadata: assert_basic_schema_equality(s1, s2) @@ -30,9 +44,9 @@ def inner(s1, s2, ignore_nullable, ignore_metadata): zipped = list(six.moves.zip_longest(s1, s2)) for sf1, sf2 in zipped: if are_structfields_equal(sf1, sf2, True): - t.add_row([blue(sf1), blue(sf2)]) + t.add_row([blue(StructFieldPrettyPrint(sf1)), blue(StructFieldPrettyPrint(sf2))]) else: - t.add_row([sf1, sf2]) + t.add_row([StructFieldPrettyPrint(sf1), StructFieldPrettyPrint(sf2)]) raise SchemasNotEqualError("\n" + t.get_string()) @@ -45,9 +59,9 @@ def assert_basic_schema_equality(s1, s2): zipped = list(six.moves.zip_longest(s1, s2)) for sf1, sf2 in zipped: if sf1 == sf2: - t.add_row([blue(sf1), blue(sf2)]) + t.add_row([blue(StructFieldPrettyPrint(sf1)), blue(StructFieldPrettyPrint(sf2))]) else: - t.add_row([sf1, sf2]) + t.add_row([StructFieldPrettyPrint(sf1), StructFieldPrettyPrint(sf2)]) raise SchemasNotEqualError("\n" + t.get_string()) @@ -59,9 +73,9 @@ def assert_schema_equality_ignore_nullable(s1, s2): zipped = list(six.moves.zip_longest(s1, s2)) for sf1, sf2 in zipped: if are_structfields_equal(sf1, sf2, True): - t.add_row([blue(sf1), blue(sf2)]) + t.add_row([blue(StructFieldPrettyPrint(sf1)), blue(StructFieldPrettyPrint(sf2))]) else: - t.add_row([sf1, sf2]) + t.add_row([StructFieldPrettyPrint(sf1), StructFieldPrettyPrint(sf2)]) raise SchemasNotEqualError("\n" + t.get_string())