Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions chispa/schema_comparer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
from chispa.prettytable import PrettyTable
from chispa.bcolors import *
import chispa.six as six
from pyspark.sql.types import StructField


class SchemasNotEqualError(Exception):
"""The schemas are not equal"""
pass


class StructFieldPrettyPrint(StructField):
def __init__(self, structfield: StructField) -> None:
self.structfield = structfield

def __repr__(self):
return "StructField(%s, %s, %s, %s)" % (
self.structfield.name,
self.structfield.dataType,
str(self.structfield.nullable).lower(),
str(self.structfield.metadata)
)


def assert_schema_equality(s1, s2, ignore_nullable=False, ignore_metadata=False):
if not ignore_nullable and not ignore_metadata:
assert_basic_schema_equality(s1, s2)
Expand All @@ -30,9 +44,9 @@ def inner(s1, s2, ignore_nullable, ignore_metadata):
zipped = list(six.moves.zip_longest(s1, s2))
for sf1, sf2 in zipped:
if are_structfields_equal(sf1, sf2, True):
t.add_row([blue(sf1), blue(sf2)])
t.add_row([blue(StructFieldPrettyPrint(sf1)), blue(StructFieldPrettyPrint(sf2))])
else:
t.add_row([sf1, sf2])
t.add_row([StructFieldPrettyPrint(sf1), StructFieldPrettyPrint(sf2)])
raise SchemasNotEqualError("\n" + t.get_string())


Expand All @@ -45,9 +59,9 @@ def assert_basic_schema_equality(s1, s2):
zipped = list(six.moves.zip_longest(s1, s2))
for sf1, sf2 in zipped:
if sf1 == sf2:
t.add_row([blue(sf1), blue(sf2)])
t.add_row([blue(StructFieldPrettyPrint(sf1)), blue(StructFieldPrettyPrint(sf2))])
else:
t.add_row([sf1, sf2])
t.add_row([StructFieldPrettyPrint(sf1), StructFieldPrettyPrint(sf2)])
raise SchemasNotEqualError("\n" + t.get_string())


Expand All @@ -59,9 +73,9 @@ def assert_schema_equality_ignore_nullable(s1, s2):
zipped = list(six.moves.zip_longest(s1, s2))
for sf1, sf2 in zipped:
if are_structfields_equal(sf1, sf2, True):
t.add_row([blue(sf1), blue(sf2)])
t.add_row([blue(StructFieldPrettyPrint(sf1)), blue(StructFieldPrettyPrint(sf2))])
else:
t.add_row([sf1, sf2])
t.add_row([StructFieldPrettyPrint(sf1), StructFieldPrettyPrint(sf2)])
raise SchemasNotEqualError("\n" + t.get_string())


Expand Down