Skip to content

Commit 5efe92a

Browse files
authored
intersection of rows as the datasets have no mutual key/connection - Spark/SF (#388)
* spark full join * snowflake full join * order import to fix spark actions tests * use pandas to compare Spark and SQL test outputs * replace isEmpty for version compatibility * sample mismatch column check
1 parent 3a7a9fa commit 5efe92a

File tree

6 files changed

+489
-74
lines changed

6 files changed

+489
-74
lines changed

datacompy/core.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,16 +343,15 @@ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
343343
for column in self.intersect_columns():
344344
if column in self.join_columns:
345345
col_match = column + "_match"
346+
match_cnt = len(self.intersect_rows)
346347
if not self.only_join_columns():
347348
row_cnt = len(self.intersect_rows)
348-
match_cnt = len(self.intersect_rows[column])
349349
else:
350350
row_cnt = (
351351
len(self.intersect_rows)
352352
+ len(self.df1_unq_rows)
353353
+ len(self.df2_unq_rows)
354354
)
355-
match_cnt = len(self.intersect_rows[column])
356355
max_diff = 0.0
357356
null_diff = 0
358357
else:

datacompy/polars.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,16 +334,15 @@ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
334334
for column in self.intersect_columns():
335335
if column in self.join_columns:
336336
col_match = column + "_match"
337+
match_cnt = len(self.intersect_rows)
337338
if not self.only_join_columns():
338339
row_cnt = len(self.intersect_rows)
339-
match_cnt = len(self.intersect_rows[column])
340340
else:
341341
row_cnt = (
342342
len(self.intersect_rows)
343343
+ len(self.df1_unq_rows)
344344
+ len(self.df2_unq_rows)
345345
)
346-
match_cnt = len(self.intersect_rows[column])
347346
max_diff = 0.0
348347
null_diff = 0
349348
else:

datacompy/snowflake.py

Lines changed: 90 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import snowflake.snowpark as sp
4040
from snowflake.connector.errors import DatabaseError, ProgrammingError
4141
from snowflake.snowpark import Window
42+
from snowflake.snowpark.exceptions import SnowparkSQLException
4243
from snowflake.snowpark.functions import (
4344
abs,
4445
col,
@@ -425,32 +426,38 @@ def _intersect_compare(self, ignore_spaces: bool) -> None:
425426
self.abs_tol,
426427
ignore_spaces,
427428
)
428-
row_cnt = self.intersect_rows.count()
429429

430430
with ThreadPoolExecutor() as executor:
431431
futures = []
432432
for column in self.intersect_columns():
433-
future = executor.submit(
434-
self._calculate_column_compare_stats, column, row_cnt
435-
)
433+
future = executor.submit(self._calculate_column_compare_stats, column)
436434
futures.append(future)
437435
for future in as_completed(futures):
438436
if future.exception():
439437
raise future.exception()
440438

441-
def _calculate_column_compare_stats(self, column: str, row_cnt: int) -> None:
439+
def _calculate_column_compare_stats(self, column: str) -> None:
442440
"""Populate the column stats for all intersecting column pairs.
443441
444442
Calculates compare stats by intersecting column pairs. For the non-trivial case
445443
where intersecting columns are not join columns, a match count, max difference,
446444
and null difference must be calculated.
447445
"""
448446
if column in self.join_columns:
449-
match_cnt = row_cnt
450-
col_match = ""
447+
col_match = column + "_MATCH"
448+
match_cnt = self.intersect_rows.count()
449+
if not self.only_join_columns():
450+
row_cnt = self.intersect_rows.count()
451+
else:
452+
row_cnt = (
453+
self.intersect_rows.count()
454+
+ self.df1_unq_rows.count()
455+
+ self.df2_unq_rows.count()
456+
)
451457
max_diff = 0
452458
null_diff = 0
453459
else:
460+
row_cnt = self.intersect_rows.count()
454461
col_1 = column + "_" + self.df1_name
455462
col_2 = column + "_" + self.df2_name
456463
col_match = column + "_MATCH"
@@ -551,6 +558,8 @@ def count_matching_rows(self) -> int:
551558

552559
def intersect_rows_match(self) -> bool:
553560
"""Check whether the intersect rows all match."""
561+
if self.intersect_rows.count() == 0:
562+
return False
554563
actual_length = self.intersect_rows.count()
555564
return self.count_matching_rows() == actual_length
556565

@@ -616,37 +625,62 @@ def sample_mismatch(
616625
"pertinent" columns, for rows that don't match on the provided
617626
column.
618627
"""
619-
row_cnt = self.intersect_rows.count()
620-
col_match = self.intersect_rows.select(column + "_MATCH")
621-
match_cnt = col_match.where(
622-
col(column + "_MATCH") == True # noqa: E712
623-
).count()
624-
sample_count = min(sample_count, row_cnt - match_cnt)
625-
sample = (
626-
self.intersect_rows.where(col(column + "_MATCH") == False) # noqa: E712
627-
.drop(column + "_MATCH")
628-
.limit(sample_count)
629-
)
628+
column = column.upper()
629+
if not self.only_join_columns() and column not in self.join_columns:
630+
row_cnt = self.intersect_rows.count()
631+
col_match = self.intersect_rows.select(column + "_MATCH")
632+
try:
633+
col_match.collect()
634+
except SnowparkSQLException:
635+
LOG.error(
636+
f"Column: {column} is not an intersecting column. No mismatches can be generated."
637+
)
638+
return None
639+
match_cnt = col_match.where(
640+
col(column + "_MATCH") == True # noqa: E712
641+
).count()
642+
sample_count = min(sample_count, row_cnt - match_cnt)
643+
sample = (
644+
self.intersect_rows.where(col(column + "_MATCH") == False) # noqa: E712
645+
.drop(column + "_MATCH")
646+
.limit(sample_count)
647+
)
630648

631-
for c in self.join_columns:
632-
sample = sample.withColumnRenamed(c + "_" + self.df1_name, c)
633-
634-
return_cols = [
635-
*self.join_columns,
636-
column + "_" + self.df1_name,
637-
column + "_" + self.df2_name,
638-
]
639-
to_return = sample.select(return_cols)
640-
641-
if for_display:
642-
return to_return.toDF(
643-
*[
644-
*self.join_columns,
645-
column + " (" + self.df1_name + ")",
646-
column + " (" + self.df2_name + ")",
647-
]
649+
for c in self.join_columns:
650+
sample = sample.withColumnRenamed(c + "_" + self.df1_name, c)
651+
652+
return_cols = [
653+
*self.join_columns,
654+
column + "_" + self.df1_name,
655+
column + "_" + self.df2_name,
656+
]
657+
to_return = sample.select(return_cols)
658+
659+
if for_display:
660+
return to_return.toDF(
661+
*[
662+
*self.join_columns,
663+
column + " (" + self.df1_name + ")",
664+
column + " (" + self.df2_name + ")",
665+
]
666+
)
667+
return to_return
668+
else:
669+
row_cnt = (
670+
self.intersect_rows.count()
671+
+ self.df1_unq_rows.count()
672+
+ self.df2_unq_rows.count()
648673
)
649-
return to_return
674+
match_cnt = self.intersect_rows.count()
675+
sample_count = min(sample_count, row_cnt - match_cnt)
676+
df1_col = column + "_" + self.df1_name
677+
df2_col = column + "_" + self.df2_name
678+
sample = (
679+
self.df1_unq_rows[[df1_col]]
680+
.union_all(self.df2_unq_rows[[df2_col]])
681+
.limit(sample_count)
682+
)
683+
return sample.toDF(column)
650684

651685
def all_mismatch(self, ignore_matching_cols: bool = False) -> "sp.DataFrame":
652686
"""Get all rows with any columns that have a mismatch.
@@ -666,6 +700,16 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> "sp.DataFrame":
666700
"""
667701
match_list = []
668702
return_list = []
703+
if self.only_join_columns():
704+
LOG.info("Only join keys in data, returning mismatches based on unq_rows")
705+
df1_cols = [f"{cols}_{self.df1_name}" for cols in self.join_columns]
706+
df2_cols = [f"{cols}_{self.df2_name}" for cols in self.join_columns]
707+
to_return = self.df1_unq_rows[df1_cols].union_all(
708+
self.df2_unq_rows[df2_cols]
709+
)
710+
for c in self.join_columns:
711+
to_return = to_return.withColumnRenamed(c + "_" + self.df1_name, c)
712+
return to_return
669713
for c in self.intersect_rows.columns:
670714
if c.endswith("_MATCH"):
671715
orig_col_name = c[:-6]
@@ -699,7 +743,16 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> "sp.DataFrame":
699743
LOG.debug(
700744
f"Column {orig_col_name} is equal in df1 and df2. It will not be added to the result."
701745
)
702-
746+
if len(match_list) == 0:
747+
LOG.info("No match columns found, returning mismatches based on unq_rows")
748+
df1_cols = [f"{cols}_{self.df1_name}" for cols in self.join_columns]
749+
df2_cols = [f"{cols}_{self.df2_name}" for cols in self.join_columns]
750+
to_return = self.df1_unq_rows[df1_cols].union_all(
751+
self.df2_unq_rows[df2_cols]
752+
)
753+
for c in self.join_columns:
754+
to_return = to_return.withColumnRenamed(c + "_" + self.df1_name, c)
755+
return to_return
703756
mm_rows = self.intersect_rows.withColumn(
704757
"match_array", concat(*match_list)
705758
).where(contains(col("match_array"), lit("false")))

datacompy/spark/sql.py

Lines changed: 74 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -459,14 +459,22 @@ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
459459
LOG.debug("Comparing intersection")
460460
max_diff: float
461461
null_diff: int
462-
row_cnt = self.intersect_rows.count()
463462
for column in self.intersect_columns():
464463
if column in self.join_columns:
465-
match_cnt = row_cnt
466-
col_match = ""
464+
col_match = column + "_match"
465+
match_cnt = self.intersect_rows.count()
466+
if not self.only_join_columns():
467+
row_cnt = self.intersect_rows.count()
468+
else:
469+
row_cnt = (
470+
self.intersect_rows.count()
471+
+ self.df1_unq_rows.count()
472+
+ self.df2_unq_rows.count()
473+
)
467474
max_diff = 0
468475
null_diff = 0
469476
else:
477+
row_cnt = self.intersect_rows.count()
470478
col_1 = column + "_" + self.df1_name
471479
col_2 = column + "_" + self.df2_name
472480
col_match = column + "_match"
@@ -561,6 +569,8 @@ def count_matching_rows(self) -> int:
561569

562570
def intersect_rows_match(self) -> bool:
563571
"""Check whether the intersect rows all match."""
572+
if self.intersect_rows.count() == 0:
573+
return False
564574
actual_length = self.intersect_rows.count()
565575
return self.count_matching_rows() == actual_length
566576

@@ -621,37 +631,54 @@ def sample_mismatch(
621631
"pertinent" columns, for rows that don't match on the provided
622632
column.
623633
"""
624-
row_cnt = self.intersect_rows.count()
625-
col_match = self.intersect_rows.select(column + "_match")
626-
match_cnt = col_match.where(
627-
col(column + "_match") == True # noqa: E712
628-
).count()
629-
sample_count = min(sample_count, row_cnt - match_cnt)
630-
sample = (
631-
self.intersect_rows.where(col(column + "_match") == False) # noqa: E712
632-
.drop(column + "_match")
633-
.limit(sample_count)
634-
)
635-
636-
for c in self.join_columns:
637-
sample = sample.withColumnRenamed(c + "_" + self.df1_name, c)
634+
if not self.only_join_columns() and column not in self.join_columns:
635+
row_cnt = self.intersect_rows.count()
636+
col_match = self.intersect_rows.select(column + "_match")
637+
match_cnt = col_match.where(
638+
col(column + "_match") == True # noqa: E712
639+
).count()
640+
sample_count = min(sample_count, row_cnt - match_cnt)
641+
sample = (
642+
self.intersect_rows.where(col(column + "_match") == False) # noqa: E712
643+
.drop(column + "_match")
644+
.limit(sample_count)
645+
)
638646

639-
return_cols = [
640-
*self.join_columns,
641-
column + "_" + self.df1_name,
642-
column + "_" + self.df2_name,
643-
]
644-
to_return = sample.select(return_cols)
647+
for c in self.join_columns:
648+
sample = sample.withColumnRenamed(c + "_" + self.df1_name, c)
645649

646-
if for_display:
647-
return to_return.toDF(
648-
*[
649-
*self.join_columns,
650-
column + " (" + self.df1_name + ")",
651-
column + " (" + self.df2_name + ")",
652-
]
650+
return_cols = [
651+
*self.join_columns,
652+
column + "_" + self.df1_name,
653+
column + "_" + self.df2_name,
654+
]
655+
to_return = sample.select(return_cols)
656+
657+
if for_display:
658+
return to_return.toDF(
659+
*[
660+
*self.join_columns,
661+
column + " (" + self.df1_name + ")",
662+
column + " (" + self.df2_name + ")",
663+
]
664+
)
665+
return to_return
666+
else:
667+
row_cnt = (
668+
self.intersect_rows.count()
669+
+ self.df1_unq_rows.count()
670+
+ self.df2_unq_rows.count()
671+
)
672+
match_cnt = self.intersect_rows.count()
673+
sample_count = min(sample_count, row_cnt - match_cnt)
674+
df1_col = column + "_" + self.df1_name
675+
df2_col = column + "_" + self.df2_name
676+
sample = (
677+
self.df1_unq_rows[[df1_col]]
678+
.union(self.df2_unq_rows[[df2_col]])
679+
.limit(sample_count)
653680
)
654-
return to_return
681+
return sample.toDF(column)
655682

656683
def all_mismatch(
657684
self, ignore_matching_cols: bool = False
@@ -673,6 +700,14 @@ def all_mismatch(
673700
"""
674701
match_list = []
675702
return_list = []
703+
if self.only_join_columns():
704+
LOG.info("Only join keys in data, returning mismatches based on unq_rows")
705+
df1_cols = [f"{cols}_{self.df1_name}" for cols in self.join_columns]
706+
df2_cols = [f"{cols}_{self.df2_name}" for cols in self.join_columns]
707+
to_return = self.df1_unq_rows[df1_cols].union(self.df2_unq_rows[df2_cols])
708+
for c in self.join_columns:
709+
to_return = to_return.withColumnRenamed(c + "_" + self.df1_name, c)
710+
return to_return
676711
for c in self.intersect_rows.columns:
677712
if c.endswith("_match"):
678713
orig_col_name = c[:-6]
@@ -707,6 +742,14 @@ def all_mismatch(
707742
LOG.debug(
708743
f"Column {orig_col_name} is equal in df1 and df2. It will not be added to the result."
709744
)
745+
if len(match_list) == 0:
746+
LOG.info("No match columns found, returning mismatches based on unq_rows")
747+
df1_cols = [f"{cols}_{self.df1_name}" for cols in self.join_columns]
748+
df2_cols = [f"{cols}_{self.df2_name}" for cols in self.join_columns]
749+
to_return = self.df1_unq_rows[df1_cols].union(self.df2_unq_rows[df2_cols])
750+
for c in self.join_columns:
751+
to_return = to_return.withColumnRenamed(c + "_" + self.df1_name, c)
752+
return to_return
710753

711754
mm_rows = self.intersect_rows.withColumn(
712755
"match_array", array(match_list)

0 commit comments

Comments
 (0)