-
-
Notifications
You must be signed in to change notification settings - Fork 18.7k
BUG: groupby.agg with UDF changing pyarrow dtypes #59601
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9faa460
969d5b1
66114f3
b0290ed
20c8fa0
97b3d54
932d737
82ddeb5
d510052
62a31d9
a54bf58
64330f0
0647711
affde38
842f561
93b5bf3
6f35c0e
abd0adf
bebc442
bb6343b
3a3f2a2
6dc40f5
4ef96f7
c6a98c0
9181eaf
612d7d0
3b6696b
680e238
3a8597e
6496b15
712c36a
e1ccef6
0ce083d
a1d73f5
57845a8
fa257b0
0a9b83f
139319a
9c2f9f2
fef315d
f758eb1
283eda9
d6edeff
b2e34fb
9cbf339
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,6 +50,7 @@ | |
) | ||
|
||
from pandas.core.arrays import Categorical | ||
from pandas.core.arrays.arrow.array import ArrowExtensionArray | ||
from pandas.core.frame import DataFrame | ||
from pandas.core.groupby import grouper | ||
from pandas.core.indexes.api import ( | ||
|
@@ -954,18 +955,28 @@ def agg_series( | |
------- | ||
np.ndarray or ExtensionArray | ||
""" | ||
result = self._aggregate_series_pure_python(obj, func) | ||
npvalues = lib.maybe_convert_objects(result, try_float=False) | ||
|
||
if isinstance(obj._values, ArrowExtensionArray): | ||
from pandas.core.dtypes.common import is_string_dtype | ||
|
||
if not isinstance(obj._values, np.ndarray): | ||
# When obj.dtype is a string, any object can be cast. Only do so if the | ||
# UDF returned strings or NA values. | ||
if not is_string_dtype(obj.dtype) or is_string_dtype( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i suspect what you really want here is lib.is_string_array? |
||
npvalues[~isna(npvalues)] | ||
): | ||
out = maybe_cast_pointwise_result( | ||
npvalues, obj.dtype, numeric_only=True, same_dtype=preserve_dtype | ||
) | ||
else: | ||
out = npvalues | ||
|
||
elif not isinstance(obj._values, np.ndarray): | ||
# we can preserve a little bit more aggressively with EA dtype | ||
# because maybe_cast_pointwise_result will do a try/except | ||
# with _from_sequence. NB we are assuming here that _from_sequence | ||
# is sufficiently strict that it casts appropriately. | ||
preserve_dtype = True | ||
|
||
result = self._aggregate_series_pure_python(obj, func) | ||
|
||
npvalues = lib.maybe_convert_objects(result, try_float=False) | ||
if preserve_dtype: | ||
out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True) | ||
else: | ||
out = npvalues | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
import pytest | ||
|
||
from pandas.errors import SpecificationError | ||
import pandas.util._test_decorators as td | ||
|
||
from pandas.core.dtypes.common import is_integer_dtype | ||
|
||
|
@@ -23,6 +24,7 @@ | |
to_datetime, | ||
) | ||
import pandas._testing as tm | ||
from pandas.arrays import ArrowExtensionArray | ||
from pandas.core.groupby.grouper import Grouping | ||
|
||
|
||
|
@@ -1807,3 +1809,99 @@ def test_groupby_aggregation_func_list_multi_index_duplicate_columns(): | |
index=Index(["level1.1", "level1.2"]), | ||
) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
|
||
@td.skip_if_no("pyarrow") | ||
@pytest.mark.parametrize( | ||
"input_dtype, output_dtype", | ||
[ | ||
# With NumPy arrays, the results from the UDF would be e.g. np.float32 scalars | ||
# which we can therefore preserve. However with PyArrow arrays, the results are | ||
# Python scalars so we have no information about size or uint vs int. | ||
("float[pyarrow]", "double[pyarrow]"), | ||
("int64[pyarrow]", "int64[pyarrow]"), | ||
("uint64[pyarrow]", "int64[pyarrow]"), | ||
("bool[pyarrow]", "bool[pyarrow]"), | ||
], | ||
) | ||
def test_agg_lambda_pyarrow_dtype_conversion(input_dtype, output_dtype): | ||
# GH#59601 | ||
# Test PyArrow dtype conversion back to PyArrow dtype | ||
df = DataFrame( | ||
{ | ||
"A": ["c1", "c2", "c3", "c1", "c2", "c3"], | ||
"B": pd.array([100, 200, 255, 0, 199, 40392], dtype=input_dtype), | ||
} | ||
) | ||
gb = df.groupby("A") | ||
result = gb.agg(lambda x: x.min()) | ||
|
||
expected = DataFrame( | ||
{"B": pd.array([0, 199, 255], dtype=output_dtype)}, | ||
index=Index(["c1", "c2", "c3"], name="A"), | ||
) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
|
||
@td.skip_if_no("pyarrow") | ||
def test_agg_lambda_complex128_dtype_conversion(): | ||
# GH#59601 | ||
df = DataFrame( | ||
{"A": ["c1", "c2", "c3"], "B": pd.array([100, 200, 255], "int64[pyarrow]")} | ||
) | ||
gb = df.groupby("A") | ||
result = gb.agg(lambda x: complex(x.sum(), x.count())) | ||
|
||
expected = DataFrame( | ||
{ | ||
"B": pd.array( | ||
[complex(100, 1), complex(200, 1), complex(255, 1)], dtype="complex128" | ||
), | ||
}, | ||
index=Index(["c1", "c2", "c3"], name="A"), | ||
) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
|
||
@td.skip_if_no("pyarrow") | ||
def test_agg_lambda_numpy_uint64_to_pyarrow_dtype_conversion(): | ||
# GH#59601 | ||
df = DataFrame( | ||
{ | ||
"A": ["c1", "c2", "c3"], | ||
"B": pd.array([100, 200, 255], dtype="uint64[pyarrow]"), | ||
} | ||
) | ||
gb = df.groupby("A") | ||
result = gb.agg(lambda x: np.uint64(x.sum())) | ||
|
||
expected = DataFrame( | ||
{ | ||
"B": pd.array([100, 200, 255], dtype="uint64[pyarrow]"), | ||
}, | ||
index=Index(["c1", "c2", "c3"], name="A"), | ||
) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
|
||
@td.skip_if_no("pyarrow") | ||
def test_agg_lambda_pyarrow_struct_to_object_dtype_conversion(): | ||
# GH#59601 | ||
import pyarrow as pa | ||
|
||
df = DataFrame( | ||
{ | ||
"A": ["c1", "c2", "c3"], | ||
"B": pd.array([100, 200, 255], dtype="int64[pyarrow]"), | ||
} | ||
) | ||
gb = df.groupby("A") | ||
result = gb.agg(lambda x: {"number": 1}) | ||
|
||
arr = pa.array([{"number": 1}, {"number": 1}, {"number": 1}]) | ||
expected = DataFrame( | ||
{"B": ArrowExtensionArray(arr)}, | ||
index=Index(["c1", "c2", "c3"], name="A"), | ||
) | ||
Comment on lines
+1899
to
+1905
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When the column starts as a PyArrow dtype and returns dictionaries, it seems questionable to me whether we should return the corresponding PyArrow dtype. The other option is a NumPy array of object dtype. But both seem like reasonable results and I imagine the PyArrow is likely to be more convenient for the user who is using PyArrow dtypes. |
||
|
||
tm.assert_frame_equal(result, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this go at the top