diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 8053c17437c5e..6070852a9ba83 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -10,7 +10,12 @@ """ from __future__ import annotations - +from pandas.core.dtypes.common import ( + is_list_like, + is_scalar, + is_datetime64_dtype, + isna, +) import collections from collections import abc from collections.abc import ( @@ -121,6 +126,7 @@ notna, ) +import pandas as pd from pandas.core import ( algorithms, common as com, @@ -9936,18 +9942,12 @@ def explode( 3 3 1 d 3 4 1 e """ - if not self.columns.is_unique: - duplicate_cols = self.columns[self.columns.duplicated()].tolist() - raise ValueError( - f"DataFrame columns must be unique. Duplicate columns: {duplicate_cols}" - ) + df = self.reset_index(drop=True) columns: list[Hashable] if is_scalar(column) or isinstance(column, tuple): columns = [column] - elif isinstance(column, list) and all( - is_scalar(c) or isinstance(c, tuple) for c in column - ): + elif isinstance(column, list) and all(is_scalar(c) or isinstance(c, tuple) for c in column): if not column: raise ValueError("column must be nonempty") if len(column) > len(set(column)): @@ -9955,22 +9955,73 @@ def explode( columns = column else: raise ValueError("column must be a scalar, tuple, or list thereof") - - df = self.reset_index(drop=True) if len(columns) == 1: - result = df[columns[0]].explode() + col = columns[0] + orig_dtype = df[col].dtype + + exploded_values = [] + exploded_index = [] + + for i, val in enumerate(df[col]): + if is_list_like(val) and not isinstance(val, (str, bytes)): + for item in val: + exploded_values.append(item) + exploded_index.append(i) + elif isna(val): + exploded_values.append(np.datetime64("NaT") if is_datetime64_dtype(orig_dtype) else np.nan) + exploded_index.append(i) + else: + exploded_values.append(val) + exploded_index.append(i) + + exploded_series = pd.Series( + np.array(exploded_values, dtype=orig_dtype if is_datetime64_dtype(orig_dtype) else None), + index=exploded_index, + name=col + ) + + result = df.drop(columns, axis=1).iloc[exploded_series.index] + result[col] = exploded_series.values else: mylen = lambda x: len(x) if (is_list_like(x) and len(x) > 0) else 1 counts0 = self[columns[0]].apply(mylen) for c in columns[1:]: if not all(counts0 == self[c].apply(mylen)): raise ValueError("columns must have matching element counts") - result = DataFrame({c: df[c].explode() for c in columns}) - result = df.drop(columns, axis=1).join(result) + + exploded_columns = {} + exploded_index = [] + + for i in range(len(df)): + row_counts = mylen(df[columns[0]].iloc[i]) + for j in range(row_counts): + exploded_index.append(i) + + for col in columns: + orig_dtype = df[col].dtype + values = [] + for val in df[col]: + if is_list_like(val) and not isinstance(val, (str, bytes)): + values.extend(val) + elif isna(val): + values.append(np.datetime64("NaT") if is_datetime64_dtype(orig_dtype) else np.nan) + else: + values.append(val) + exploded_columns[col] = pd.Series( + np.array(values, dtype=orig_dtype if is_datetime64_dtype(orig_dtype) else None), + index=exploded_index + ) + + result = df.drop(columns, axis=1).iloc[exploded_index].copy() + for col in columns: + result[col] = exploded_columns[col].values + + # Handle index if ignore_index: result.index = default_index(len(result)) else: result.index = self.index.take(result.index) + result = result.reindex(columns=self.columns) return result.__finalize__(self, method="explode") diff --git a/pandas/tests/series/methods/test_explode.py b/pandas/tests/series/methods/test_explode.py index e4ad2493f9bb9..3d4716f74e251 100644 --- a/pandas/tests/series/methods/test_explode.py +++ b/pandas/tests/series/methods/test_explode.py @@ -175,3 +175,41 @@ def test_explode_pyarrow_non_list_type(ignore_index): result = ser.explode(ignore_index=ignore_index) expected = pd.Series([1, 2, 3], dtype="int64[pyarrow]", index=[0, 1, 2]) tm.assert_series_equal(result, expected) + + +def test_explode_preserves_datetime_unit(): + # Create datetime64[ms] array manually + dt64_ms = np.array( + [ + "2020-01-01T00:00:00.000", + "2020-01-01T01:00:00.000", + "2020-01-01T02:00:00.000", + ], + dtype="datetime64[ms]", + ) + s = pd.Series([dt64_ms]) + + # Explode the Series + result = s.explode() + + # Ensure the dtype (including unit) is preserved + assert result.dtype == dt64_ms.dtype, ( + f"Expected dtype {dt64_ms.dtype}, got {result.dtype}" + ) + + +def test_single_column_explode_preserves_datetime_unit(): + # Use freq in ms since unit='ms' + rng = pd.date_range("2020-01-01T00:00:00Z", periods=3, freq="3600000ms", unit="ms") + s = pd.Series([rng]) + result = s.explode() + assert result.dtype == rng.dtype + + +def test_multi_column_explode_preserves_datetime_unit(): + rng1 = pd.date_range("2020-01-01", periods=2, freq="3600000ms", unit="ms") + rng2 = pd.date_range("2020-01-01", periods=2, freq="3600000ms", unit="ms") + df = pd.DataFrame({"A": [rng1], "B": [rng2]}) + result = df.explode(["A", "B"]) + assert result["A"].dtype == rng1.dtype + assert result["B"].dtype == rng2.dtype