@@ -83,8 +83,8 @@ class PolarsCompare(BaseCompare):
83
83
84
84
def __init__ (
85
85
self ,
86
- df1 : " pl.DataFrame" ,
87
- df2 : " pl.DataFrame" ,
86
+ df1 : pl .DataFrame ,
87
+ df2 : pl .DataFrame ,
88
88
join_columns : List [str ] | str ,
89
89
abs_tol : float = 0 ,
90
90
rel_tol : float = 0 ,
@@ -126,25 +126,25 @@ def __init__(
126
126
self ._compare (ignore_spaces = ignore_spaces , ignore_case = ignore_case )
127
127
128
128
@property
129
- def df1 (self ) -> " pl.DataFrame" :
129
+ def df1 (self ) -> pl .DataFrame :
130
130
"""Get the first dataframe."""
131
131
return self ._df1
132
132
133
133
@df1 .setter
134
- def df1 (self , df1 : " pl.DataFrame" ) -> None :
134
+ def df1 (self , df1 : pl .DataFrame ) -> None :
135
135
"""Check that it is a dataframe and has the join columns."""
136
136
self ._df1 = df1
137
137
self ._validate_dataframe (
138
138
"df1" , cast_column_names_lower = self .cast_column_names_lower
139
139
)
140
140
141
141
@property
142
- def df2 (self ) -> " pl.DataFrame" :
142
+ def df2 (self ) -> pl .DataFrame :
143
143
"""Get the second dataframe."""
144
144
return self ._df2
145
145
146
146
@df2 .setter
147
- def df2 (self , df2 : " pl.DataFrame" ) -> None :
147
+ def df2 (self , df2 : pl .DataFrame ) -> None :
148
148
"""Check that it is a dataframe and has the join columns."""
149
149
self ._df2 = df2
150
150
self ._validate_dataframe (
@@ -331,14 +331,23 @@ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
331
331
null_diff : int | float
332
332
333
333
LOG .debug ("Comparing intersection" )
334
- row_cnt = len (self .intersect_rows )
335
334
for column in self .intersect_columns ():
336
335
if column in self .join_columns :
337
- match_cnt = row_cnt
338
- col_match = ""
336
+ col_match = column + "_match"
337
+ if not self .only_join_columns ():
338
+ row_cnt = len (self .intersect_rows )
339
+ match_cnt = len (self .intersect_rows [column ])
340
+ else :
341
+ row_cnt = (
342
+ len (self .intersect_rows )
343
+ + len (self .df1_unq_rows )
344
+ + len (self .df2_unq_rows )
345
+ )
346
+ match_cnt = len (self .intersect_rows [column ])
339
347
max_diff = 0.0
340
348
null_diff = 0
341
349
else :
350
+ row_cnt = len (self .intersect_rows )
342
351
col_1 = column + "_" + self .df1_name
343
352
col_2 = column + "_" + self .df2_name
344
353
col_match = column + "_match"
@@ -429,6 +438,8 @@ def count_matching_rows(self) -> int:
429
438
430
439
def intersect_rows_match (self ) -> bool :
431
440
"""Check whether the intersect rows all match."""
441
+ if self .intersect_rows .is_empty ():
442
+ return False
432
443
actual_length = self .intersect_rows .shape [0 ]
433
444
return self .count_matching_rows () == actual_length
434
445
@@ -471,7 +482,7 @@ def subset(self) -> bool:
471
482
472
483
def sample_mismatch (
473
484
self , column : str , sample_count : int = 10 , for_display : bool = False
474
- ) -> " pl.DataFrame" :
485
+ ) -> pl .DataFrame | None :
475
486
"""Return sample mismatches.
476
487
477
488
Get a sub-dataframe which contains the identifying
@@ -493,29 +504,46 @@ def sample_mismatch(
493
504
A sample of the intersection dataframe, containing only the
494
505
"pertinent" columns, for rows that don't match on the provided
495
506
column.
507
+
508
+ None
509
+ When the column being requested is not an intersecting column between dataframes.
496
510
"""
497
- row_cnt = self .intersect_rows .shape [0 ]
498
- col_match = self .intersect_rows [column + "_match" ]
499
- match_cnt = col_match .sum ()
500
- sample_count = min (sample_count , row_cnt - match_cnt ) # type: ignore
501
- sample = self .intersect_rows .filter (
502
- pl .col (column + "_match" ) != True # noqa: E712
503
- ).sample (sample_count )
504
- return_cols = [
505
- * self .join_columns ,
506
- column + "_" + self .df1_name ,
507
- column + "_" + self .df2_name ,
508
- ]
509
- to_return = sample [return_cols ]
510
- if for_display :
511
- to_return .columns = [
511
+ if not self .only_join_columns () and column not in self .join_columns :
512
+ row_cnt = self .intersect_rows .shape [0 ]
513
+ col_match = self .intersect_rows [column + "_match" ]
514
+ match_cnt = col_match .sum ()
515
+ sample_count = min (sample_count , row_cnt - match_cnt ) # type: ignore
516
+ sample = self .intersect_rows .filter (
517
+ pl .col (column + "_match" ) != True # noqa: E712
518
+ ).sample (sample_count )
519
+ return_cols = [
512
520
* self .join_columns ,
513
- column + " ( " + self .df1_name + ")" ,
514
- column + " ( " + self .df2_name + ")" ,
521
+ column + "_ " + self .df1_name ,
522
+ column + "_ " + self .df2_name ,
515
523
]
516
- return to_return
517
-
518
- def all_mismatch (self , ignore_matching_cols : bool = False ) -> "pl.DataFrame" :
524
+ to_return = sample [return_cols ]
525
+ if for_display :
526
+ to_return .columns = [
527
+ * self .join_columns ,
528
+ column + " (" + self .df1_name + ")" ,
529
+ column + " (" + self .df2_name + ")" ,
530
+ ]
531
+ return to_return
532
+ else :
533
+ row_cnt = (
534
+ len (self .intersect_rows )
535
+ + len (self .df1_unq_rows )
536
+ + len (self .df2_unq_rows )
537
+ )
538
+ col_match = self .intersect_rows [column ]
539
+ match_cnt = col_match .count ()
540
+ sample_count = min (sample_count , row_cnt - match_cnt )
541
+ sample = pl .concat (
542
+ [self .df1_unq_rows [[column ]], self .df2_unq_rows [[column ]]]
543
+ ).sample (sample_count )
544
+ return sample
545
+
546
+ def all_mismatch (self , ignore_matching_cols : bool = False ) -> pl .DataFrame :
519
547
"""Get all rows with any columns that have a mismatch.
520
548
521
549
Returns all df1 and df2 versions of the columns and join
@@ -533,6 +561,10 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> "pl.DataFrame":
533
561
"""
534
562
match_list = []
535
563
return_list = []
564
+ if self .only_join_columns ():
565
+ LOG .info ("Only join keys in data, returning mismatches based on unq_rows" )
566
+ return pl .concat ([self .df1_unq_rows , self .df2_unq_rows ])
567
+
536
568
for col in self .intersect_rows .columns :
537
569
if col .endswith ("_match" ):
538
570
orig_col_name = col [:- 6 ]
@@ -561,6 +593,15 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> "pl.DataFrame":
561
593
LOG .debug (
562
594
f"Column { orig_col_name } is equal in df1 and df2. It will not be added to the result."
563
595
)
596
+ if len (match_list ) == 0 :
597
+ LOG .info ("No match columns found, returning mismatches based on unq_rows" )
598
+ return pl .concat (
599
+ [
600
+ self .df1_unq_rows .select (self .join_columns ),
601
+ self .df2_unq_rows .select (self .join_columns ),
602
+ ]
603
+ )
604
+
564
605
return (
565
606
self .intersect_rows .with_columns (__all = pl .all_horizontal (match_list ))
566
607
.filter (pl .col ("__all" ) != True ) # noqa: E712
@@ -595,7 +636,7 @@ def report(
595
636
The report, formatted kinda nicely.
596
637
"""
597
638
598
- def df_to_str (pdf : " pl.DataFrame" ) -> str :
639
+ def df_to_str (pdf : pl .DataFrame ) -> str :
599
640
return pdf .to_pandas ().to_string ()
600
641
601
642
# Header
@@ -887,7 +928,7 @@ def compare_string_and_date_columns(
887
928
888
929
889
930
def get_merged_columns (
890
- original_df : " pl.DataFrame" , merged_df : " pl.DataFrame" , suffix : str
931
+ original_df : pl .DataFrame , merged_df : pl .DataFrame , suffix : str
891
932
) -> List [str ]:
892
933
"""Get the columns from an original dataframe, in the new merged dataframe.
893
934
@@ -936,7 +977,7 @@ def calculate_max_diff(col_1: "pl.Series", col_2: "pl.Series") -> float:
936
977
937
978
938
979
def generate_id_within_group (
939
- dataframe : " pl.DataFrame" , join_columns : List [str ]
980
+ dataframe : pl .DataFrame , join_columns : List [str ]
940
981
) -> "pl.Series" :
941
982
"""Generate an ID column that can be used to deduplicate identical rows.
942
983
0 commit comments