Skip to content

Commit

Permalink
Merge pull request #31 from TomAugspurger/feature/arrow-types
Browse files Browse the repository at this point in the history
Optionally use pyarrow types in to_geodataframe
  • Loading branch information
Tom Augspurger authored Mar 29, 2024
2 parents 3901d33 + 9c60219 commit dfd384a
Show file tree
Hide file tree
Showing 5 changed files with 348 additions and 159 deletions.
153 changes: 120 additions & 33 deletions stac_geoparquet/stac_geoparquet.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""
Generate geoparquet from a sequence of STAC items.
"""

from __future__ import annotations
import collections

from typing import Sequence, Any
from typing import Sequence, Any, Literal
import warnings

import pystac
import geopandas
import pandas as pd
import pyarrow as pa
import numpy as np
import shapely.geometry

Expand All @@ -16,7 +20,7 @@
from stac_geoparquet.utils import fix_empty_multipolygon

STAC_ITEM_TYPES = ["application/json", "application/geo+json"]

DTYPE_BACKEND = Literal["numpy_nullable", "pyarrow"]
SELF_LINK_COLUMN = "self_link"


Expand All @@ -31,7 +35,10 @@ def _fix_array(v):


def to_geodataframe(
items: Sequence[dict[str, Any]], add_self_link: bool = False
items: Sequence[dict[str, Any]],
add_self_link: bool = False,
dtype_backend: DTYPE_BACKEND | None = None,
datetime_precision: str = "ns",
) -> geopandas.GeoDataFrame:
"""
Convert a sequence of STAC items to a :class:`geopandas.GeoDataFrame`.
Expand All @@ -42,19 +49,72 @@ def to_geodataframe(
Parameters
----------
items: A sequence of STAC items.
add_self_link: Add the absolute link (if available) to the source STAC Item as a separate column named "self_link"
add_self_link: bool, default False
Add the absolute link (if available) to the source STAC Item
as a separate column named "self_link"
dtype_backend: {'pyarrow', 'numpy_nullable'}, optional
The dtype backend to use for storing arrays.
By default, this will use 'numpy_nullable' and emit a
FutureWarning that the default will change to 'pyarrow' in
the next release.
Set to 'numpy_nullable' to silence the warning and accept the
old behavior.
Set to 'pyarrow' to silence the warning and accept the new behavior.
There are some difference in the output as well: with
``dtype_backend="pyarrow"``, struct-like fields will explicitly
contain null values for fields that appear in only some of the
records. For example, given an ``assets`` like::
{
"a": {
"href": "a.tif",
},
"b": {
"href": "b.tif",
"title": "B",
}
}
The ``assets`` field of the output for the first row with
``dtype_backend="numpy_nullable"`` will be a Python dictionary with
just ``{"href": "a.tiff"}``.
With ``dtype_backend="pyarrow"``, this will be a pyarrow struct
with fields ``{"href": "a.tif", "title", None}``. pyarrow will
infer that the struct field ``asset.title`` is nullable.
datetime_precision: str, default "ns"
The precision to use for the datetime columns. For example,
"us" is microsecond and "ns" is nanosecond.
Returns
-------
The converted GeoDataFrame.
"""
items2 = []
items2 = collections.defaultdict(list)

for item in items:
item2 = {k: v for k, v in item.items() if k != "properties"}
keys = set(item) - {"properties", "geometry"}

for k in keys:
items2[k].append(item[k])

item_geometry = item["geometry"]
if item_geometry:
item_geometry = fix_empty_multipolygon(item_geometry)

items2["geometry"].append(item_geometry)

for k, v in item["properties"].items():
if k in item2:
raise ValueError("k", k)
item2[k] = v
if k in item:
msg = f"Key '{k}' appears in both 'properties' and the top level."
raise ValueError(msg)
items2[k].append(v)

if add_self_link:
self_href = None
for link in item["links"]:
Expand All @@ -65,23 +125,11 @@ def to_geodataframe(
):
self_href = link["href"]
break
item2[SELF_LINK_COLUMN] = self_href
items2.append(item2)

# Filter out missing geoms in MultiPolygons
# https://github.com/shapely/shapely/issues/1407
# geometry = [shapely.geometry.shape(x["geometry"]) for x in items2]

geometry = []
for item2 in items2:
item_geometry = item2["geometry"]
if item_geometry:
item_geometry = fix_empty_multipolygon(item_geometry) # type: ignore
geometry.append(item_geometry)

gdf = geopandas.GeoDataFrame(items2, geometry=geometry, crs="WGS84")
items2[SELF_LINK_COLUMN].append(self_href)

for column in [
# TODO: Ideally we wouldn't have to hard-code this list.
# Could we get it from the JSON schema.
DATETIME_COLUMNS = {
"datetime", # common metadata
"start_datetime",
"end_datetime",
Expand All @@ -90,9 +138,43 @@ def to_geodataframe(
"expires", # timestamps extension
"published",
"unpublished",
]:
if column in gdf.columns:
gdf[column] = pd.to_datetime(gdf[column], format="ISO8601")
}

items2["geometry"] = geopandas.array.from_shapely(items2["geometry"])

if dtype_backend is None:
msg = (
"The default argument for 'dtype_backend' will change from "
"'numpy_nullable' to 'pyarrow'. To keep the previous default "
"specify ``dtype_backend='numpy_nullable'``. To accept the future "
"behavior specify ``dtype_backend='pyarrow'."
)
warnings.warn(FutureWarning(msg))
dtype_backend = "numpy_nullable"

if dtype_backend == "pyarrow":
for k, v in items2.items():
if k in DATETIME_COLUMNS:
dt = pd.to_datetime(v, format="ISO8601").as_unit(datetime_precision)
items2[k] = pd.arrays.ArrowExtensionArray(pa.array(dt))

elif k != "geometry":
items2[k] = pd.arrays.ArrowExtensionArray(pa.array(v))

elif dtype_backend == "numpy_nullable":
for k, v in items2.items():
if k in DATETIME_COLUMNS:
items2[k] = pd.to_datetime(v, format="ISO8601").as_unit(
datetime_precision
)

if k in {"type", "stac_version", "id", "collection", SELF_LINK_COLUMN}:
items2[k] = pd.array(v, dtype="string")
else:
msg = f"Invalid 'dtype_backend={dtype_backend}'."
raise TypeError(msg)

gdf = geopandas.GeoDataFrame(items2, geometry="geometry", crs="WGS84")

columns = [
"type",
Expand All @@ -111,10 +193,6 @@ def to_geodataframe(
columns.remove(col)

gdf = pd.concat([gdf[columns], gdf.drop(columns=columns)], axis="columns")
for k in ["type", "stac_version", "id", "collection", SELF_LINK_COLUMN]:
if k in gdf:
gdf[k] = gdf[k].astype("string")

return gdf


Expand Down Expand Up @@ -144,12 +222,16 @@ def to_dict(record: dict) -> dict:

if k == SELF_LINK_COLUMN:
continue
elif k == "assets":
item[k] = {k2: v2 for k2, v2 in v.items() if v2 is not None}
elif k in top_level_keys:
item[k] = v
else:
properties[k] = v

item["geometry"] = shapely.geometry.mapping(item["geometry"])
if item["geometry"]:
item["geometry"] = shapely.geometry.mapping(item["geometry"])

item["properties"] = properties

return item
Expand All @@ -175,6 +257,11 @@ def to_item_collection(df: geopandas.GeoDataFrame) -> pystac.ItemCollection:
include=["datetime64[ns, UTC]", "datetime64[ns]"]
).columns
for k in datelike:
# %f isn't implemented in pyarrow
# https://github.com/apache/arrow/issues/20146
if isinstance(df2[k].dtype, pd.ArrowDtype):
df2[k] = df2[k].astype("datetime64[ns, utc]")

df2[k] = (
df2[k].dt.strftime("%Y-%m-%dT%H:%M:%S.%fZ").fillna("").replace({"": None})
)
Expand Down
44 changes: 36 additions & 8 deletions stac_geoparquet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,27 @@


@functools.singledispatch
def assert_equal(result: Any, expected: Any) -> bool:
def assert_equal(result: Any, expected: Any, ignore_none: bool = False) -> bool:
raise TypeError(f"Invalid type {type(result)}")


@assert_equal.register(pystac.ItemCollection)
def assert_equal_ic(
result: pystac.ItemCollection, expected: pystac.ItemCollection
result: pystac.ItemCollection,
expected: pystac.ItemCollection,
ignore_none: bool = False,
) -> None:
assert type(result) == type(expected)
assert len(result) == len(expected)
assert result.extra_fields == expected.extra_fields
for a, b in zip(result.items, expected.items):
assert_equal(a, b)
assert_equal(a, b, ignore_none=ignore_none)


@assert_equal.register(pystac.Item)
def assert_equal_item(result: pystac.Item, expected: pystac.Item) -> None:
def assert_equal_item(
result: pystac.Item, expected: pystac.Item, ignore_none: bool = False
) -> None:
assert type(result) == type(expected)
assert result.id == expected.id
assert shapely.geometry.shape(result.geometry) == shapely.geometry.shape(
Expand All @@ -41,20 +45,44 @@ def assert_equal_item(result: pystac.Item, expected: pystac.Item) -> None:
expected_links = sorted(expected.links, key=lambda x: x.href)
assert len(result_links) == len(expected_links)
for a, b in zip(result_links, expected_links):
assert_equal(a, b)
assert_equal(a, b, ignore_none=ignore_none)

assert set(result.assets) == set(expected.assets)
for k in result.assets:
assert_equal(result.assets[k], expected.assets[k])
assert_equal(result.assets[k], expected.assets[k], ignore_none=ignore_none)


@assert_equal.register(pystac.Link)
@assert_equal.register(pystac.Asset)
def assert_link_equal(
result: pystac.Link | pystac.Asset, expected: pystac.Link | pystac.Asset
result: pystac.Link | pystac.Asset,
expected: pystac.Link | pystac.Asset,
ignore_none: bool = False,
) -> None:
assert type(result) == type(expected)
assert result.to_dict() == expected.to_dict()
resultd = result.to_dict()
expectedd = expected.to_dict()

left = {}

if ignore_none:
for k, v in resultd.items():
if v is None and k not in expectedd:
pass
elif isinstance(v, list) and k in expectedd:
out = []
for val in v:
if isinstance(val, dict):
out.append({k: v2 for k, v2 in val.items() if v2 is not None})
else:
out.append(val)
left[k] = out
else:
left[k] = v
else:
left = resultd

assert left == expectedd


def fix_empty_multipolygon(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pgstac_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_naip_item():
expected.remove_links(rel=pystac.RelType.SELF)
result.remove_links(rel=pystac.RelType.SELF)

assert_equal(result, expected)
assert_equal(result, expected, ignore_none=True)


def test_sentinel2_l2a():
Expand All @@ -139,7 +139,7 @@ def test_sentinel2_l2a():
result.remove_links(rel=pystac.RelType.SELF)

expected.remove_links(rel=pystac.RelType.LICENSE)
assert_equal(result, expected)
assert_equal(result, expected, ignore_none=True)


def test_generate_endpoints():
Expand Down
Loading

0 comments on commit dfd384a

Please sign in to comment.