Skip to content

Commit ebca3c5

Browse files
TST: update expected dtype for sum of decimals with pyarrow 21+ (#61799)
1 parent cf1a11c commit ebca3c5

File tree

3 files changed

+9
-1
lines changed

3 files changed

+9
-1
lines changed

pandas/compat/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
pa_version_under18p0,
3636
pa_version_under19p0,
3737
pa_version_under20p0,
38+
pa_version_under21p0,
3839
)
3940

4041
if TYPE_CHECKING:
@@ -168,4 +169,5 @@ def is_ci_environment() -> bool:
168169
"pa_version_under18p0",
169170
"pa_version_under19p0",
170171
"pa_version_under20p0",
172+
"pa_version_under21p0",
171173
]

pandas/compat/pyarrow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
pa_version_under18p0 = _palv < Version("18.0.0")
1919
pa_version_under19p0 = _palv < Version("19.0.0")
2020
pa_version_under20p0 = _palv < Version("20.0.0")
21+
pa_version_under21p0 = _palv < Version("21.0.0")
2122
HAS_PYARROW = _palv >= Version("12.0.1")
2223
except ImportError:
2324
pa_version_under12p1 = True
@@ -30,4 +31,5 @@
3031
pa_version_under18p0 = True
3132
pa_version_under19p0 = True
3233
pa_version_under20p0 = True
34+
pa_version_under21p0 = True
3335
HAS_PYARROW = False

pandas/tests/extension/test_arrow.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
pa_version_under14p0,
4444
pa_version_under19p0,
4545
pa_version_under20p0,
46+
pa_version_under21p0,
4647
)
4748

4849
from pandas.core.dtypes.dtypes import (
@@ -542,7 +543,10 @@ def _get_expected_reduction_dtype(self, arr, op_name: str, skipna: bool):
542543
else:
543544
cmp_dtype = arr.dtype
544545
elif arr.dtype.name == "decimal128(7, 3)[pyarrow]":
545-
if op_name not in ["median", "var", "std", "sem", "skew"]:
546+
if op_name == "sum" and not pa_version_under21p0:
547+
# https://github.com/apache/arrow/pull/44184
548+
cmp_dtype = ArrowDtype(pa.decimal128(38, 3))
549+
elif op_name not in ["median", "var", "std", "sem", "skew"]:
546550
cmp_dtype = arr.dtype
547551
else:
548552
cmp_dtype = "float64[pyarrow]"

0 commit comments

Comments
 (0)