forked from microsoft/torchgeo
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
27 changed files
with
742 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,10 @@ | ||
# datasets | ||
h5py==3.12.1 | ||
laspy==2.5.4 | ||
netcdf4==1.7.2 | ||
opencv-python==4.11.0.86 | ||
pandas[parquet]==2.2.3 | ||
pycocotools==2.0.8 | ||
scikit-image==0.25.1 | ||
scipy==1.15.1 | ||
scipy==1.15.2 | ||
xarray==2024.11.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import hashlib | ||
import shutil | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import xarray as xr | ||
|
||
# define the patch size | ||
PATCH_SIZE = 16 | ||
|
||
# create a random generator | ||
rg = np.random.RandomState(42) | ||
|
||
|
||
def create_dummy_sample(fp: str | Path) -> None: | ||
# create the random S2 bands data; make the last two bands as binary masks | ||
band_data = rg.randint( | ||
low=0, high=10000, dtype=np.int16, size=(15, PATCH_SIZE, PATCH_SIZE) | ||
) | ||
band_data[-2:] = (band_data[-2:] > 5000).astype(np.int16) | ||
|
||
data_dict = { | ||
'band_data': { | ||
'dims': ('band', 'y', 'x'), | ||
'data': band_data, | ||
'attrs': { | ||
'long_name': [ | ||
'B1', | ||
'B2', | ||
'B3', | ||
'B4', | ||
'B5', | ||
'B6', | ||
'B7', | ||
'B8', | ||
'B8A', | ||
'B9', | ||
'B10', | ||
'B11', | ||
'B12', | ||
'CLOUDLESS_MASK', | ||
'FILL_MASK', | ||
], | ||
'_FillValue': -9999, | ||
}, | ||
}, | ||
'mask_all_g_id': { # glaciers mask (with -1 for no-glacier and GLACIER_ID for glacier) | ||
'dims': ('y', 'x'), | ||
'data': rg.choice([-1, 8, 9, 30, 35], size=(PATCH_SIZE, PATCH_SIZE)).astype( | ||
np.int32 | ||
), | ||
'attrs': {'_FillValue': -1}, | ||
}, | ||
'mask_debris': { | ||
'dims': ('y', 'x'), | ||
'data': (rg.random((PATCH_SIZE, PATCH_SIZE)) > 0.5).astype(np.int8), | ||
'attrs': {'_FillValue': -1}, | ||
}, | ||
} | ||
|
||
# add the additional variables | ||
for v in [ | ||
'dem', | ||
'slope', | ||
'aspect', | ||
'planform_curvature', | ||
'profile_curvature', | ||
'terrain_ruggedness_index', | ||
'dhdt', | ||
'v', | ||
]: | ||
data_dict[v] = { | ||
'dims': ('y', 'x'), | ||
'data': (rg.random((PATCH_SIZE, PATCH_SIZE)) * 100).astype(np.float32), | ||
'attrs': {'_FillValue': -9999}, | ||
} | ||
|
||
# create the xarray dataset and save it | ||
nc = xr.Dataset.from_dict(data_dict) | ||
nc.to_netcdf(fp) | ||
|
||
|
||
def create_splits_df(fp: str | Path) -> pd.DataFrame: | ||
# create a dataframe with the splits for the 4 glaciers | ||
splits_df = pd.DataFrame( | ||
{ | ||
'entry_id': ['g_0008', 'g_0009', 'g_0030', 'g_0035'], | ||
'split_1': ['fold_train', 'fold_train', 'fold_valid', 'fold_test'], | ||
'split_2': ['fold_train', 'fold_valid', 'fold_train', 'fold_test'], | ||
'split_3': ['fold_train', 'fold_valid', 'fold_test', 'fold_train'], | ||
'split_4': ['fold_test', 'fold_valid', 'fold_train', 'fold_train'], | ||
'split_5': ['fold_test', 'fold_train', 'fold_train', 'fold_valid'], | ||
} | ||
) | ||
|
||
splits_df.to_csv(fp_splits, index=False) | ||
print(f'Splits dataframe saved to {fp_splits}') | ||
return splits_df | ||
|
||
|
||
if __name__ == '__main__': | ||
# prepare the paths | ||
fp_splits = Path('splits.csv') | ||
fp_dir_ds_small = Path('dataset_small') | ||
fp_dir_ds_large = Path('dataset_large') | ||
|
||
# cleanup | ||
fp_splits.unlink(missing_ok=True) | ||
fp_dir_ds_small.with_suffix('.tar.gz').unlink(missing_ok=True) | ||
fp_dir_ds_large.with_suffix('.tar.gz').unlink(missing_ok=True) | ||
shutil.rmtree(fp_dir_ds_small, ignore_errors=True) | ||
shutil.rmtree(fp_dir_ds_large, ignore_errors=True) | ||
|
||
# create the splits dataframe | ||
split_df = create_splits_df(fp_splits) | ||
|
||
# create the two datasets versions (small and large) with 1 and 2 patches per glacier, respectively | ||
for fp_dir, num_patches in zip([fp_dir_ds_small, fp_dir_ds_large], [1, 2]): | ||
for glacier_id in split_df.entry_id: | ||
for i in range(num_patches): | ||
fp = fp_dir / glacier_id / f'{glacier_id}_patch_{i}.nc' | ||
fp.parent.mkdir(parents=True, exist_ok=True) | ||
create_dummy_sample(fp=fp) | ||
|
||
# archive the datasets | ||
for fp_dir in [fp_dir_ds_small, fp_dir_ds_large]: | ||
shutil.make_archive(str(fp_dir), 'gztar', fp_dir) | ||
|
||
# compute checksums | ||
for fp in [ | ||
fp_dir_ds_small.with_suffix('.tar.gz'), | ||
fp_dir_ds_large.with_suffix('.tar.gz'), | ||
fp_splits, | ||
]: | ||
with open(fp, 'rb') as f: | ||
md5 = hashlib.md5(f.read()).hexdigest() | ||
print(f'md5 for {fp}: {md5}') |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
entry_id,split_1,split_2,split_3,split_4,split_5 | ||
g_0008,fold_train,fold_train,fold_train,fold_test,fold_test | ||
g_0009,fold_train,fold_valid,fold_valid,fold_valid,fold_train | ||
g_0030,fold_valid,fold_train,fold_test,fold_train,fold_train | ||
g_0035,fold_test,fold_test,fold_train,fold_train,fold_valid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import shutil | ||
from pathlib import Path | ||
|
||
import matplotlib.pyplot as plt | ||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from _pytest.fixtures import SubRequest | ||
from pytest import MonkeyPatch | ||
|
||
from torchgeo.datasets import DatasetNotFoundError, DL4GAMAlps, RGBBandsMissingError | ||
|
||
pytest.importorskip('xarray', minversion='0.12.3') | ||
pytest.importorskip('netCDF4', minversion='1.5.8') | ||
|
||
|
||
class TestDL4GAMAlps: | ||
@pytest.fixture( | ||
params=zip( | ||
['train', 'val', 'test'], | ||
[1, 3, 5], | ||
['small', 'small', 'large'], | ||
[DL4GAMAlps.rgb_bands, DL4GAMAlps.rgb_nir_swir_bands, DL4GAMAlps.all_bands], | ||
[None, ['dem'], DL4GAMAlps.valid_extra_features], | ||
) | ||
) | ||
def dataset( | ||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest | ||
) -> DL4GAMAlps: | ||
url = Path('tests', 'data', 'dl4gam_alps') | ||
download_metadata = { | ||
'dataset_small': { | ||
'url': str(url / 'dataset_small.tar.gz'), | ||
'checksum': '35f85360b943caa8661d9fb573b0f0b5', | ||
}, | ||
'dataset_large': { | ||
'url': str(url / 'dataset_large.tar.gz'), | ||
'checksum': '636be5be35b8bd1e7771e9010503e4bc', | ||
}, | ||
'splits_csv': { | ||
'url': str(url / 'splits.csv'), | ||
'checksum': '973367465c8ab322d0cf544a345b02f5', | ||
}, | ||
} | ||
|
||
monkeypatch.setattr(DL4GAMAlps, 'download_metadata', download_metadata) | ||
root = tmp_path | ||
split, cv_iter, version, bands, extra_features = request.param | ||
transforms = nn.Identity() | ||
return DL4GAMAlps( | ||
root, | ||
split, | ||
cv_iter, | ||
version, | ||
bands, | ||
extra_features, | ||
transforms, | ||
download=True, | ||
checksum=True, | ||
) | ||
|
||
def test_getitem(self, dataset: DL4GAMAlps) -> None: | ||
x = dataset[0] | ||
assert isinstance(x, dict) | ||
|
||
var_names = ['image', 'mask_glacier', 'mask_debris', 'mask_clouds_and_shadows'] | ||
if dataset.extra_features: | ||
var_names += list(dataset.extra_features) | ||
for v in var_names: | ||
assert v in x | ||
assert isinstance(x[v], torch.Tensor) | ||
|
||
# check if all variables have the same spatial dimensions as the image | ||
assert x['image'].shape[-2:] == x[v].shape[-2:] | ||
|
||
# check the first dimension of the image tensor | ||
assert x['image'].shape[0] == len(dataset.bands) | ||
|
||
def test_len(self, dataset: DL4GAMAlps) -> None: | ||
num_glaciers_per_fold = 2 if dataset.split == 'train' else 1 | ||
num_patches_per_glacier = 1 if dataset.version == 'small' else 2 | ||
assert len(dataset) == num_glaciers_per_fold * num_patches_per_glacier | ||
|
||
def test_not_downloaded(self, tmp_path: Path) -> None: | ||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'): | ||
DL4GAMAlps(tmp_path) | ||
|
||
def test_already_downloaded_and_extracted(self, dataset: DL4GAMAlps) -> None: | ||
DL4GAMAlps(root=dataset.root, download=False, version=dataset.version) | ||
|
||
def test_already_downloaded_but_not_yet_extracted(self, tmp_path: Path) -> None: | ||
fp_archive = Path('tests', 'data', 'dl4gam_alps', 'dataset_small.tar.gz') | ||
shutil.copyfile(fp_archive, Path(str(tmp_path), fp_archive.name)) | ||
fp_splits = Path('tests', 'data', 'dl4gam_alps', 'splits.csv') | ||
shutil.copyfile(fp_splits, Path(str(tmp_path), fp_splits.name)) | ||
DL4GAMAlps(root=str(tmp_path), download=False) | ||
|
||
def test_invalid_split(self) -> None: | ||
with pytest.raises(AssertionError): | ||
DL4GAMAlps(split='foo') | ||
|
||
def test_plot(self, dataset: DL4GAMAlps) -> None: | ||
dataset.plot(dataset[0], suptitle='Test') | ||
plt.close() | ||
|
||
sample = dataset[0] | ||
sample['prediction'] = torch.clone(sample['mask_glacier']) | ||
dataset.plot(sample, suptitle='Test with prediction') | ||
plt.close() | ||
|
||
def test_plot_wrong_bands(self, dataset: DL4GAMAlps) -> None: | ||
ds = DL4GAMAlps( | ||
root=dataset.root, | ||
split=dataset.split, | ||
cv_iter=dataset.cv_iter, | ||
version=dataset.version, | ||
bands=('B3',), | ||
) | ||
with pytest.raises( | ||
RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' | ||
): | ||
ds.plot(dataset[0], suptitle='Single Band') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.