Skip to content

Commit fba6a24

Browse files
authored
columns_equal clean up (#396)
* refactor columns_equal, fixes #121 * refactor polars columns_equal, support for temporal * remove DATE_TYPE const
1 parent eeffb0e commit fba6a24

File tree

4 files changed

+165
-85
lines changed

4 files changed

+165
-85
lines changed

datacompy/core.py

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -859,53 +859,50 @@ def columns_equal(
859859
"""
860860
default_value = "DATACOMPY_NULL"
861861
compare: pd.Series[bool]
862+
if ignore_spaces:
863+
if col_1.dtype.kind == "O" and pd.api.types.is_string_dtype(col_1):
864+
col_1 = col_1.str.strip()
865+
if col_2.dtype.kind == "O" and pd.api.types.is_string_dtype(col_2):
866+
col_2 = col_2.str.strip()
867+
if ignore_case:
868+
if col_1.dtype.kind == "O" and pd.api.types.is_string_dtype(col_1):
869+
col_1 = col_1.str.upper()
870+
if col_2.dtype.kind == "O" and pd.api.types.is_string_dtype(col_2):
871+
col_2 = col_2.str.upper()
862872

863873
# short circuit if comparing mixed type columns. We don't want to support this moving forward.
864874
if pd.api.types.infer_dtype(col_1).startswith("mixed") or pd.api.types.infer_dtype(
865875
col_2
866876
).startswith("mixed"):
867877
compare = pd.Series(False, index=col_1.index)
868-
compare.index = col_1.index
869-
return compare
870-
871-
try:
872-
compare = pd.Series(
873-
np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)
874-
)
875-
except TypeError:
878+
elif pd.api.types.is_string_dtype(col_1) and pd.api.types.is_string_dtype(col_2):
876879
try:
877880
compare = pd.Series(
878-
np.isclose(
879-
col_1.astype(float),
880-
col_2.astype(float),
881-
rtol=rel_tol,
882-
atol=abs_tol,
883-
equal_nan=True,
884-
)
881+
(col_1.fillna(default_value) == col_2.fillna(default_value))
882+
| (col_1.isnull() & col_2.isnull())
883+
)
884+
except TypeError:
885+
compare = pd.Series(col_1.astype(str) == col_2.astype(str))
886+
elif {col_1.dtype.kind, col_2.dtype.kind} == {"M", "O"}:
887+
compare = compare_string_and_date_columns(col_1, col_2)
888+
else:
889+
try:
890+
compare = pd.Series(
891+
np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)
885892
)
886-
except (ValueError, TypeError):
893+
except TypeError:
887894
try:
888-
if ignore_spaces:
889-
if col_1.dtype.kind == "O" and pd.api.types.is_string_dtype(col_1):
890-
col_1 = col_1.str.strip()
891-
if col_2.dtype.kind == "O" and pd.api.types.is_string_dtype(col_2):
892-
col_2 = col_2.str.strip()
893-
894-
if ignore_case:
895-
if col_1.dtype.kind == "O" and pd.api.types.is_string_dtype(col_1):
896-
col_1 = col_1.str.upper()
897-
if col_2.dtype.kind == "O" and pd.api.types.is_string_dtype(col_2):
898-
col_2 = col_2.str.upper()
899-
900-
if {col_1.dtype.kind, col_2.dtype.kind} == {"M", "O"}:
901-
compare = compare_string_and_date_columns(col_1, col_2)
902-
else:
903-
compare = pd.Series(
904-
(col_1.fillna(default_value) == col_2.fillna(default_value))
905-
| (col_1.isnull() & col_2.isnull())
895+
compare = pd.Series(
896+
np.isclose(
897+
col_1.astype(float),
898+
col_2.astype(float),
899+
rtol=rel_tol,
900+
atol=abs_tol,
901+
equal_nan=True,
906902
)
903+
)
907904
except Exception:
908-
try:
905+
try: # last check where we just cast to strings
909906
compare = pd.Series(col_1.astype(str) == col_2.astype(str))
910907
except Exception: # Blanket exception should just return all False
911908
compare = pd.Series(False, index=col_1.index)

datacompy/polars.py

Lines changed: 43 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,12 @@
2929
import numpy as np
3030
import polars as pl
3131
from ordered_set import OrderedSet
32-
from polars.exceptions import ComputeError, InvalidOperationError
3332

3433
from datacompy.base import BaseCompare, temp_column_name
3534

3635
LOG = logging.getLogger(__name__)
3736

3837
STRING_TYPE = ["String", "Utf8"]
39-
DATE_TYPE = ["Date", "Datetime"]
4038

4139

4240
class PolarsCompare(BaseCompare):
@@ -799,13 +797,13 @@ def render(filename: str, *fields: int | float | str) -> str:
799797

800798

801799
def columns_equal(
802-
col_1: "pl.Series",
803-
col_2: "pl.Series",
800+
col_1: pl.Series,
801+
col_2: pl.Series,
804802
rel_tol: float = 0,
805803
abs_tol: float = 0,
806804
ignore_spaces: bool = False,
807805
ignore_case: bool = False,
808-
) -> "pl.Series":
806+
) -> pl.Series:
809807
"""Compare two columns from a dataframe.
810808
811809
Returns a True/False series,
@@ -841,57 +839,54 @@ def columns_equal(
841839
values don't match.
842840
"""
843841
compare: pl.Series
844-
try:
842+
843+
if ignore_spaces:
844+
if str(col_1.dtype) in STRING_TYPE:
845+
col_1 = col_1.str.strip_chars()
846+
if str(col_2.dtype) in STRING_TYPE:
847+
col_2 = col_2.str.strip_chars()
848+
849+
if ignore_case:
850+
if str(col_1.dtype) in STRING_TYPE:
851+
col_1 = col_1.str.to_uppercase()
852+
if str(col_2.dtype) in STRING_TYPE:
853+
col_2 = col_2.str.to_uppercase()
854+
855+
if (str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in STRING_TYPE) or (
856+
col_1.dtype.is_temporal() and col_2.dtype.is_temporal()
857+
):
845858
compare = pl.Series(
846-
np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)
859+
(col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())
847860
)
848-
except TypeError:
861+
elif (str(col_1.dtype) in STRING_TYPE and str(col_2.dtype).startswith("Date")) or (
862+
str(col_1.dtype).startswith("Date") and str(col_2.dtype) in STRING_TYPE
863+
):
864+
compare = compare_string_and_date_columns(col_1, col_2)
865+
else:
849866
try:
850-
if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:
851-
raise TypeError("Found date, moving to alternative logic")
852-
853867
compare = pl.Series(
854-
np.isclose(
855-
col_1.cast(pl.Float64, strict=True),
856-
col_2.cast(pl.Float64, strict=True),
857-
rtol=rel_tol,
858-
atol=abs_tol,
859-
equal_nan=True,
860-
)
868+
np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)
861869
)
862-
except (ValueError, TypeError, InvalidOperationError, ComputeError):
870+
except TypeError:
863871
try:
864-
if ignore_spaces:
865-
if str(col_1.dtype) in STRING_TYPE:
866-
col_1 = col_1.str.strip_chars()
867-
if str(col_2.dtype) in STRING_TYPE:
868-
col_2 = col_2.str.strip_chars()
869-
870-
if ignore_case:
871-
if str(col_1.dtype) in STRING_TYPE:
872-
col_1 = col_1.str.to_uppercase()
873-
if str(col_2.dtype) in STRING_TYPE:
874-
col_2 = col_2.str.to_uppercase()
875-
876-
if (
877-
str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE
878-
) or (
879-
str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE
880-
):
881-
compare = compare_string_and_date_columns(col_1, col_2)
882-
else:
883-
compare = pl.Series(
884-
(col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())
872+
compare = pl.Series(
873+
np.isclose(
874+
col_1.cast(pl.Float64, strict=True),
875+
col_2.cast(pl.Float64, strict=True),
876+
rtol=rel_tol,
877+
atol=abs_tol,
878+
equal_nan=True,
885879
)
880+
)
886881
except Exception:
887-
# Blanket exception should just return all False
888-
compare = pl.Series(False * col_1.shape[0])
882+
try: # last check where we just cast to strings
883+
compare = pl.Series(col_1.cast(pl.String) == col_2.cast(pl.String))
884+
except Exception: # Blanket exception should just return all False
885+
compare = pl.Series(False * col_1.shape[0])
889886
return compare
890887

891888

892-
def compare_string_and_date_columns(
893-
col_1: "pl.Series", col_2: "pl.Series"
894-
) -> "pl.Series":
889+
def compare_string_and_date_columns(col_1: pl.Series, col_2: pl.Series) -> pl.Series:
895890
"""Compare a string column and date column, value-wise.
896891
897892
This tries to
@@ -919,7 +914,7 @@ def compare_string_and_date_columns(
919914

920915
try: # datetime is inferred
921916
return pl.Series(
922-
(str_column.str.to_datetime().eq_missing(date_column))
917+
(str_column.str.to_datetime(strict=False).eq_missing(date_column))
923918
| (str_column.is_null() & date_column.is_null())
924919
)
925920
except Exception:
@@ -952,7 +947,7 @@ def get_merged_columns(
952947
return columns
953948

954949

955-
def calculate_max_diff(col_1: "pl.Series", col_2: "pl.Series") -> float:
950+
def calculate_max_diff(col_1: pl.Series, col_2: pl.Series) -> float:
956951
"""Get a maximum difference between two columns.
957952
958953
Parameters
@@ -977,7 +972,7 @@ def calculate_max_diff(col_1: "pl.Series", col_2: "pl.Series") -> float:
977972

978973
def generate_id_within_group(
979974
dataframe: pl.DataFrame, join_columns: List[str]
980-
) -> "pl.Series":
975+
) -> pl.Series:
981976
"""Generate an ID column that can be used to deduplicate identical rows.
982977
983978
The series generated

tests/test_core.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,3 +1497,30 @@ def test_non_full_join_counts_some_matches():
14971497
]
14981498
),
14991499
)
1500+
1501+
1502+
def test_string_as_numeric():
1503+
df1 = pd.DataFrame({"ID": [1], "REFER_NR": ["9998700990704001708177961516923014"]})
1504+
df2 = pd.DataFrame({"ID": [1], "REFER_NR": ["9998700990704001708177961516923015"]})
1505+
actual_out = datacompy.columns_equal(df1.REFER_NR, df2.REFER_NR)
1506+
assert not actual_out.all()
1507+
1508+
1509+
def test_single_date_columns_equal_to_string():
1510+
data = """a|b|expected
1511+
2017-01-01|2017-01-01 |True
1512+
2017-01-02 |2017-01-02|True
1513+
2017-10-01 |2017-10-10 |False
1514+
2017-01-01||False
1515+
|2017-01-01|False
1516+
||False"""
1517+
df = pd.read_csv(io.StringIO(data), sep="|", keep_default_na=False)
1518+
1519+
try:
1520+
df["a"] = pd.to_datetime(df["a"], format="mixed")
1521+
except ValueError:
1522+
df["a"] = pd.to_datetime(df["a"])
1523+
1524+
actual_out = datacompy.columns_equal(df.a, df.b, rel_tol=0.2, ignore_spaces=True)
1525+
expect_out = df["expected"]
1526+
assert_series_equal(expect_out, actual_out, check_names=False)

tests/test_polars.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,13 @@ def test_bad_date_columns():
259259
df = pl.DataFrame(
260260
[{"a": "2017-01-01", "b": "2017-01-01"}, {"a": "2017-01-01", "b": "2A17-01-01"}]
261261
)
262-
df = df.with_columns(df["a"].str.to_date(exact=True).alias("a_dt"))
263-
assert not columns_equal(df["a_dt"], df["b"]).any()
262+
col_a = df["a"].str.to_date()
263+
col_b = df["b"]
264+
assert columns_equal(col_a, col_b).to_list() == [True, False]
265+
266+
col_a = df["a"]
267+
col_b = df["b"].str.to_date(strict=False)
268+
assert columns_equal(col_a, col_b).to_list() == [True, False]
264269

265270

266271
def test_rounded_date_columns():
@@ -1457,3 +1462,59 @@ def test_categorical_column():
14571462
compare = PolarsCompare(df, df, join_columns=["idx"])
14581463
assert compare.intersect_rows["foo_match"].all()
14591464
assert compare.intersect_rows["bar_match"].all()
1465+
1466+
1467+
def test_string_as_numeric():
1468+
df1 = pl.DataFrame({"ID": [1], "REFER_NR": ["9998700990704001708177961516923014"]})
1469+
df2 = pl.DataFrame({"ID": [1], "REFER_NR": ["9998700990704001708177961516923015"]})
1470+
actual_out = columns_equal(df1["REFER_NR"], df2["REFER_NR"])
1471+
assert not actual_out.all()
1472+
1473+
1474+
def test_single_date_columns_equal_to_string():
1475+
data = """a|b|expected
1476+
2017-01-01|2017-01-01 |True
1477+
2017-01-02 |2017-01-02|True
1478+
2017-10-01 |2017-10-10 |False
1479+
2017-01-01||False
1480+
|2017-01-01|False
1481+
||True"""
1482+
df = pl.read_csv(
1483+
io.StringIO(data),
1484+
separator="|",
1485+
null_values=["NULL"],
1486+
missing_utf8_is_empty_string=True,
1487+
)
1488+
col_a = df["a"].str.strip_chars().str.to_date(strict=False)
1489+
col_b = df["b"]
1490+
1491+
actual_out = columns_equal(col_a, col_b, rel_tol=0.2, ignore_spaces=True)
1492+
expect_out = df["expected"]
1493+
assert_series_equal(expect_out, actual_out, check_names=False)
1494+
1495+
1496+
def test_temporal_equal():
1497+
data = """a|b|expected
1498+
2017-01-01|2017-01-01|True
1499+
2017-01-02|2017-01-02|True
1500+
2017-10-01|2017-10-10 |False
1501+
2017-01-01||False
1502+
|2017-01-01|False
1503+
||True"""
1504+
df = pl.read_csv(
1505+
io.StringIO(data),
1506+
separator="|",
1507+
null_values=["NULL"],
1508+
missing_utf8_is_empty_string=True,
1509+
)
1510+
expect_out = df["expected"]
1511+
1512+
col_a = df["a"].str.to_date(strict=False)
1513+
col_b = df["b"].str.to_date(strict=False)
1514+
actual_out = columns_equal(col_a, col_b)
1515+
assert_series_equal(expect_out, actual_out, check_names=False)
1516+
1517+
col_a = df["a"].str.to_datetime(strict=False)
1518+
col_b = df["b"].str.to_datetime(strict=False)
1519+
actual_out = columns_equal(col_a, col_b)
1520+
assert_series_equal(expect_out, actual_out, check_names=False)

0 commit comments

Comments
 (0)