Skip to content

Commit 5d65a7f

Browse files
committed
stash
1 parent f723198 commit 5d65a7f

File tree

3 files changed

+264
-281
lines changed

3 files changed

+264
-281
lines changed

torchgeo/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
NonGeoClassificationDataset,
5656
NonGeoDataset,
5757
RasterDataset,
58+
RioXarrayDataset,
5859
UnionDataset,
5960
VectorDataset,
6061
)
@@ -101,7 +102,6 @@
101102
from .quakeset import QuakeSet
102103
from .reforestree import ReforesTree
103104
from .resisc45 import RESISC45
104-
from .rioxr import RioXarrayDataset
105105
from .rwanda_field_boundary import RwandaFieldBoundary
106106
from .satlas import SatlasPretrain
107107
from .seasonet import SeasoNet
@@ -258,6 +258,7 @@
258258
'RGBBandsMissingError',
259259
'RasterDataset',
260260
'ReforesTree',
261+
'RioXarrayDataset',
261262
'RwandaFieldBoundary',
262263
'SSL4EOLBenchmark',
263264
'SatlasPretrain',

torchgeo/datasets/geo.py

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
from collections.abc import Callable, Iterable, Sequence
1515
from typing import Any, ClassVar, cast
1616

17+
import xarray as xr
18+
from rasterio.crs import CRS
19+
from rioxarray.merge import merge_arrays
20+
from rtree.index import Index, Property
21+
1722
import fiona
1823
import fiona.transform
1924
import numpy as np
@@ -1238,3 +1243,260 @@ def res(self, new_res: float) -> None:
12381243
self._res = new_res
12391244
self.datasets[0].res = new_res
12401245
self.datasets[1].res = new_res
1246+
1247+
1248+
1249+
class RioXarrayDataset(GeoDataset):
1250+
"""Wrapper for geographical datasets stored as Xarray Datasets.
1251+
1252+
In-memory geographical xarray.DataArray and xarray.Dataset.
1253+
1254+
Relies on rioxarray.
1255+
1256+
.. versionadded:: 0.7.0
1257+
"""
1258+
1259+
filename_glob = "*"
1260+
filename_regex = ".*"
1261+
1262+
is_image = True
1263+
1264+
spatial_x_name = "x"
1265+
spatial_y_name = "y"
1266+
1267+
transform = None
1268+
1269+
@property
1270+
def dtype(self) -> torch.dtype:
1271+
"""The dtype of the dataset (overrides the dtype of the data file via a cast).
1272+
1273+
Returns:
1274+
the dtype of the dataset
1275+
"""
1276+
if self.is_image:
1277+
return torch.float32
1278+
else:
1279+
return torch.long
1280+
1281+
def harmonize_format(self, ds):
1282+
"""Convert the dataset to the standard format.
1283+
1284+
Args:
1285+
ds: dataset or array to harmonize
1286+
1287+
Returns:
1288+
the harmonized dataset or array
1289+
"""
1290+
# rioxarray expects spatial dimensions to be named x and y
1291+
ds.rio.set_spatial_dims(self.spatial_x_name, self.spatial_y_name, inplace=True)
1292+
1293+
# if x coords go from 0 to 360, convert to -180 to 180
1294+
if ds[self.spatial_x_name].min() > 180:
1295+
ds = ds.assign_coords(
1296+
{self.spatial_x_name: ds[self.spatial_x_name] % 360 - 180}
1297+
)
1298+
1299+
# if y coords go from 0 to 180, convert to -90 to 90
1300+
if ds[self.spatial_x_name].min() > 90:
1301+
ds = ds.assign_coords(
1302+
{self.spatial_y_name: ds[self.spatial_y_name] % 180 - 90}
1303+
)
1304+
# expect asceding coordinate values
1305+
ds = ds.sortby(self.spatial_x_name, ascending=True)
1306+
ds = ds.sortby(self.spatial_y_name, ascending=True)
1307+
return ds
1308+
1309+
def __init__(
1310+
self,
1311+
paths: Path | Iterable[Path] = 'data',
1312+
data_variables: list[str] | None = None,
1313+
# crs: Optional[CRS] = None,
1314+
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
1315+
) -> None:
1316+
"""Initialize a new Dataset instance.
1317+
1318+
Args:
1319+
paths: one or more root directories to search or files to load
1320+
data_variables: data variables that should be gathered from the collection
1321+
of xarray datasets
1322+
transforms: a function/transform that takes an input sample
1323+
and returns a transformed version
1324+
1325+
Raises:
1326+
FileNotFoundError: if files are not found in ``paths``
1327+
"""
1328+
super().__init__(transforms)
1329+
1330+
self.paths = paths
1331+
1332+
if data_variables:
1333+
self.data_variables = data_variables
1334+
else:
1335+
data_variables_to_collect: list[str] = []
1336+
1337+
self.transforms = transforms
1338+
1339+
# Create an R-tree to index the dataset
1340+
self.index = Index(interleaved=False, properties=Property(dimension=3))
1341+
1342+
# Populate the dataset index
1343+
i = 0
1344+
pathname = os.path.join(root, self.filename_glob)
1345+
filename_regex = re.compile(self.filename_regex, re.VERBOSE)
1346+
for filepath in glob.iglob(pathname, recursive=True):
1347+
match = re.match(filename_regex, os.path.basename(filepath))
1348+
if match is not None:
1349+
with xr.open_dataset(filepath, decode_times=True) as ds:
1350+
ds = self.harmonize_format(ds)
1351+
1352+
try:
1353+
(minx, miny, maxx, maxy) = ds.rio.bounds()
1354+
except AttributeError:
1355+
# or take the shape of the data variable?
1356+
continue
1357+
1358+
if hasattr(ds, "time"):
1359+
try:
1360+
indices = ds.indexes["time"].to_datetimeindex()
1361+
except AttributeError:
1362+
indices = ds.indexes["time"]
1363+
1364+
mint = indices.min().to_pydatetime().timestamp()
1365+
maxt = indices.max().to_pydatetime().timestamp()
1366+
else:
1367+
mint = 0
1368+
maxt = sys.maxsize
1369+
coords = (minx, maxx, miny, maxy, mint, maxt)
1370+
self.index.insert(i, coords, filepath)
1371+
i += 1
1372+
1373+
# collect all possible data variables if self.data_variables is None
1374+
if not data_variables:
1375+
data_variables_to_collect.extend(list(ds.data_vars))
1376+
1377+
if i == 0:
1378+
import pdb
1379+
1380+
pdb.set_trace()
1381+
msg = f"No {self.__class__.__name__} data was found in `paths='{self.paths}'`"
1382+
raise FileNotFoundError(msg)
1383+
1384+
if not data_variables:
1385+
self.data_variables = list(set(data_variables_to_collect))
1386+
1387+
# if not crs:
1388+
# self._crs = "EPSG:4326"
1389+
# else:
1390+
# self._crs = cast(CRS, crs)
1391+
self.res = 1.0
1392+
1393+
def _infer_spatial_coordinate_names(self, ds) -> tuple[str]:
1394+
"""Infer the names of the spatial coordinates.
1395+
1396+
Args:
1397+
ds: Dataset or DataArray of which to infer the spatial coordinates
1398+
1399+
Returns:
1400+
x and y coordinate names
1401+
"""
1402+
x_name = None
1403+
y_name = None
1404+
for coord_name, coord in ds.coords.items():
1405+
if hasattr(coord, "units"):
1406+
if any(
1407+
[
1408+
x in coord.units.lower()
1409+
for x in ["degrees_north", "degree_north"]
1410+
]
1411+
):
1412+
y_name = coord_name
1413+
elif any(
1414+
[x in coord.units.lower() for x in ["degrees_east", "degree_east"]]
1415+
):
1416+
x_name = coord_name
1417+
1418+
if not x_name or not y_name:
1419+
raise ValueError("Spatial Coordinate Units not found in Dataset.")
1420+
1421+
return x_name, y_name
1422+
1423+
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
1424+
"""Retrieve image/mask and metadata indexed by query.
1425+
1426+
Args:
1427+
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
1428+
1429+
Returns:
1430+
sample of image/mask and metadata at that index
1431+
1432+
Raises:
1433+
IndexError: if query is not found in the index
1434+
"""
1435+
hits = self.index.intersection(tuple(query), objects=True)
1436+
items = [hit.object for hit in hits]
1437+
1438+
if not items:
1439+
raise IndexError(
1440+
f"query: {query} not found in index with bounds: {self.bounds}"
1441+
)
1442+
1443+
data_arrays: list["np.typing.NDArray"] = []
1444+
for item in items:
1445+
with xr.open_dataset(item, decode_cf=True) as ds:
1446+
ds = self.harmonize_format(ds)
1447+
# select time dimension
1448+
if hasattr(ds, "time"):
1449+
try:
1450+
ds["time"] = ds.indexes["time"].to_datetimeindex()
1451+
except AttributeError:
1452+
ds["time"] = ds.indexes["time"]
1453+
ds = ds.sel(
1454+
time=slice(
1455+
datetime.fromtimestamp(query.mint),
1456+
datetime.fromtimestamp(query.maxt),
1457+
)
1458+
)
1459+
1460+
for variable in self.data_variables:
1461+
if hasattr(ds, variable):
1462+
da = ds[variable]
1463+
# if not da.rio.crs:
1464+
# da.rio.write_crs(self._crs, inplace=True)
1465+
# elif da.rio.crs != self._crs:
1466+
# da = da.rio.reproject(self._crs)
1467+
# clip box ignores time dimension
1468+
clipped = da.rio.clip_box(
1469+
minx=query.minx,
1470+
miny=query.miny,
1471+
maxx=query.maxx,
1472+
maxy=query.maxy,
1473+
)
1474+
# rioxarray expects this order
1475+
clipped = clipped.transpose(
1476+
"time", self.spatial_y_name, self.spatial_x_name, ...
1477+
)
1478+
1479+
# set proper transform # TODO not working
1480+
# clipped.rio.write_transform(self.transform)
1481+
data_arrays.append(clipped.squeeze())
1482+
1483+
import pdb
1484+
1485+
pdb.set_trace()
1486+
merged_data = torch.from_numpy(
1487+
merge_arrays(
1488+
data_arrays, bounds=(query.minx, query.miny, query.maxx, query.maxy)
1489+
).data
1490+
)
1491+
sample = {"bbox": query}
1492+
1493+
merged_data = merged_data.to(self.dtype)
1494+
if self.is_image:
1495+
sample["image"] = merged_data
1496+
else:
1497+
sample["mask"] = merged_data
1498+
1499+
if self.transforms is not None:
1500+
sample = self.transforms(sample)
1501+
1502+
return sample

0 commit comments

Comments
 (0)