From b8b6c79b98f982dfc16322a4f0d29ede33655a3b Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Sun, 19 Jan 2025 14:09:07 -0700 Subject: [PATCH] Fixed dropping the geometry column (#322) --- CHANGELOG.md | 5 +++++ dask_geopandas/_expr.py | 24 ++++++++++++++++++++++++ dask_geopandas/expr.py | 20 ++++++++++++++++++++ dask_geopandas/tests/test_core.py | 15 +++++++++++++++ pyproject.toml | 2 +- 5 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 dask_geopandas/_expr.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f8e537..c5d81d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ Packaging: - `dask>=2025.1.0` is now required. - `python>=3.10` is now required. +Bug fixes: + +- Fixed `GeoDataFrame.drop` returning a `GeoDataFrame` + instead of a `DataFrame`, when dropping the geometry + column (#321). Version 0.4.2 (September 24, 2024) ---------------------------------- diff --git a/dask_geopandas/_expr.py b/dask_geopandas/_expr.py new file mode 100644 index 0000000..7c8989a --- /dev/null +++ b/dask_geopandas/_expr.py @@ -0,0 +1,24 @@ +from typing import Literal + +import dask.dataframe.dask_expr as dx + +import geopandas + + +def _drop(df: geopandas.GeoDataFrame, columns, errors): + return df.drop(columns=columns, errors=errors) + + +def _validate_axis(axis=0, none_is_zero: bool = True) -> None | Literal[0, 1]: + if axis not in (0, 1, "index", "columns", None): + raise ValueError(f"No axis named {axis}") + # convert to numeric axis + numeric_axis: dict[str | None, Literal[0, 1]] = {"index": 0, "columns": 1} + if none_is_zero: + numeric_axis[None] = 0 + + return numeric_axis.get(axis, axis) + + +class Drop(dx.expr.Drop): + operation = staticmethod(_drop) diff --git a/dask_geopandas/expr.py b/dask_geopandas/expr.py index d148963..1a1f3d3 100644 --- a/dask_geopandas/expr.py +++ b/dask_geopandas/expr.py @@ -26,6 +26,7 @@ import dask_geopandas +from ._expr import Drop, _validate_axis from .geohash import _geohash from .hilbert_distance import _hilbert_distance from .morton_distance import _morton_distance @@ -868,6 +869,25 @@ def explode(self, column=None, ignore_index=False, index_parts=None): enforce_metadata=False, ) + @derived_from(geopandas.GeoDataFrame) + def drop(self, labels=None, axis=0, columns=None, errors="raise"): + # https://github.com/geopandas/dask-geopandas/issues/321 + # Override to avoid an inplace drop, since we need + # to convert from a GeoDataFrame to a DataFrame when dropping + # the geometry column. + if columns is None and labels is None: + raise TypeError("must either specify 'columns' or 'labels'") + + axis = _validate_axis(axis) + + if axis == 1: + columns = labels or columns + elif axis == 0 and columns is None: + raise NotImplementedError( + "Drop currently only works for axis=1 or when columns is not None" + ) + return new_collection(Drop(self, columns=columns, errors=errors)) + from_geopandas = dd.from_pandas diff --git a/dask_geopandas/tests/test_core.py b/dask_geopandas/tests/test_core.py index 3031261..678f869 100644 --- a/dask_geopandas/tests/test_core.py +++ b/dask_geopandas/tests/test_core.py @@ -1034,6 +1034,21 @@ def get_chunk(n): assert_geodataframe_equal(ddf.compute(), expected) +def test_drop(): + # https://github.com/geopandas/dask-geopandas/issues/321 + df = dask_geopandas.from_geopandas( + geopandas.GeoDataFrame({"col": [1], "geometry": [Point(1, 1)]}), npartitions=1 + ) + result = df.drop(columns="geometry") + assert type(result) is dd.DataFrame + + result = df.drop(columns="col") + assert type(result) is dask_geopandas.GeoDataFrame + + with pytest.raises(ValueError, match="No axis named x"): + df.drop(labels="a", axis="x") + + def test_core_deprecated(): with pytest.warns(FutureWarning, match="dask_geopandas.core"): import dask_geopandas.core # noqa: F401 diff --git a/pyproject.toml b/pyproject.toml index c5190d0..4d56413 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,6 +155,6 @@ section-order = [ ] [tool.ruff.lint.isort.sections] -"dask" = ["dask", "dask_expr"] +"dask" = ["dask"] "geo" = ["geopandas", "shapely", "pyproj"] "testing" = ["pytest", "pandas.testing", "numpy.testing", "geopandas.tests", "geopandas.testing"]