Skip to content

Commit

Permalink
Merge branch 'main' into benv2
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsleh authored Feb 19, 2025
2 parents 70fc4c7 + 8055d88 commit 788454f
Show file tree
Hide file tree
Showing 27 changed files with 742 additions and 8 deletions.
4 changes: 4 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ Digital Typhoon

.. autoclass:: DigitalTyphoon

DL4GAM
^^^^^^
.. autoclass:: DL4GAMAlps

ETCI2021 Flood Detection
^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`DeepGlobe Land Cover`_,S,DigitalGlobe +Vivid,-,803,7,"2,448x2,448",0.5,RGB
`DFC2022`_,S,Aerial,"CC-BY-4.0","3,981",15,"2,000x2,000",0.5,RGB
`Digital Typhoon`_,"C, R",Himawari,"CC-BY-4.0","189,364",8,512,5000,Infrared
`DL4GAM`_,S,"Sentinel-2","CC-BY-4.0","2,251 or 11,440","2","256x256","10","MSI"
`ETCI2021 Flood Detection`_,S,Sentinel-1,-,"66,810",2,256x256,5--20,SAR
`EuroSAT`_,C,Sentinel-2,"MIT","27,000",10,64x64,10,MSI
`FAIR1M`_,OD,Gaofen/Google Earth,"CC-BY-NC-SA-3.0","15,000",37,"1,024x1,024",0.3--0.8,RGB
Expand Down
6 changes: 3 additions & 3 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ maintainers = [
]
keywords = ["pytorch", "deep learning", "machine learning", "remote sensing", "satellite imagery", "earth observation", "geospatial"]
classifiers = [
"Development Status :: 3 - Alpha",
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
Expand Down Expand Up @@ -91,6 +91,8 @@ datasets = [
"h5py>=3.6",
# laspy 2+ required for laspy.read
"laspy>=2",
# netcdf4 1.5.8+ required for Python 3.10 wheels
"netcdf4>=1.5.8",
# opencv-python 4.5.4+ required for Python 3.10 wheels
"opencv-python>=4.5.4",
# pandas 2+ required for parquet extra
Expand All @@ -101,6 +103,8 @@ datasets = [
"scikit-image>=0.19",
# scipy 1.7.2+ required for Python 3.10 wheels
"scipy>=1.7.2",
# xarray 0.12.3+ required for pandas 1.3.3 support
"xarray>=0.12.3",
]
docs = [
# ipywidgets 7+ required by nbsphinx
Expand Down
4 changes: 3 additions & 1 deletion requirements/datasets.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# datasets
h5py==3.12.1
laspy==2.5.4
netcdf4==1.7.2
opencv-python==4.11.0.86
pandas[parquet]==2.2.3
pycocotools==2.0.8
scikit-image==0.25.1
scipy==1.15.1
scipy==1.15.2
xarray==2024.11.0
2 changes: 2 additions & 0 deletions requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ typing-extensions==4.5.0
# datasets
h5py==3.6.0
laspy==2.0.0
netCDF4==1.5.8
opencv-python==4.5.4.58
pycocotools==2.0.7
pyarrow==15.0.0 # Remove when we upgrade min version of pandas to `pandas[parquet]>=2`
scikit-image==0.19.0
scipy==1.7.2
xarray==0.12.3

# tests
pytest==7.3.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ matplotlib==3.10.0
numpy==2.2.3
pandas==2.2.3
pillow==11.1.0
pyproj==3.7.0
pyproj==3.7.1
rasterio==1.4.3
rtree==1.3.0
segmentation-models-pytorch==0.4.0
Expand Down
143 changes: 143 additions & 0 deletions tests/data/dl4gam_alps/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import shutil
from pathlib import Path

import numpy as np
import pandas as pd
import xarray as xr

# define the patch size
PATCH_SIZE = 16

# create a random generator
rg = np.random.RandomState(42)


def create_dummy_sample(fp: str | Path) -> None:
# create the random S2 bands data; make the last two bands as binary masks
band_data = rg.randint(
low=0, high=10000, dtype=np.int16, size=(15, PATCH_SIZE, PATCH_SIZE)
)
band_data[-2:] = (band_data[-2:] > 5000).astype(np.int16)

data_dict = {
'band_data': {
'dims': ('band', 'y', 'x'),
'data': band_data,
'attrs': {
'long_name': [
'B1',
'B2',
'B3',
'B4',
'B5',
'B6',
'B7',
'B8',
'B8A',
'B9',
'B10',
'B11',
'B12',
'CLOUDLESS_MASK',
'FILL_MASK',
],
'_FillValue': -9999,
},
},
'mask_all_g_id': { # glaciers mask (with -1 for no-glacier and GLACIER_ID for glacier)
'dims': ('y', 'x'),
'data': rg.choice([-1, 8, 9, 30, 35], size=(PATCH_SIZE, PATCH_SIZE)).astype(
np.int32
),
'attrs': {'_FillValue': -1},
},
'mask_debris': {
'dims': ('y', 'x'),
'data': (rg.random((PATCH_SIZE, PATCH_SIZE)) > 0.5).astype(np.int8),
'attrs': {'_FillValue': -1},
},
}

# add the additional variables
for v in [
'dem',
'slope',
'aspect',
'planform_curvature',
'profile_curvature',
'terrain_ruggedness_index',
'dhdt',
'v',
]:
data_dict[v] = {
'dims': ('y', 'x'),
'data': (rg.random((PATCH_SIZE, PATCH_SIZE)) * 100).astype(np.float32),
'attrs': {'_FillValue': -9999},
}

# create the xarray dataset and save it
nc = xr.Dataset.from_dict(data_dict)
nc.to_netcdf(fp)


def create_splits_df(fp: str | Path) -> pd.DataFrame:
# create a dataframe with the splits for the 4 glaciers
splits_df = pd.DataFrame(
{
'entry_id': ['g_0008', 'g_0009', 'g_0030', 'g_0035'],
'split_1': ['fold_train', 'fold_train', 'fold_valid', 'fold_test'],
'split_2': ['fold_train', 'fold_valid', 'fold_train', 'fold_test'],
'split_3': ['fold_train', 'fold_valid', 'fold_test', 'fold_train'],
'split_4': ['fold_test', 'fold_valid', 'fold_train', 'fold_train'],
'split_5': ['fold_test', 'fold_train', 'fold_train', 'fold_valid'],
}
)

splits_df.to_csv(fp_splits, index=False)
print(f'Splits dataframe saved to {fp_splits}')
return splits_df


if __name__ == '__main__':
# prepare the paths
fp_splits = Path('splits.csv')
fp_dir_ds_small = Path('dataset_small')
fp_dir_ds_large = Path('dataset_large')

# cleanup
fp_splits.unlink(missing_ok=True)
fp_dir_ds_small.with_suffix('.tar.gz').unlink(missing_ok=True)
fp_dir_ds_large.with_suffix('.tar.gz').unlink(missing_ok=True)
shutil.rmtree(fp_dir_ds_small, ignore_errors=True)
shutil.rmtree(fp_dir_ds_large, ignore_errors=True)

# create the splits dataframe
split_df = create_splits_df(fp_splits)

# create the two datasets versions (small and large) with 1 and 2 patches per glacier, respectively
for fp_dir, num_patches in zip([fp_dir_ds_small, fp_dir_ds_large], [1, 2]):
for glacier_id in split_df.entry_id:
for i in range(num_patches):
fp = fp_dir / glacier_id / f'{glacier_id}_patch_{i}.nc'
fp.parent.mkdir(parents=True, exist_ok=True)
create_dummy_sample(fp=fp)

# archive the datasets
for fp_dir in [fp_dir_ds_small, fp_dir_ds_large]:
shutil.make_archive(str(fp_dir), 'gztar', fp_dir)

# compute checksums
for fp in [
fp_dir_ds_small.with_suffix('.tar.gz'),
fp_dir_ds_large.with_suffix('.tar.gz'),
fp_splits,
]:
with open(fp, 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f'md5 for {fp}: {md5}')
Binary file added tests/data/dl4gam_alps/dataset_large.tar.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/dl4gam_alps/dataset_small.tar.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
5 changes: 5 additions & 0 deletions tests/data/dl4gam_alps/splits.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
entry_id,split_1,split_2,split_3,split_4,split_5
g_0008,fold_train,fold_train,fold_train,fold_test,fold_test
g_0009,fold_train,fold_valid,fold_valid,fold_valid,fold_train
g_0030,fold_valid,fold_train,fold_test,fold_train,fold_train
g_0035,fold_test,fold_test,fold_train,fold_train,fold_valid
125 changes: 125 additions & 0 deletions tests/datasets/test_dl4gam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch

from torchgeo.datasets import DatasetNotFoundError, DL4GAMAlps, RGBBandsMissingError

pytest.importorskip('xarray', minversion='0.12.3')
pytest.importorskip('netCDF4', minversion='1.5.8')


class TestDL4GAMAlps:
@pytest.fixture(
params=zip(
['train', 'val', 'test'],
[1, 3, 5],
['small', 'small', 'large'],
[DL4GAMAlps.rgb_bands, DL4GAMAlps.rgb_nir_swir_bands, DL4GAMAlps.all_bands],
[None, ['dem'], DL4GAMAlps.valid_extra_features],
)
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> DL4GAMAlps:
url = Path('tests', 'data', 'dl4gam_alps')
download_metadata = {
'dataset_small': {
'url': str(url / 'dataset_small.tar.gz'),
'checksum': '35f85360b943caa8661d9fb573b0f0b5',
},
'dataset_large': {
'url': str(url / 'dataset_large.tar.gz'),
'checksum': '636be5be35b8bd1e7771e9010503e4bc',
},
'splits_csv': {
'url': str(url / 'splits.csv'),
'checksum': '973367465c8ab322d0cf544a345b02f5',
},
}

monkeypatch.setattr(DL4GAMAlps, 'download_metadata', download_metadata)
root = tmp_path
split, cv_iter, version, bands, extra_features = request.param
transforms = nn.Identity()
return DL4GAMAlps(
root,
split,
cv_iter,
version,
bands,
extra_features,
transforms,
download=True,
checksum=True,
)

def test_getitem(self, dataset: DL4GAMAlps) -> None:
x = dataset[0]
assert isinstance(x, dict)

var_names = ['image', 'mask_glacier', 'mask_debris', 'mask_clouds_and_shadows']
if dataset.extra_features:
var_names += list(dataset.extra_features)
for v in var_names:
assert v in x
assert isinstance(x[v], torch.Tensor)

# check if all variables have the same spatial dimensions as the image
assert x['image'].shape[-2:] == x[v].shape[-2:]

# check the first dimension of the image tensor
assert x['image'].shape[0] == len(dataset.bands)

def test_len(self, dataset: DL4GAMAlps) -> None:
num_glaciers_per_fold = 2 if dataset.split == 'train' else 1
num_patches_per_glacier = 1 if dataset.version == 'small' else 2
assert len(dataset) == num_glaciers_per_fold * num_patches_per_glacier

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
DL4GAMAlps(tmp_path)

def test_already_downloaded_and_extracted(self, dataset: DL4GAMAlps) -> None:
DL4GAMAlps(root=dataset.root, download=False, version=dataset.version)

def test_already_downloaded_but_not_yet_extracted(self, tmp_path: Path) -> None:
fp_archive = Path('tests', 'data', 'dl4gam_alps', 'dataset_small.tar.gz')
shutil.copyfile(fp_archive, Path(str(tmp_path), fp_archive.name))
fp_splits = Path('tests', 'data', 'dl4gam_alps', 'splits.csv')
shutil.copyfile(fp_splits, Path(str(tmp_path), fp_splits.name))
DL4GAMAlps(root=str(tmp_path), download=False)

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
DL4GAMAlps(split='foo')

def test_plot(self, dataset: DL4GAMAlps) -> None:
dataset.plot(dataset[0], suptitle='Test')
plt.close()

sample = dataset[0]
sample['prediction'] = torch.clone(sample['mask_glacier'])
dataset.plot(sample, suptitle='Test with prediction')
plt.close()

def test_plot_wrong_bands(self, dataset: DL4GAMAlps) -> None:
ds = DL4GAMAlps(
root=dataset.root,
split=dataset.split,
cv_iter=dataset.cv_iter,
version=dataset.version,
bands=('B3',),
)
with pytest.raises(
RGBBandsMissingError, match='Dataset does not contain some of the RGB bands'
):
ds.plot(dataset[0], suptitle='Single Band')
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .deepglobelandcover import DeepGlobeLandCover
from .dfc2022 import DFC2022
from .digital_typhoon import DigitalTyphoon
from .dl4gam import DL4GAMAlps
from .eddmaps import EDDMapS
from .enmap import EnMAP
from .enviroatlas import EnviroAtlas
Expand Down Expand Up @@ -206,6 +207,7 @@
'ChesapeakeWV',
'CloudCoverDetection',
'CropHarvest',
'DL4GAMAlps',
'DatasetNotFoundError',
'DeepGlobeLandCover',
'DependencyNotFoundError',
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/digital_typhoon.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ class DigitalTyphoon(NonGeoDataset):
url = 'https://hf.co/datasets/torchgeo/digital_typhoon/resolve/cf2f9ef89168d31cb09e42993d35b068688fe0df/WP.tar.gz{0}'

md5sums: ClassVar[dict[str, str]] = {
'aa': '3af98052aed17e0ddb1e94caca2582e2',
'ab': '2c5d25455ac8aef1de33fe6456ab2c8d',
'aa': '9e77a5f74783f7909dee0fb936053b17',
'ab': '46aebdcba6e4e2df1619e4a3d7e556bb',
}

min_input_clamp = 170.0
Expand Down
Loading

0 comments on commit 788454f

Please sign in to comment.