Skip to content

Commit

Permalink
Misc. typing fixes (#84)
Browse files Browse the repository at this point in the history
* allow mypy scanning more code in tests

This is where we check for type issues.

* add "# type: ignore" when testing invalid types

* fix all typing errors

* fix ArrowArrayHolder.__arrow_c_array__ signature

* dry

* Test CI mypy = 1.x

I have no error locally.
  • Loading branch information
benbovy authored Dec 9, 2024
1 parent 17c355e commit ad27fbe
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 103 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- name: Install mypy
run: |
python -m pip install 'mypy<0.990'
python -m pip install 'mypy'
- name: Run mypy
run: |
Expand Down
13 changes: 9 additions & 4 deletions src/generate_spherely_vfunc_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
import string
from pathlib import Path

from spherely import EARTH_RADIUS_METERS


VFUNC_TYPE_SPECS = {
"_VFunc_Nin1_Nout1": {"n_in": 1},
"_VFunc_Nin2_Nout1": {"n_in": 2},
"_VFunc_Nin2optradius_Nout1": {"n_in": 2, "radius": "float"},
"_VFunc_Nin1optradius_Nout1": {"n_in": 1, "radius": "float"},
"_VFunc_Nin2optradius_Nout1": {"n_in": 2, "radius": ("float", EARTH_RADIUS_METERS)},
"_VFunc_Nin1optradius_Nout1": {"n_in": 1, "radius": ("float", EARTH_RADIUS_METERS)},
"_VFunc_Nin1optprecision_Nout1": {"n_in": 1, "precision": ("int", 6)},
}

STUB_FILE_PATH = Path(__file__).parent / "spherely.pyi"
Expand Down Expand Up @@ -51,10 +55,11 @@ def _vfunctype_factory(class_name, n_in, **optargs):
"",
]
optarg_str = ", ".join(
f"{arg_name}: {arg_type} = ..." for arg_name, arg_type in optargs.items()
f"{arg_name}: {arg_type} = {arg_value}"
for arg_name, (arg_type, arg_value) in optargs.items()
)

geog_types = ["Geography", "npt.ArrayLike"]
geog_types = ["Geography", "Iterable[Geography]"]
for arg_types in itertools.product(geog_types, repeat=n_in):
arg_str = ", ".join(
f"{arg_name}: {arg_type}"
Expand Down
91 changes: 63 additions & 28 deletions src/spherely.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ from typing import (
Literal,
Protocol,
Sequence,
Tuple,
TypeVar,
overload,
)
Expand Down Expand Up @@ -66,7 +65,7 @@ class Projection:
@staticmethod
def lnglat() -> Projection: ...
@staticmethod
def speudo_mercator() -> Projection: ...
def pseudo_mercator() -> Projection: ...
@staticmethod
def orthographic(longitude: float, latitude: float) -> Projection: ...

Expand All @@ -76,6 +75,11 @@ _NameType = TypeVar("_NameType", bound=str)
_ScalarReturnType = TypeVar("_ScalarReturnType", bound=Any)
_ArrayReturnDType = TypeVar("_ArrayReturnDType", bound=Any)

# TODO: npt.NDArray[Geography] not supported yet
# (see https://github.com/numpy/numpy/issues/24738)
# (unless Geography is passed via Generic[...], see VFunc below)
T_NDArray_Geography = npt.NDArray[Any]

# The following types are auto-generated. Please don't edit them by hand.
# Instead, update the generate_spherely_vfunc_types.py script and run it
# to update the types.
Expand All @@ -87,7 +91,9 @@ class _VFunc_Nin1_Nout1(Generic[_NameType, _ScalarReturnType, _ArrayReturnDType]
@overload
def __call__(self, geography: Geography) -> _ScalarReturnType: ...
@overload
def __call__(self, geography: npt.ArrayLike) -> npt.NDArray[_ArrayReturnDType]: ...
def __call__(
self, geography: Iterable[Geography]
) -> npt.NDArray[_ArrayReturnDType]: ...

class _VFunc_Nin2_Nout1(Generic[_NameType, _ScalarReturnType, _ArrayReturnDType]):
@property
Expand All @@ -96,15 +102,15 @@ class _VFunc_Nin2_Nout1(Generic[_NameType, _ScalarReturnType, _ArrayReturnDType]
def __call__(self, a: Geography, b: Geography) -> _ScalarReturnType: ...
@overload
def __call__(
self, a: Geography, b: npt.ArrayLike
self, a: Geography, b: Iterable[Geography]
) -> npt.NDArray[_ArrayReturnDType]: ...
@overload
def __call__(
self, a: npt.ArrayLike, b: Geography
self, a: Iterable[Geography], b: Geography
) -> npt.NDArray[_ArrayReturnDType]: ...
@overload
def __call__(
self, a: npt.ArrayLike, b: npt.ArrayLike
self, a: Iterable[Geography], b: Iterable[Geography]
) -> npt.NDArray[_ArrayReturnDType]: ...

class _VFunc_Nin2optradius_Nout1(
Expand All @@ -114,19 +120,19 @@ class _VFunc_Nin2optradius_Nout1(
def __name__(self) -> _NameType: ...
@overload
def __call__(
self, a: Geography, b: Geography, radius: float = ...
self, a: Geography, b: Geography, radius: float = 6371010.0
) -> _ScalarReturnType: ...
@overload
def __call__(
self, a: Geography, b: npt.ArrayLike, radius: float = ...
self, a: Geography, b: Iterable[Geography], radius: float = 6371010.0
) -> npt.NDArray[_ArrayReturnDType]: ...
@overload
def __call__(
self, a: npt.ArrayLike, b: Geography, radius: float = ...
self, a: Iterable[Geography], b: Geography, radius: float = 6371010.0
) -> npt.NDArray[_ArrayReturnDType]: ...
@overload
def __call__(
self, a: npt.ArrayLike, b: npt.ArrayLike, radius: float = ...
self, a: Iterable[Geography], b: Iterable[Geography], radius: float = 6371010.0
) -> npt.NDArray[_ArrayReturnDType]: ...

class _VFunc_Nin1optradius_Nout1(
Expand All @@ -135,10 +141,24 @@ class _VFunc_Nin1optradius_Nout1(
@property
def __name__(self) -> _NameType: ...
@overload
def __call__(self, a: Geography, radius: float = ...) -> _ScalarReturnType: ...
def __call__(
self, a: Geography, radius: float = 6371010.0
) -> _ScalarReturnType: ...
@overload
def __call__(
self, a: Iterable[Geography], radius: float = 6371010.0
) -> npt.NDArray[_ArrayReturnDType]: ...

class _VFunc_Nin1optprecision_Nout1(
Generic[_NameType, _ScalarReturnType, _ArrayReturnDType]
):
@property
def __name__(self) -> _NameType: ...
@overload
def __call__(self, a: Geography, precision: int = 6) -> _ScalarReturnType: ...
@overload
def __call__(
self, a: npt.ArrayLike, radius: float = ...
self, a: Iterable[Geography], precision: int = 6
) -> npt.NDArray[_ArrayReturnDType]: ...

# /// End types
Expand Down Expand Up @@ -188,12 +208,9 @@ def create_collection(geographies: Iterable[Geography]) -> GeometryCollection: .

# Geography creation (vectorized)

@overload
def points(
longitude: npt.ArrayLike, latitude: npt.ArrayLike
) -> npt.NDArray[np.object_]: ...
@overload
def points(longitude: float, latitude: float) -> PointGeography: ... # type: ignore[misc]
) -> PointGeography | T_NDArray_Geography: ...

# Geography utils

Expand Down Expand Up @@ -234,42 +251,60 @@ boundary: _VFunc_Nin1_Nout1[Literal["boundary"], Geography, Geography]
convex_hull: _VFunc_Nin1_Nout1[
Literal["convex_hull"], PolygonGeography, PolygonGeography
]
distance: _VFunc_Nin2optradius_Nout1[Literal["distance"], float, float]
area: _VFunc_Nin1optradius_Nout1[Literal["area"], float, float]
length: _VFunc_Nin1optradius_Nout1[Literal["length"], float, float]
perimeter: _VFunc_Nin1optradius_Nout1[Literal["perimeter"], float, float]
distance: _VFunc_Nin2optradius_Nout1[Literal["distance"], float, np.float64]
area: _VFunc_Nin1optradius_Nout1[Literal["area"], float, np.float64]
length: _VFunc_Nin1optradius_Nout1[Literal["length"], float, np.float64]
perimeter: _VFunc_Nin1optradius_Nout1[Literal["perimeter"], float, np.float64]

# io functions

to_wkt: _VFunc_Nin1_Nout1[Literal["to_wkt"], str, object]
to_wkt: _VFunc_Nin1optprecision_Nout1[Literal["to_wkt"], str, object]
to_wkb: _VFunc_Nin1_Nout1[Literal["to_wkb"], bytes, object]

@overload
def from_wkt(
a: str,
oriented: bool = False,
planar: bool = False,
tessellate_tolerance: float = 100.0,
) -> Geography: ...
@overload
def from_wkt(
a: Iterable[str],
a: list[str] | npt.NDArray[np.str_],
oriented: bool = False,
planar: bool = False,
tessellate_tolerance: float = 100.0,
) -> npt.NDArray[Any]: ...
) -> T_NDArray_Geography: ...
@overload
def from_wkb(
a: bytes,
oriented: bool = False,
planar: bool = False,
tessellate_tolerance: float = 100.0,
) -> Geography: ...
@overload
def from_wkb(
a: Iterable[bytes],
oriented: bool = False,
planar: bool = False,
tessellate_tolerance: float = 100.0,
) -> npt.NDArray[Any]: ...
) -> T_NDArray_Geography: ...

class ArrowSchemaExportable(Protocol):
def __arrow_c_schema__(self) -> object: ...

class ArrowArrayExportable(Protocol):
def __arrow_c_array__(
self, requested_schema: object | None = None
) -> Tuple[object, object]: ...
) -> tuple[object, object]: ...

class ArrowArrayHolder(ArrowArrayExportable): ...

def to_geoarrow(
input: npt.ArrayLike,
input: Geography | T_NDArray_Geography,
/,
*,
output_schema: ArrowSchemaExportable | None = None,
output_schema: ArrowSchemaExportable | str | None = None,
projection: Projection = Projection.lnglat(),
planar: bool = False,
tessellate_tolerance: float = 100.0,
Expand All @@ -284,4 +319,4 @@ def from_geoarrow(
tessellate_tolerance: float = 100.0,
projection: Projection = Projection.lnglat(),
geometry_encoding: str | None = None,
) -> npt.NDArray[Any]: ...
) -> T_NDArray_Geography: ...
12 changes: 6 additions & 6 deletions tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def test_distance_with_custom_radius() -> None:
assert actual == pytest.approx(np.pi / 2)


def test_area():
def test_area() -> None:
# scalar
geog = spherely.create_polygon([(0, 0), (90, 0), (0, 90), (0, 0)])
result = spherely.area(geog, radius=1)
Expand Down Expand Up @@ -191,11 +191,11 @@ def test_area():
"POLYGON EMPTY",
],
)
def test_area_empty(geog):
def test_area_empty(geog) -> None:
assert spherely.area(spherely.from_wkt(geog)) == 0


def test_length():
def test_length() -> None:
geog = spherely.create_linestring([(0, 0), (1, 0)])
result = spherely.length(geog, radius=1)
assert isinstance(result, float)
Expand All @@ -218,11 +218,11 @@ def test_length():
"POLYGON ((0 0, 0 1, 1 0, 0 0))",
],
)
def test_length_invalid(geog):
def test_length_invalid(geog) -> None:
assert spherely.length(spherely.from_wkt(geog)) == 0.0


def test_perimeter():
def test_perimeter() -> None:
geog = spherely.create_polygon([(0, 0), (0, 90), (90, 90), (90, 0), (0, 0)])
result = spherely.perimeter(geog, radius=1)
assert isinstance(result, float)
Expand All @@ -239,5 +239,5 @@ def test_perimeter():
@pytest.mark.parametrize(
"geog", ["POINT (0 0)", "POINT EMPTY", "LINESTRING (0 0, 1 0)", "POLYGON EMPTY"]
)
def test_perimeter_invalid(geog):
def test_perimeter_invalid(geog) -> None:
assert spherely.perimeter(spherely.from_wkt(geog)) == 0.0
20 changes: 10 additions & 10 deletions tests/test_boolean_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
),
],
)
def test_union(geog1, geog2, expected):
def test_union(geog1, geog2, expected) -> None:
result = spherely.union(spherely.from_wkt(geog1), spherely.from_wkt(geog2))
assert str(result) == expected

Expand All @@ -47,12 +47,12 @@ def test_union_polygon():
),
],
)
def test_intersection(geog1, geog2, expected):
def test_intersection(geog1, geog2, expected) -> None:
result = spherely.intersection(spherely.from_wkt(geog1), spherely.from_wkt(geog2))
assert str(result) == expected


def test_intersection_empty():
def test_intersection_empty() -> None:
result = spherely.intersection(poly1, spherely.from_wkt("POLYGON EMPTY"))
# assert spherely.is_empty(result)
assert str(result) == "GEOMETRYCOLLECTION EMPTY"
Expand All @@ -66,7 +66,7 @@ def test_intersection_empty():
assert str(result) == "GEOMETRYCOLLECTION EMPTY"


def test_intersection_lines():
def test_intersection_lines() -> None:
result = spherely.intersection(
spherely.from_wkt("LINESTRING (-45 0, 45 0)"),
spherely.from_wkt("LINESTRING (0 -10, 0 10)"),
Expand All @@ -75,7 +75,7 @@ def test_intersection_lines():
assert spherely.distance(result, spherely.from_wkt("POINT (0 0)")) == 0


def test_intersection_polygons():
def test_intersection_polygons() -> None:
result = spherely.intersection(poly1, poly2)
# TODO precision could be higher with snap level
precision = 2 if Version(spherely.__s2geography_version__) < Version("0.2.0") else 1
Expand All @@ -85,7 +85,7 @@ def test_intersection_polygons():
)


def test_intersection_polygon_model():
def test_intersection_polygon_model() -> None:
poly = spherely.from_wkt("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))")
point = spherely.from_wkt("POINT (0 0)")

Expand All @@ -107,12 +107,12 @@ def test_intersection_polygon_model():
),
],
)
def test_difference(geog1, geog2, expected):
def test_difference(geog1, geog2, expected) -> None:
result = spherely.difference(spherely.from_wkt(geog1), spherely.from_wkt(geog2))
assert spherely.equals(result, spherely.from_wkt(expected))


def test_difference_polygons():
def test_difference_polygons() -> None:
result = spherely.difference(poly1, poly2)
expected_near = spherely.area(poly1) - spherely.area(
spherely.from_wkt("POLYGON ((5 5, 10 5, 10 10, 5 10, 5 5))")
Expand All @@ -133,14 +133,14 @@ def test_difference_polygons():
),
],
)
def test_symmetric_difference(geog1, geog2, expected):
def test_symmetric_difference(geog1, geog2, expected) -> None:
result = spherely.symmetric_difference(
spherely.from_wkt(geog1), spherely.from_wkt(geog2)
)
assert spherely.equals(result, spherely.from_wkt(expected))


def test_symmetric_difference_polygons():
def test_symmetric_difference_polygons() -> None:
result = spherely.symmetric_difference(poly1, poly2)
expected_near = 2 * (
spherely.area(poly1)
Expand Down
Loading

0 comments on commit ad27fbe

Please sign in to comment.