39
39
import snowflake .snowpark as sp
40
40
from snowflake .connector .errors import DatabaseError , ProgrammingError
41
41
from snowflake .snowpark import Window
42
+ from snowflake .snowpark .exceptions import SnowparkSQLException
42
43
from snowflake .snowpark .functions import (
43
44
abs ,
44
45
col ,
@@ -425,32 +426,38 @@ def _intersect_compare(self, ignore_spaces: bool) -> None:
425
426
self .abs_tol ,
426
427
ignore_spaces ,
427
428
)
428
- row_cnt = self .intersect_rows .count ()
429
429
430
430
with ThreadPoolExecutor () as executor :
431
431
futures = []
432
432
for column in self .intersect_columns ():
433
- future = executor .submit (
434
- self ._calculate_column_compare_stats , column , row_cnt
435
- )
433
+ future = executor .submit (self ._calculate_column_compare_stats , column )
436
434
futures .append (future )
437
435
for future in as_completed (futures ):
438
436
if future .exception ():
439
437
raise future .exception ()
440
438
441
- def _calculate_column_compare_stats (self , column : str , row_cnt : int ) -> None :
439
+ def _calculate_column_compare_stats (self , column : str ) -> None :
442
440
"""Populate the column stats for all intersecting column pairs.
443
441
444
442
Calculates compare stats by intersecting column pairs. For the non-trivial case
445
443
where intersecting columns are not join columns, a match count, max difference,
446
444
and null difference must be calculated.
447
445
"""
448
446
if column in self .join_columns :
449
- match_cnt = row_cnt
450
- col_match = ""
447
+ col_match = column + "_MATCH"
448
+ match_cnt = self .intersect_rows .count ()
449
+ if not self .only_join_columns ():
450
+ row_cnt = self .intersect_rows .count ()
451
+ else :
452
+ row_cnt = (
453
+ self .intersect_rows .count ()
454
+ + self .df1_unq_rows .count ()
455
+ + self .df2_unq_rows .count ()
456
+ )
451
457
max_diff = 0
452
458
null_diff = 0
453
459
else :
460
+ row_cnt = self .intersect_rows .count ()
454
461
col_1 = column + "_" + self .df1_name
455
462
col_2 = column + "_" + self .df2_name
456
463
col_match = column + "_MATCH"
@@ -551,6 +558,8 @@ def count_matching_rows(self) -> int:
551
558
552
559
def intersect_rows_match (self ) -> bool :
553
560
"""Check whether the intersect rows all match."""
561
+ if self .intersect_rows .count () == 0 :
562
+ return False
554
563
actual_length = self .intersect_rows .count ()
555
564
return self .count_matching_rows () == actual_length
556
565
@@ -616,37 +625,62 @@ def sample_mismatch(
616
625
"pertinent" columns, for rows that don't match on the provided
617
626
column.
618
627
"""
619
- row_cnt = self .intersect_rows .count ()
620
- col_match = self .intersect_rows .select (column + "_MATCH" )
621
- match_cnt = col_match .where (
622
- col (column + "_MATCH" ) == True # noqa: E712
623
- ).count ()
624
- sample_count = min (sample_count , row_cnt - match_cnt )
625
- sample = (
626
- self .intersect_rows .where (col (column + "_MATCH" ) == False ) # noqa: E712
627
- .drop (column + "_MATCH" )
628
- .limit (sample_count )
629
- )
628
+ column = column .upper ()
629
+ if not self .only_join_columns () and column not in self .join_columns :
630
+ row_cnt = self .intersect_rows .count ()
631
+ col_match = self .intersect_rows .select (column + "_MATCH" )
632
+ try :
633
+ col_match .collect ()
634
+ except SnowparkSQLException :
635
+ LOG .error (
636
+ f"Column: { column } is not an intersecting column. No mismatches can be generated."
637
+ )
638
+ return None
639
+ match_cnt = col_match .where (
640
+ col (column + "_MATCH" ) == True # noqa: E712
641
+ ).count ()
642
+ sample_count = min (sample_count , row_cnt - match_cnt )
643
+ sample = (
644
+ self .intersect_rows .where (col (column + "_MATCH" ) == False ) # noqa: E712
645
+ .drop (column + "_MATCH" )
646
+ .limit (sample_count )
647
+ )
630
648
631
- for c in self .join_columns :
632
- sample = sample .withColumnRenamed (c + "_" + self .df1_name , c )
633
-
634
- return_cols = [
635
- * self .join_columns ,
636
- column + "_" + self .df1_name ,
637
- column + "_" + self .df2_name ,
638
- ]
639
- to_return = sample .select (return_cols )
640
-
641
- if for_display :
642
- return to_return .toDF (
643
- * [
644
- * self .join_columns ,
645
- column + " (" + self .df1_name + ")" ,
646
- column + " (" + self .df2_name + ")" ,
647
- ]
649
+ for c in self .join_columns :
650
+ sample = sample .withColumnRenamed (c + "_" + self .df1_name , c )
651
+
652
+ return_cols = [
653
+ * self .join_columns ,
654
+ column + "_" + self .df1_name ,
655
+ column + "_" + self .df2_name ,
656
+ ]
657
+ to_return = sample .select (return_cols )
658
+
659
+ if for_display :
660
+ return to_return .toDF (
661
+ * [
662
+ * self .join_columns ,
663
+ column + " (" + self .df1_name + ")" ,
664
+ column + " (" + self .df2_name + ")" ,
665
+ ]
666
+ )
667
+ return to_return
668
+ else :
669
+ row_cnt = (
670
+ self .intersect_rows .count ()
671
+ + self .df1_unq_rows .count ()
672
+ + self .df2_unq_rows .count ()
648
673
)
649
- return to_return
674
+ match_cnt = self .intersect_rows .count ()
675
+ sample_count = min (sample_count , row_cnt - match_cnt )
676
+ df1_col = column + "_" + self .df1_name
677
+ df2_col = column + "_" + self .df2_name
678
+ sample = (
679
+ self .df1_unq_rows [[df1_col ]]
680
+ .union_all (self .df2_unq_rows [[df2_col ]])
681
+ .limit (sample_count )
682
+ )
683
+ return sample .toDF (column )
650
684
651
685
def all_mismatch (self , ignore_matching_cols : bool = False ) -> "sp.DataFrame" :
652
686
"""Get all rows with any columns that have a mismatch.
@@ -666,6 +700,16 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> "sp.DataFrame":
666
700
"""
667
701
match_list = []
668
702
return_list = []
703
+ if self .only_join_columns ():
704
+ LOG .info ("Only join keys in data, returning mismatches based on unq_rows" )
705
+ df1_cols = [f"{ cols } _{ self .df1_name } " for cols in self .join_columns ]
706
+ df2_cols = [f"{ cols } _{ self .df2_name } " for cols in self .join_columns ]
707
+ to_return = self .df1_unq_rows [df1_cols ].union_all (
708
+ self .df2_unq_rows [df2_cols ]
709
+ )
710
+ for c in self .join_columns :
711
+ to_return = to_return .withColumnRenamed (c + "_" + self .df1_name , c )
712
+ return to_return
669
713
for c in self .intersect_rows .columns :
670
714
if c .endswith ("_MATCH" ):
671
715
orig_col_name = c [:- 6 ]
@@ -699,7 +743,16 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> "sp.DataFrame":
699
743
LOG .debug (
700
744
f"Column { orig_col_name } is equal in df1 and df2. It will not be added to the result."
701
745
)
702
-
746
+ if len (match_list ) == 0 :
747
+ LOG .info ("No match columns found, returning mismatches based on unq_rows" )
748
+ df1_cols = [f"{ cols } _{ self .df1_name } " for cols in self .join_columns ]
749
+ df2_cols = [f"{ cols } _{ self .df2_name } " for cols in self .join_columns ]
750
+ to_return = self .df1_unq_rows [df1_cols ].union_all (
751
+ self .df2_unq_rows [df2_cols ]
752
+ )
753
+ for c in self .join_columns :
754
+ to_return = to_return .withColumnRenamed (c + "_" + self .df1_name , c )
755
+ return to_return
703
756
mm_rows = self .intersect_rows .withColumn (
704
757
"match_array" , concat (* match_list )
705
758
).where (contains (col ("match_array" ), lit ("false" )))
0 commit comments