14
14
from collections .abc import Callable , Iterable , Sequence
15
15
from typing import Any , ClassVar , cast
16
16
17
+ import xarray as xr
18
+ from rasterio .crs import CRS
19
+ from rioxarray .merge import merge_arrays
20
+ from rtree .index import Index , Property
21
+
17
22
import fiona
18
23
import fiona .transform
19
24
import numpy as np
@@ -1238,3 +1243,260 @@ def res(self, new_res: float) -> None:
1238
1243
self ._res = new_res
1239
1244
self .datasets [0 ].res = new_res
1240
1245
self .datasets [1 ].res = new_res
1246
+
1247
+
1248
+
1249
+ class RioXarrayDataset (GeoDataset ):
1250
+ """Wrapper for geographical datasets stored as Xarray Datasets.
1251
+
1252
+ In-memory geographical xarray.DataArray and xarray.Dataset.
1253
+
1254
+ Relies on rioxarray.
1255
+
1256
+ .. versionadded:: 0.7.0
1257
+ """
1258
+
1259
+ filename_glob = "*"
1260
+ filename_regex = ".*"
1261
+
1262
+ is_image = True
1263
+
1264
+ spatial_x_name = "x"
1265
+ spatial_y_name = "y"
1266
+
1267
+ transform = None
1268
+
1269
+ @property
1270
+ def dtype (self ) -> torch .dtype :
1271
+ """The dtype of the dataset (overrides the dtype of the data file via a cast).
1272
+
1273
+ Returns:
1274
+ the dtype of the dataset
1275
+ """
1276
+ if self .is_image :
1277
+ return torch .float32
1278
+ else :
1279
+ return torch .long
1280
+
1281
+ def harmonize_format (self , ds ):
1282
+ """Convert the dataset to the standard format.
1283
+
1284
+ Args:
1285
+ ds: dataset or array to harmonize
1286
+
1287
+ Returns:
1288
+ the harmonized dataset or array
1289
+ """
1290
+ # rioxarray expects spatial dimensions to be named x and y
1291
+ ds .rio .set_spatial_dims (self .spatial_x_name , self .spatial_y_name , inplace = True )
1292
+
1293
+ # if x coords go from 0 to 360, convert to -180 to 180
1294
+ if ds [self .spatial_x_name ].min () > 180 :
1295
+ ds = ds .assign_coords (
1296
+ {self .spatial_x_name : ds [self .spatial_x_name ] % 360 - 180 }
1297
+ )
1298
+
1299
+ # if y coords go from 0 to 180, convert to -90 to 90
1300
+ if ds [self .spatial_x_name ].min () > 90 :
1301
+ ds = ds .assign_coords (
1302
+ {self .spatial_y_name : ds [self .spatial_y_name ] % 180 - 90 }
1303
+ )
1304
+ # expect asceding coordinate values
1305
+ ds = ds .sortby (self .spatial_x_name , ascending = True )
1306
+ ds = ds .sortby (self .spatial_y_name , ascending = True )
1307
+ return ds
1308
+
1309
+ def __init__ (
1310
+ self ,
1311
+ paths : Path | Iterable [Path ] = 'data' ,
1312
+ data_variables : list [str ] | None = None ,
1313
+ # crs: Optional[CRS] = None,
1314
+ transforms : Callable [[dict [str , Any ]], dict [str , Any ]] | None = None ,
1315
+ ) -> None :
1316
+ """Initialize a new Dataset instance.
1317
+
1318
+ Args:
1319
+ paths: one or more root directories to search or files to load
1320
+ data_variables: data variables that should be gathered from the collection
1321
+ of xarray datasets
1322
+ transforms: a function/transform that takes an input sample
1323
+ and returns a transformed version
1324
+
1325
+ Raises:
1326
+ FileNotFoundError: if files are not found in ``paths``
1327
+ """
1328
+ super ().__init__ (transforms )
1329
+
1330
+ self .paths = paths
1331
+
1332
+ if data_variables :
1333
+ self .data_variables = data_variables
1334
+ else :
1335
+ data_variables_to_collect : list [str ] = []
1336
+
1337
+ self .transforms = transforms
1338
+
1339
+ # Create an R-tree to index the dataset
1340
+ self .index = Index (interleaved = False , properties = Property (dimension = 3 ))
1341
+
1342
+ # Populate the dataset index
1343
+ i = 0
1344
+ pathname = os .path .join (root , self .filename_glob )
1345
+ filename_regex = re .compile (self .filename_regex , re .VERBOSE )
1346
+ for filepath in glob .iglob (pathname , recursive = True ):
1347
+ match = re .match (filename_regex , os .path .basename (filepath ))
1348
+ if match is not None :
1349
+ with xr .open_dataset (filepath , decode_times = True ) as ds :
1350
+ ds = self .harmonize_format (ds )
1351
+
1352
+ try :
1353
+ (minx , miny , maxx , maxy ) = ds .rio .bounds ()
1354
+ except AttributeError :
1355
+ # or take the shape of the data variable?
1356
+ continue
1357
+
1358
+ if hasattr (ds , "time" ):
1359
+ try :
1360
+ indices = ds .indexes ["time" ].to_datetimeindex ()
1361
+ except AttributeError :
1362
+ indices = ds .indexes ["time" ]
1363
+
1364
+ mint = indices .min ().to_pydatetime ().timestamp ()
1365
+ maxt = indices .max ().to_pydatetime ().timestamp ()
1366
+ else :
1367
+ mint = 0
1368
+ maxt = sys .maxsize
1369
+ coords = (minx , maxx , miny , maxy , mint , maxt )
1370
+ self .index .insert (i , coords , filepath )
1371
+ i += 1
1372
+
1373
+ # collect all possible data variables if self.data_variables is None
1374
+ if not data_variables :
1375
+ data_variables_to_collect .extend (list (ds .data_vars ))
1376
+
1377
+ if i == 0 :
1378
+ import pdb
1379
+
1380
+ pdb .set_trace ()
1381
+ msg = f"No { self .__class__ .__name__ } data was found in `paths='{ self .paths } '`"
1382
+ raise FileNotFoundError (msg )
1383
+
1384
+ if not data_variables :
1385
+ self .data_variables = list (set (data_variables_to_collect ))
1386
+
1387
+ # if not crs:
1388
+ # self._crs = "EPSG:4326"
1389
+ # else:
1390
+ # self._crs = cast(CRS, crs)
1391
+ self .res = 1.0
1392
+
1393
+ def _infer_spatial_coordinate_names (self , ds ) -> tuple [str ]:
1394
+ """Infer the names of the spatial coordinates.
1395
+
1396
+ Args:
1397
+ ds: Dataset or DataArray of which to infer the spatial coordinates
1398
+
1399
+ Returns:
1400
+ x and y coordinate names
1401
+ """
1402
+ x_name = None
1403
+ y_name = None
1404
+ for coord_name , coord in ds .coords .items ():
1405
+ if hasattr (coord , "units" ):
1406
+ if any (
1407
+ [
1408
+ x in coord .units .lower ()
1409
+ for x in ["degrees_north" , "degree_north" ]
1410
+ ]
1411
+ ):
1412
+ y_name = coord_name
1413
+ elif any (
1414
+ [x in coord .units .lower () for x in ["degrees_east" , "degree_east" ]]
1415
+ ):
1416
+ x_name = coord_name
1417
+
1418
+ if not x_name or not y_name :
1419
+ raise ValueError ("Spatial Coordinate Units not found in Dataset." )
1420
+
1421
+ return x_name , y_name
1422
+
1423
+ def __getitem__ (self , query : BoundingBox ) -> dict [str , Any ]:
1424
+ """Retrieve image/mask and metadata indexed by query.
1425
+
1426
+ Args:
1427
+ query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
1428
+
1429
+ Returns:
1430
+ sample of image/mask and metadata at that index
1431
+
1432
+ Raises:
1433
+ IndexError: if query is not found in the index
1434
+ """
1435
+ hits = self .index .intersection (tuple (query ), objects = True )
1436
+ items = [hit .object for hit in hits ]
1437
+
1438
+ if not items :
1439
+ raise IndexError (
1440
+ f"query: { query } not found in index with bounds: { self .bounds } "
1441
+ )
1442
+
1443
+ data_arrays : list ["np.typing.NDArray" ] = []
1444
+ for item in items :
1445
+ with xr .open_dataset (item , decode_cf = True ) as ds :
1446
+ ds = self .harmonize_format (ds )
1447
+ # select time dimension
1448
+ if hasattr (ds , "time" ):
1449
+ try :
1450
+ ds ["time" ] = ds .indexes ["time" ].to_datetimeindex ()
1451
+ except AttributeError :
1452
+ ds ["time" ] = ds .indexes ["time" ]
1453
+ ds = ds .sel (
1454
+ time = slice (
1455
+ datetime .fromtimestamp (query .mint ),
1456
+ datetime .fromtimestamp (query .maxt ),
1457
+ )
1458
+ )
1459
+
1460
+ for variable in self .data_variables :
1461
+ if hasattr (ds , variable ):
1462
+ da = ds [variable ]
1463
+ # if not da.rio.crs:
1464
+ # da.rio.write_crs(self._crs, inplace=True)
1465
+ # elif da.rio.crs != self._crs:
1466
+ # da = da.rio.reproject(self._crs)
1467
+ # clip box ignores time dimension
1468
+ clipped = da .rio .clip_box (
1469
+ minx = query .minx ,
1470
+ miny = query .miny ,
1471
+ maxx = query .maxx ,
1472
+ maxy = query .maxy ,
1473
+ )
1474
+ # rioxarray expects this order
1475
+ clipped = clipped .transpose (
1476
+ "time" , self .spatial_y_name , self .spatial_x_name , ...
1477
+ )
1478
+
1479
+ # set proper transform # TODO not working
1480
+ # clipped.rio.write_transform(self.transform)
1481
+ data_arrays .append (clipped .squeeze ())
1482
+
1483
+ import pdb
1484
+
1485
+ pdb .set_trace ()
1486
+ merged_data = torch .from_numpy (
1487
+ merge_arrays (
1488
+ data_arrays , bounds = (query .minx , query .miny , query .maxx , query .maxy )
1489
+ ).data
1490
+ )
1491
+ sample = {"bbox" : query }
1492
+
1493
+ merged_data = merged_data .to (self .dtype )
1494
+ if self .is_image :
1495
+ sample ["image" ] = merged_data
1496
+ else :
1497
+ sample ["mask" ] = merged_data
1498
+
1499
+ if self .transforms is not None :
1500
+ sample = self .transforms (sample )
1501
+
1502
+ return sample
0 commit comments