Skip to content

Commit

Permalink
Add BRIGHT dataset (#2520)
Browse files Browse the repository at this point in the history
* bright

* bright tests

* bright

* run ruff

* mypy and docs

* ruff on data.py

* ruff on bright

* docs

* ruff

* rm datamodule

* coverage

* request

* Update docs/api/datasets/non_geo_datasets.csv

Co-authored-by: Adam J. Stewart <[email protected]>

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
nilsleh and adamjstewart authored Jan 28, 2025
1 parent 662a883 commit e42c404
Show file tree
Hide file tree
Showing 22 changed files with 560 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ BioMassters

.. autoclass:: BioMassters

BRIGHT
^^^^^^

.. autoclass:: BRIGHTDFC2025

CaBuAr
^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`Benin Cashew Plantations`_,S,Airbus Pléiades,"CC-BY-4.0",70,6,"1,122x1,186",10,MSI
`BigEarthNet`_,C,Sentinel-1/2,"CDLA-Permissive-1.0","590,326",19--43,120x120,10,"SAR, MSI"
`BioMassters`_,R,Sentinel-1/2 and Lidar,"CC-BY-4.0",,,256x256, 10, "SAR, MSI"
`BRIGHT`_,CD,"MAXAR, NAIP, Capella, Umbra","CC-BY-4.0 AND CC-BY-NC-4.0",3239,4,"0.1-1","RGB,SAR"
`CaBuAr`_,CD,Sentinel-2,"OpenRAIL",424,2,512x512,20,MSI
`CaFFe`_,S,"Sentinel-1, TerraSAR-X, TanDEM-X, ENVISAT, ERS-1/2, ALOS PALSAR, and RADARSAT-1","CC-BY-4.0","19092","2 or 4","512x512",6-20,"SAR"
`ChaBuD`_,CD,Sentinel-2,"OpenRAIL",356,2,512x512,10,MSI
Expand Down
117 changes: 117 additions & 0 deletions tests/data/bright/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil

import numpy as np
import rasterio

ROOT = '.'
DATA_DIR = 'dfc25_track2_trainval'

TRAIN_FILE = 'train_setlevel.txt'
HOLDOUT_FILE = 'holdout_setlevel.txt'
VAL_FILE = 'val_setlevel.txt'

TRAIN_IDS = [
'bata-explosion_00000049',
'bata-explosion_00000014',
'bata-explosion_00000047',
]
HOLDOUT_IDS = ['turkey-earthquake_00000413']
VAL_IDS = ['val-disaster_00000001', 'val-disaster_00000002']

SIZE = 32


def make_dirs() -> None:
paths = [
os.path.join(ROOT, DATA_DIR),
os.path.join(ROOT, DATA_DIR, 'train', 'pre-event'),
os.path.join(ROOT, DATA_DIR, 'train', 'post-event'),
os.path.join(ROOT, DATA_DIR, 'train', 'target'),
os.path.join(ROOT, DATA_DIR, 'val', 'pre-event'),
os.path.join(ROOT, DATA_DIR, 'val', 'post-event'),
os.path.join(ROOT, DATA_DIR, 'val', 'target'),
]
for p in paths:
os.makedirs(p, exist_ok=True)


def write_list_file(filename: str, ids: list[str]) -> None:
file_path = os.path.join(ROOT, DATA_DIR, filename)
with open(file_path, 'w') as f:
for sid in ids:
f.write(f'{sid}\n')


def write_tif(filepath: str, channels: int) -> None:
data = np.random.randint(0, 255, (channels, SIZE, SIZE), dtype=np.uint8)
# transform = from_origin(0, 0, 1, 1)
crs = 'epsg:4326'
with rasterio.open(
filepath,
'w',
driver='GTiff',
height=SIZE,
width=SIZE,
count=channels,
crs=crs,
dtype=data.dtype,
compress='lzw',
# transform=transform,
) as dst:
dst.write(data)


def populate_data(ids: list[str], dir_name: str, with_target: bool = True) -> None:
for sid in ids:
pre_path = os.path.join(
ROOT, DATA_DIR, dir_name, 'pre-event', f'{sid}_pre_disaster.tif'
)
write_tif(pre_path, channels=3)
post_path = os.path.join(
ROOT, DATA_DIR, dir_name, 'post-event', f'{sid}_post_disaster.tif'
)
write_tif(post_path, channels=1)
if with_target:
target_path = os.path.join(
ROOT, DATA_DIR, dir_name, 'target', f'{sid}_building_damage.tif'
)
write_tif(target_path, channels=1)


def main() -> None:
make_dirs()

# Write the ID lists to text files
write_list_file(TRAIN_FILE, TRAIN_IDS)
write_list_file(HOLDOUT_FILE, HOLDOUT_IDS)
write_list_file(VAL_FILE, VAL_IDS)

# Generate TIF files for the train (with target) and val (no target) splits
populate_data(TRAIN_IDS, 'train', with_target=True)
populate_data(HOLDOUT_IDS, 'train', with_target=True)
populate_data(VAL_IDS, 'val', with_target=False)

# zip and compute md5
zip_filename = os.path.join(ROOT, 'dfc25_track2_trainval')
shutil.make_archive(zip_filename, 'zip', ROOT, DATA_DIR)

def md5(fname: str) -> str:
hash_md5 = hashlib.md5()
with open(fname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b''):
hash_md5.update(chunk)
return hash_md5.hexdigest()

md5sum = md5(zip_filename + '.zip')
print(f'MD5 checksum: {md5sum}')


if __name__ == '__main__':
main()
Binary file added tests/data/bright/dfc25_track2_trainval.zip
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
turkey-earthquake_00000413
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
3 changes: 3 additions & 0 deletions tests/data/bright/dfc25_track2_trainval/train_setlevel.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
bata-explosion_00000049
bata-explosion_00000014
bata-explosion_00000047
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 2 additions & 0 deletions tests/data/bright/dfc25_track2_trainval/val_setlevel.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
val-disaster_00000001
val-disaster_00000002
89 changes: 89 additions & 0 deletions tests/datasets/test_bright.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch

from torchgeo.datasets import BRIGHTDFC2025, DatasetNotFoundError


class TestBRIGHTDFC2025:
@pytest.fixture(params=['train', 'val', 'test'])
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> BRIGHTDFC2025:
md5 = '7b0e24d45fb2d9a4f766196702586414'
monkeypatch.setattr(BRIGHTDFC2025, 'md5', md5)
url = os.path.join('tests', 'data', 'bright', 'dfc25_track2_trainval.zip')
monkeypatch.setattr(BRIGHTDFC2025, 'url', url)
root = tmp_path
split = request.param
transforms = nn.Identity()
return BRIGHTDFC2025(root, split, transforms, download=True, checksum=True)

def test_getitem(self, dataset: BRIGHTDFC2025) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image_pre'], torch.Tensor)
assert x['image_pre'].shape[0] == 3
assert isinstance(x['image_post'], torch.Tensor)
assert x['image_post'].shape[0] == 3
assert x['image_pre'].shape[-2:] == x['image_post'].shape[-2:]
if dataset.split != 'test':
assert isinstance(x['mask'], torch.Tensor)
assert x['image_pre'].shape[-2:] == x['mask'].shape[-2:]

def test_len(self, dataset: BRIGHTDFC2025) -> None:
if dataset.split == 'train':
assert len(dataset) == 3
elif dataset.split == 'val':
assert len(dataset) == 1
else:
assert len(dataset) == 2

def test_already_downloaded(self, dataset: BRIGHTDFC2025) -> None:
BRIGHTDFC2025(root=dataset.root)

def test_not_yet_extracted(self, tmp_path: Path) -> None:
filename = 'dfc25_track2_trainval.zip'
dir = os.path.join('tests', 'data', 'bright')
shutil.copyfile(
os.path.join(dir, filename), os.path.join(str(tmp_path), filename)
)
BRIGHTDFC2025(root=str(tmp_path))

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
BRIGHTDFC2025(split='foo')

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
BRIGHTDFC2025(tmp_path)

def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, 'dfc25_track2_trainval.zip'), 'w') as f:
f.write('bad')
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
BRIGHTDFC2025(root=tmp_path, checksum=True)

def test_plot(self, dataset: BRIGHTDFC2025) -> None:
dataset.plot(dataset[0], suptitle='Test')
plt.close()

if dataset.split != 'test':
sample = dataset[0]
sample['prediction'] = torch.clone(sample['mask'])
dataset.plot(sample, suptitle='Prediction')
plt.close()

del sample['mask']
dataset.plot(sample, suptitle='Only Prediction')
plt.close()
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .benin_cashews import BeninSmallHolderCashews
from .bigearthnet import BigEarthNet
from .biomassters import BioMassters
from .bright import BRIGHTDFC2025
from .cabuar import CaBuAr
from .caffe import CaFFe
from .cbf import CanadianBuildingFootprints
Expand Down Expand Up @@ -152,6 +153,7 @@

__all__ = (
'ADVANCE',
'BRIGHTDFC2025',
'CDL',
'COWC',
'DFC2022',
Expand Down
Loading

0 comments on commit e42c404

Please sign in to comment.