Skip to content

Commit 3a7a9fa

Browse files
authored
intersection of rows as the datasets have no mutual key/connection (#385)
1 parent fa8e539 commit 3a7a9fa

File tree

7 files changed

+366
-60
lines changed

7 files changed

+366
-60
lines changed

datacompy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
Then extended to carry that functionality over to Spark Dataframes.
1919
"""
2020

21-
__version__ = "0.16.3"
21+
__version__ = "0.16.4"
2222

2323
import platform
2424
from warnings import warn

datacompy/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,10 @@ def report(
158158
"""Return a string representation of a report."""
159159
pass
160160

161+
def only_join_columns(self) -> bool:
162+
"""Boolean on if the only columns are the join columns."""
163+
return set(self.join_columns) == set(self.df1.columns) == set(self.df2.columns)
164+
161165

162166
def temp_column_name(*dataframes) -> str:
163167
"""Get a temp column name that isn't included in columns of any dataframes.

datacompy/core.py

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -340,14 +340,23 @@ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
340340
otherwise.
341341
"""
342342
LOG.debug("Comparing intersection")
343-
row_cnt = len(self.intersect_rows)
344343
for column in self.intersect_columns():
345344
if column in self.join_columns:
346-
match_cnt = row_cnt
347-
col_match = ""
345+
col_match = column + "_match"
346+
if not self.only_join_columns():
347+
row_cnt = len(self.intersect_rows)
348+
match_cnt = len(self.intersect_rows[column])
349+
else:
350+
row_cnt = (
351+
len(self.intersect_rows)
352+
+ len(self.df1_unq_rows)
353+
+ len(self.df2_unq_rows)
354+
)
355+
match_cnt = len(self.intersect_rows[column])
348356
max_diff = 0.0
349357
null_diff = 0
350358
else:
359+
row_cnt = len(self.intersect_rows)
351360
col_1 = column + "_" + self.df1_name
352361
col_2 = column + "_" + self.df2_name
353362
col_match = column + "_match"
@@ -428,6 +437,8 @@ def count_matching_rows(self) -> int:
428437

429438
def intersect_rows_match(self) -> bool:
430439
"""Check whether the intersect rows all match."""
440+
if self.intersect_rows.empty:
441+
return False
431442
actual_length = self.intersect_rows.shape[0]
432443
return self.count_matching_rows() == actual_length
433444

@@ -470,7 +481,7 @@ def subset(self) -> bool:
470481

471482
def sample_mismatch(
472483
self, column: str, sample_count: int = 10, for_display: bool = False
473-
) -> pd.DataFrame:
484+
) -> pd.DataFrame | None:
474485
"""Return sample mismatches.
475486
476487
Gets a sub-dataframe which contains the identifying
@@ -492,27 +503,53 @@ def sample_mismatch(
492503
A sample of the intersection dataframe, containing only the
493504
"pertinent" columns, for rows that don't match on the provided
494505
column.
506+
507+
None
508+
When the column being requested is not an intersecting column between dataframes.
495509
"""
496-
row_cnt = self.intersect_rows.shape[0]
497-
col_match = self.intersect_rows[column + "_match"]
498-
match_cnt = col_match.sum()
499-
sample_count = min(sample_count, row_cnt - match_cnt)
500-
sample = self.intersect_rows[~col_match].sample(sample_count)
501-
return_cols = [
502-
*self.join_columns,
503-
column + "_" + self.df1_name,
504-
column + "_" + self.df2_name,
505-
]
506-
to_return = sample[return_cols]
507-
if for_display:
508-
to_return.columns = pd.Index(
509-
[
510-
*self.join_columns,
511-
column + " (" + self.df1_name + ")",
512-
column + " (" + self.df2_name + ")",
513-
]
510+
if not self.only_join_columns() and column not in self.join_columns:
511+
row_cnt = self.intersect_rows.shape[0]
512+
try:
513+
col_match = self.intersect_rows[column + "_match"]
514+
except KeyError:
515+
LOG.error(
516+
f"Column: {column} is not an intersecting column. No mismatches can be generated."
517+
)
518+
return None
519+
match_cnt = col_match.sum()
520+
sample_count = min(sample_count, row_cnt - match_cnt)
521+
sample = self.intersect_rows[~col_match].sample(sample_count)
522+
return_cols = [
523+
*self.join_columns,
524+
column + "_" + self.df1_name,
525+
column + "_" + self.df2_name,
526+
]
527+
to_return = sample[return_cols]
528+
if for_display:
529+
to_return.columns = pd.Index(
530+
[
531+
*self.join_columns,
532+
column + " (" + self.df1_name + ")",
533+
column + " (" + self.df2_name + ")",
534+
]
535+
)
536+
return to_return
537+
else:
538+
row_cnt = (
539+
len(self.intersect_rows)
540+
+ len(self.df1_unq_rows)
541+
+ len(self.df2_unq_rows)
514542
)
515-
return to_return
543+
col_match = self.intersect_rows[column]
544+
match_cnt = col_match.count()
545+
sample_count = min(sample_count, row_cnt - match_cnt)
546+
sample = pd.concat(
547+
[self.df1_unq_rows[[column]], self.df2_unq_rows[[column]]]
548+
).sample(sample_count)
549+
to_return = sample
550+
if for_display:
551+
to_return.columns = pd.Index([column])
552+
return to_return
516553

517554
def all_mismatch(self, ignore_matching_cols: bool = False) -> pd.DataFrame:
518555
"""Get all rows with any columns that have a mismatch.
@@ -532,6 +569,10 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> pd.DataFrame:
532569
"""
533570
match_list = []
534571
return_list = []
572+
if self.only_join_columns():
573+
LOG.info("Only join keys in data, returning mismatches based on unq_rows")
574+
return pd.concat([self.df1_unq_rows, self.df2_unq_rows])
575+
535576
for col in self.intersect_rows.columns:
536577
if col.endswith("_match"):
537578
orig_col_name = col[:-6]
@@ -560,6 +601,14 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> pd.DataFrame:
560601
LOG.debug(
561602
f"Column {orig_col_name} is equal in df1 and df2. It will not be added to the result."
562603
)
604+
if len(match_list) == 0:
605+
LOG.info("No match columns found, returning mismatches based on unq_rows")
606+
return pd.concat(
607+
[
608+
self.df1_unq_rows[self.join_columns],
609+
self.df2_unq_rows[self.join_columns],
610+
]
611+
)
563612

564613
mm_bool = self.intersect_rows[match_list].all(axis="columns")
565614
return self.intersect_rows[~mm_bool][self.join_columns + return_list]

datacompy/polars.py

Lines changed: 74 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ class PolarsCompare(BaseCompare):
8383

8484
def __init__(
8585
self,
86-
df1: "pl.DataFrame",
87-
df2: "pl.DataFrame",
86+
df1: pl.DataFrame,
87+
df2: pl.DataFrame,
8888
join_columns: List[str] | str,
8989
abs_tol: float = 0,
9090
rel_tol: float = 0,
@@ -126,25 +126,25 @@ def __init__(
126126
self._compare(ignore_spaces=ignore_spaces, ignore_case=ignore_case)
127127

128128
@property
129-
def df1(self) -> "pl.DataFrame":
129+
def df1(self) -> pl.DataFrame:
130130
"""Get the first dataframe."""
131131
return self._df1
132132

133133
@df1.setter
134-
def df1(self, df1: "pl.DataFrame") -> None:
134+
def df1(self, df1: pl.DataFrame) -> None:
135135
"""Check that it is a dataframe and has the join columns."""
136136
self._df1 = df1
137137
self._validate_dataframe(
138138
"df1", cast_column_names_lower=self.cast_column_names_lower
139139
)
140140

141141
@property
142-
def df2(self) -> "pl.DataFrame":
142+
def df2(self) -> pl.DataFrame:
143143
"""Get the second dataframe."""
144144
return self._df2
145145

146146
@df2.setter
147-
def df2(self, df2: "pl.DataFrame") -> None:
147+
def df2(self, df2: pl.DataFrame) -> None:
148148
"""Check that it is a dataframe and has the join columns."""
149149
self._df2 = df2
150150
self._validate_dataframe(
@@ -331,14 +331,23 @@ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
331331
null_diff: int | float
332332

333333
LOG.debug("Comparing intersection")
334-
row_cnt = len(self.intersect_rows)
335334
for column in self.intersect_columns():
336335
if column in self.join_columns:
337-
match_cnt = row_cnt
338-
col_match = ""
336+
col_match = column + "_match"
337+
if not self.only_join_columns():
338+
row_cnt = len(self.intersect_rows)
339+
match_cnt = len(self.intersect_rows[column])
340+
else:
341+
row_cnt = (
342+
len(self.intersect_rows)
343+
+ len(self.df1_unq_rows)
344+
+ len(self.df2_unq_rows)
345+
)
346+
match_cnt = len(self.intersect_rows[column])
339347
max_diff = 0.0
340348
null_diff = 0
341349
else:
350+
row_cnt = len(self.intersect_rows)
342351
col_1 = column + "_" + self.df1_name
343352
col_2 = column + "_" + self.df2_name
344353
col_match = column + "_match"
@@ -429,6 +438,8 @@ def count_matching_rows(self) -> int:
429438

430439
def intersect_rows_match(self) -> bool:
431440
"""Check whether the intersect rows all match."""
441+
if self.intersect_rows.is_empty():
442+
return False
432443
actual_length = self.intersect_rows.shape[0]
433444
return self.count_matching_rows() == actual_length
434445

@@ -471,7 +482,7 @@ def subset(self) -> bool:
471482

472483
def sample_mismatch(
473484
self, column: str, sample_count: int = 10, for_display: bool = False
474-
) -> "pl.DataFrame":
485+
) -> pl.DataFrame | None:
475486
"""Return sample mismatches.
476487
477488
Get a sub-dataframe which contains the identifying
@@ -493,29 +504,46 @@ def sample_mismatch(
493504
A sample of the intersection dataframe, containing only the
494505
"pertinent" columns, for rows that don't match on the provided
495506
column.
507+
508+
None
509+
When the column being requested is not an intersecting column between dataframes.
496510
"""
497-
row_cnt = self.intersect_rows.shape[0]
498-
col_match = self.intersect_rows[column + "_match"]
499-
match_cnt = col_match.sum()
500-
sample_count = min(sample_count, row_cnt - match_cnt) # type: ignore
501-
sample = self.intersect_rows.filter(
502-
pl.col(column + "_match") != True # noqa: E712
503-
).sample(sample_count)
504-
return_cols = [
505-
*self.join_columns,
506-
column + "_" + self.df1_name,
507-
column + "_" + self.df2_name,
508-
]
509-
to_return = sample[return_cols]
510-
if for_display:
511-
to_return.columns = [
511+
if not self.only_join_columns() and column not in self.join_columns:
512+
row_cnt = self.intersect_rows.shape[0]
513+
col_match = self.intersect_rows[column + "_match"]
514+
match_cnt = col_match.sum()
515+
sample_count = min(sample_count, row_cnt - match_cnt) # type: ignore
516+
sample = self.intersect_rows.filter(
517+
pl.col(column + "_match") != True # noqa: E712
518+
).sample(sample_count)
519+
return_cols = [
512520
*self.join_columns,
513-
column + " (" + self.df1_name + ")",
514-
column + " (" + self.df2_name + ")",
521+
column + "_" + self.df1_name,
522+
column + "_" + self.df2_name,
515523
]
516-
return to_return
517-
518-
def all_mismatch(self, ignore_matching_cols: bool = False) -> "pl.DataFrame":
524+
to_return = sample[return_cols]
525+
if for_display:
526+
to_return.columns = [
527+
*self.join_columns,
528+
column + " (" + self.df1_name + ")",
529+
column + " (" + self.df2_name + ")",
530+
]
531+
return to_return
532+
else:
533+
row_cnt = (
534+
len(self.intersect_rows)
535+
+ len(self.df1_unq_rows)
536+
+ len(self.df2_unq_rows)
537+
)
538+
col_match = self.intersect_rows[column]
539+
match_cnt = col_match.count()
540+
sample_count = min(sample_count, row_cnt - match_cnt)
541+
sample = pl.concat(
542+
[self.df1_unq_rows[[column]], self.df2_unq_rows[[column]]]
543+
).sample(sample_count)
544+
return sample
545+
546+
def all_mismatch(self, ignore_matching_cols: bool = False) -> pl.DataFrame:
519547
"""Get all rows with any columns that have a mismatch.
520548
521549
Returns all df1 and df2 versions of the columns and join
@@ -533,6 +561,10 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> "pl.DataFrame":
533561
"""
534562
match_list = []
535563
return_list = []
564+
if self.only_join_columns():
565+
LOG.info("Only join keys in data, returning mismatches based on unq_rows")
566+
return pl.concat([self.df1_unq_rows, self.df2_unq_rows])
567+
536568
for col in self.intersect_rows.columns:
537569
if col.endswith("_match"):
538570
orig_col_name = col[:-6]
@@ -561,6 +593,15 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> "pl.DataFrame":
561593
LOG.debug(
562594
f"Column {orig_col_name} is equal in df1 and df2. It will not be added to the result."
563595
)
596+
if len(match_list) == 0:
597+
LOG.info("No match columns found, returning mismatches based on unq_rows")
598+
return pl.concat(
599+
[
600+
self.df1_unq_rows.select(self.join_columns),
601+
self.df2_unq_rows.select(self.join_columns),
602+
]
603+
)
604+
564605
return (
565606
self.intersect_rows.with_columns(__all=pl.all_horizontal(match_list))
566607
.filter(pl.col("__all") != True) # noqa: E712
@@ -595,7 +636,7 @@ def report(
595636
The report, formatted kinda nicely.
596637
"""
597638

598-
def df_to_str(pdf: "pl.DataFrame") -> str:
639+
def df_to_str(pdf: pl.DataFrame) -> str:
599640
return pdf.to_pandas().to_string()
600641

601642
# Header
@@ -887,7 +928,7 @@ def compare_string_and_date_columns(
887928

888929

889930
def get_merged_columns(
890-
original_df: "pl.DataFrame", merged_df: "pl.DataFrame", suffix: str
931+
original_df: pl.DataFrame, merged_df: pl.DataFrame, suffix: str
891932
) -> List[str]:
892933
"""Get the columns from an original dataframe, in the new merged dataframe.
893934
@@ -936,7 +977,7 @@ def calculate_max_diff(col_1: "pl.Series", col_2: "pl.Series") -> float:
936977

937978

938979
def generate_id_within_group(
939-
dataframe: "pl.DataFrame", join_columns: List[str]
980+
dataframe: pl.DataFrame, join_columns: List[str]
940981
) -> "pl.Series":
941982
"""Generate an ID column that can be used to deduplicate identical rows.
942983

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ maintainers = [
1313
{ name="Raymond Haffar", email="[email protected]" },
1414
]
1515
license = {text = "Apache Software License"}
16-
dependencies = ["pandas<=2.2.3,>=0.25.0", "numpy<=2.2.3,>=1.22.0", "ordered-set<=4.1.0,>=4.0.2", "polars[pandas]<=1.22.0,>=0.20.4"]
16+
dependencies = ["pandas<=2.2.3,>=0.25.0", "numpy<=2.2.3,>=1.22.0", "ordered-set<=4.1.0,>=4.0.2", "polars[pandas]<=1.23.0,>=0.20.4"]
1717
requires-python = ">=3.10.0"
1818
classifiers = [
1919
"Intended Audience :: Developers",

0 commit comments

Comments
 (0)