@@ -83,8 +83,8 @@ class PolarsCompare(BaseCompare):
83
83
84
84
def __init__ (
85
85
self ,
86
- df1 : " pl.DataFrame" ,
87
- df2 : " pl.DataFrame" ,
86
+ df1 : pl .DataFrame ,
87
+ df2 : pl .DataFrame ,
88
88
join_columns : List [str ] | str ,
89
89
abs_tol : float = 0 ,
90
90
rel_tol : float = 0 ,
@@ -126,25 +126,25 @@ def __init__(
126
126
self ._compare (ignore_spaces = ignore_spaces , ignore_case = ignore_case )
127
127
128
128
@property
129
- def df1 (self ) -> " pl.DataFrame" :
129
+ def df1 (self ) -> pl .DataFrame :
130
130
"""Get the first dataframe."""
131
131
return self ._df1
132
132
133
133
@df1 .setter
134
- def df1 (self , df1 : " pl.DataFrame" ) -> None :
134
+ def df1 (self , df1 : pl .DataFrame ) -> None :
135
135
"""Check that it is a dataframe and has the join columns."""
136
136
self ._df1 = df1
137
137
self ._validate_dataframe (
138
138
"df1" , cast_column_names_lower = self .cast_column_names_lower
139
139
)
140
140
141
141
@property
142
- def df2 (self ) -> " pl.DataFrame" :
142
+ def df2 (self ) -> pl .DataFrame :
143
143
"""Get the second dataframe."""
144
144
return self ._df2
145
145
146
146
@df2 .setter
147
- def df2 (self , df2 : " pl.DataFrame" ) -> None :
147
+ def df2 (self , df2 : pl .DataFrame ) -> None :
148
148
"""Check that it is a dataframe and has the join columns."""
149
149
self ._df2 = df2
150
150
self ._validate_dataframe (
@@ -331,14 +331,22 @@ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
331
331
null_diff : int | float
332
332
333
333
LOG .debug ("Comparing intersection" )
334
- row_cnt = len (self .intersect_rows )
335
334
for column in self .intersect_columns ():
336
335
if column in self .join_columns :
337
- match_cnt = row_cnt
338
- col_match = ""
336
+ col_match = column + "_match"
337
+ match_cnt = len (self .intersect_rows )
338
+ if not self .only_join_columns ():
339
+ row_cnt = len (self .intersect_rows )
340
+ else :
341
+ row_cnt = (
342
+ len (self .intersect_rows )
343
+ + len (self .df1_unq_rows )
344
+ + len (self .df2_unq_rows )
345
+ )
339
346
max_diff = 0.0
340
347
null_diff = 0
341
348
else :
349
+ row_cnt = len (self .intersect_rows )
342
350
col_1 = column + "_" + self .df1_name
343
351
col_2 = column + "_" + self .df2_name
344
352
col_match = column + "_match"
@@ -429,6 +437,8 @@ def count_matching_rows(self) -> int:
429
437
430
438
def intersect_rows_match (self ) -> bool :
431
439
"""Check whether the intersect rows all match."""
440
+ if self .intersect_rows .is_empty ():
441
+ return False
432
442
actual_length = self .intersect_rows .shape [0 ]
433
443
return self .count_matching_rows () == actual_length
434
444
@@ -471,7 +481,7 @@ def subset(self) -> bool:
471
481
472
482
def sample_mismatch (
473
483
self , column : str , sample_count : int = 10 , for_display : bool = False
474
- ) -> " pl.DataFrame" :
484
+ ) -> pl .DataFrame | None :
475
485
"""Return sample mismatches.
476
486
477
487
Get a sub-dataframe which contains the identifying
@@ -493,29 +503,46 @@ def sample_mismatch(
493
503
A sample of the intersection dataframe, containing only the
494
504
"pertinent" columns, for rows that don't match on the provided
495
505
column.
506
+
507
+ None
508
+ When the column being requested is not an intersecting column between dataframes.
496
509
"""
497
- row_cnt = self .intersect_rows .shape [0 ]
498
- col_match = self .intersect_rows [column + "_match" ]
499
- match_cnt = col_match .sum ()
500
- sample_count = min (sample_count , row_cnt - match_cnt ) # type: ignore
501
- sample = self .intersect_rows .filter (
502
- pl .col (column + "_match" ) != True # noqa: E712
503
- ).sample (sample_count )
504
- return_cols = [
505
- * self .join_columns ,
506
- column + "_" + self .df1_name ,
507
- column + "_" + self .df2_name ,
508
- ]
509
- to_return = sample [return_cols ]
510
- if for_display :
511
- to_return .columns = [
510
+ if not self .only_join_columns () and column not in self .join_columns :
511
+ row_cnt = self .intersect_rows .shape [0 ]
512
+ col_match = self .intersect_rows [column + "_match" ]
513
+ match_cnt = col_match .sum ()
514
+ sample_count = min (sample_count , row_cnt - match_cnt ) # type: ignore
515
+ sample = self .intersect_rows .filter (
516
+ pl .col (column + "_match" ) != True # noqa: E712
517
+ ).sample (sample_count )
518
+ return_cols = [
512
519
* self .join_columns ,
513
- column + " ( " + self .df1_name + ")" ,
514
- column + " ( " + self .df2_name + ")" ,
520
+ column + "_ " + self .df1_name ,
521
+ column + "_ " + self .df2_name ,
515
522
]
516
- return to_return
517
-
518
- def all_mismatch (self , ignore_matching_cols : bool = False ) -> "pl.DataFrame" :
523
+ to_return = sample [return_cols ]
524
+ if for_display :
525
+ to_return .columns = [
526
+ * self .join_columns ,
527
+ column + " (" + self .df1_name + ")" ,
528
+ column + " (" + self .df2_name + ")" ,
529
+ ]
530
+ return to_return
531
+ else :
532
+ row_cnt = (
533
+ len (self .intersect_rows )
534
+ + len (self .df1_unq_rows )
535
+ + len (self .df2_unq_rows )
536
+ )
537
+ col_match = self .intersect_rows [column ]
538
+ match_cnt = col_match .count ()
539
+ sample_count = min (sample_count , row_cnt - match_cnt )
540
+ sample = pl .concat (
541
+ [self .df1_unq_rows [[column ]], self .df2_unq_rows [[column ]]]
542
+ ).sample (sample_count )
543
+ return sample
544
+
545
+ def all_mismatch (self , ignore_matching_cols : bool = False ) -> pl .DataFrame :
519
546
"""Get all rows with any columns that have a mismatch.
520
547
521
548
Returns all df1 and df2 versions of the columns and join
@@ -533,6 +560,10 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> "pl.DataFrame":
533
560
"""
534
561
match_list = []
535
562
return_list = []
563
+ if self .only_join_columns ():
564
+ LOG .info ("Only join keys in data, returning mismatches based on unq_rows" )
565
+ return pl .concat ([self .df1_unq_rows , self .df2_unq_rows ])
566
+
536
567
for col in self .intersect_rows .columns :
537
568
if col .endswith ("_match" ):
538
569
orig_col_name = col [:- 6 ]
@@ -561,6 +592,15 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> "pl.DataFrame":
561
592
LOG .debug (
562
593
f"Column { orig_col_name } is equal in df1 and df2. It will not be added to the result."
563
594
)
595
+ if len (match_list ) == 0 :
596
+ LOG .info ("No match columns found, returning mismatches based on unq_rows" )
597
+ return pl .concat (
598
+ [
599
+ self .df1_unq_rows .select (self .join_columns ),
600
+ self .df2_unq_rows .select (self .join_columns ),
601
+ ]
602
+ )
603
+
564
604
return (
565
605
self .intersect_rows .with_columns (__all = pl .all_horizontal (match_list ))
566
606
.filter (pl .col ("__all" ) != True ) # noqa: E712
@@ -595,7 +635,7 @@ def report(
595
635
The report, formatted kinda nicely.
596
636
"""
597
637
598
- def df_to_str (pdf : " pl.DataFrame" ) -> str :
638
+ def df_to_str (pdf : pl .DataFrame ) -> str :
599
639
return pdf .to_pandas ().to_string ()
600
640
601
641
# Header
@@ -887,7 +927,7 @@ def compare_string_and_date_columns(
887
927
888
928
889
929
def get_merged_columns (
890
- original_df : " pl.DataFrame" , merged_df : " pl.DataFrame" , suffix : str
930
+ original_df : pl .DataFrame , merged_df : pl .DataFrame , suffix : str
891
931
) -> List [str ]:
892
932
"""Get the columns from an original dataframe, in the new merged dataframe.
893
933
@@ -936,7 +976,7 @@ def calculate_max_diff(col_1: "pl.Series", col_2: "pl.Series") -> float:
936
976
937
977
938
978
def generate_id_within_group (
939
- dataframe : " pl.DataFrame" , join_columns : List [str ]
979
+ dataframe : pl .DataFrame , join_columns : List [str ]
940
980
) -> "pl.Series" :
941
981
"""Generate an ID column that can be used to deduplicate identical rows.
942
982
0 commit comments