Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BigEarthNet Version2 #2531

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ BigEarthNet
^^^^^^^^^^^

.. autoclass:: BigEarthNet
.. autoclass:: BigEarthNetV2

BioMassters
^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions requirements/datasets.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ pandas[parquet]==2.2.3
pycocotools==2.0.8
scikit-image==0.25.0
scipy==1.15.1
zstandard==0.23.0
Binary file added tests/data/bigearthnetV2/BigEarthNet-S1.tar.zst
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/bigearthnetV2/BigEarthNet-S2.tar.zst
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/bigearthnetV2/Reference_Maps.tar.zst
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
174 changes: 174 additions & 0 deletions tests/data/bigearthnetV2/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
import numpy as np
import pandas as pd
from pathlib import Path
import rasterio
import tarfile
import zstandard as zstd

Check failure on line 13 in tests/data/bigearthnetV2/data.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

tests/data/bigearthnetV2/data.py:6:1: I001 Import block is un-sorted or un-formatted

# Constants
IMG_SIZE = 120
ROOT_DIR = '.'

# Sample patch definitions
SAMPLE_PATCHES = [
{
's2_name': 'S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP_26_57',
's2_base': 'S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP',
's1_name': 'S1A_IW_GRDH_1SDV_20170613T165043_33UUP_61_39',
's1_base': 'S1A_IW_GRDH_1SDV_20170613T165043',
'split': 'train',
},
{
's2_name': 'S2B_MSIL2A_20170615T102019_N9999_R122_T32TNS_45_23',
's2_base': 'S2B_MSIL2A_20170615T102019_N9999_R122_T32TNS',
's1_name': 'S1A_IW_GRDH_1SDV_20170615T170156_32TNS_77_12',
's1_base': 'S1A_IW_GRDH_1SDV_20170615T170156',
'split': 'val',
},
{
's2_name': 'S2A_MSIL2A_20170618T101021_N9999_R022_T32TQR_89_34',
's2_base': 'S2A_MSIL2A_20170618T101021_N9999_R022_T32TQR',
's1_name': 'S1A_IW_GRDH_1SDV_20170618T165722_32TQR_92_45',
's1_base': 'S1A_IW_GRDH_1SDV_20170618T165722',
'split': 'test',
},
]

S1_BANDS = ['VV', 'VH']
S2_BANDS = [
'B01',
'B02',
'B03',
'B04',
'B05',
'B06',
'B07',
'B08',
'B8A',
'B09',
'B11',
'B12',
]


def create_directory_structure() -> None:
"""Create the base directory structure"""

for dir_name in ['BigEarthNet-S1', 'BigEarthNet-S2', 'Reference_Maps']:
if os.path.exists(os.path.join(ROOT_DIR, dir_name)):
shutil.rmtree(os.path.join(ROOT_DIR, dir_name))
Path(os.path.join(ROOT_DIR, dir_name)).mkdir(parents=True, exist_ok=True)


def create_dummy_image(path: str, shape: tuple[int, int], dtype: str) -> None:
"""Create a dummy GeoTIFF file"""
if dtype == 's1':
data = np.random.randint(-25, 0, shape).astype(np.int16)
elif dtype == 's2':
data = np.random.randint(0, 10000, shape).astype(np.int16)
else: # reference map
data = np.random.randint(0, 19, shape).astype(np.uint8)

with rasterio.open(
path,
'w',
driver='GTiff',
height=shape[0],
width=shape[1],
count=1,
dtype=data.dtype,
crs='+proj=utm +zone=32 +datum=WGS84 +units=m +no_defs',
transform=rasterio.transform.from_origin(0, 0, 10, 10),
) as dst:
dst.write(data, 1)


def generate_sample(patch_info: dict) -> None:
"""Generate a complete sample with S1, S2 and reference data"""
# Create S1 data
s1_dir = os.path.join(
ROOT_DIR, 'BigEarthNet-S1', patch_info['s1_base'], patch_info['s1_name']
)
os.makedirs(s1_dir, exist_ok=True)

for band in S1_BANDS:
path = os.path.join(s1_dir, f'{patch_info["s1_name"]}_{band}.tif')
create_dummy_image(path, (IMG_SIZE, IMG_SIZE), 's1')

# Create S2 data
s2_dir = os.path.join(
ROOT_DIR, 'BigEarthNet-S2', patch_info['s2_base'], patch_info['s2_name']
)
os.makedirs(s2_dir, exist_ok=True)

for band in S2_BANDS:
path = os.path.join(s2_dir, f'{patch_info["s2_name"]}_{band}.tif')
create_dummy_image(path, (IMG_SIZE, IMG_SIZE), 's2')

# Create reference map
ref_dir = os.path.join(
ROOT_DIR, 'Reference_Maps', patch_info['s2_base'], patch_info['s2_name']
)
os.makedirs(ref_dir, exist_ok=True)

path = os.path.join(ref_dir, f'{patch_info["s2_name"]}_reference_map.tif')
create_dummy_image(path, (IMG_SIZE, IMG_SIZE), 'reference')


def create_metadata() -> None:
"""Create metadata parquet file"""
records = []

for patch in SAMPLE_PATCHES:
records.append(
{
'patch_id': patch['s2_name'],
's1_name': patch['s1_name'],
'split': patch['split'],
'labels': np.random.choice(range(19), size=3, replace=False).tolist(),
}
)

df = pd.DataFrame.from_records(records)
df.to_parquet(os.path.join(ROOT_DIR, 'metadata.parquet'))


def compress_directory(dirname: str) -> None:
"""Compress directory using tar+zstd"""
tar_path = os.path.join(ROOT_DIR, f'{dirname}.tar')
with tarfile.open(tar_path, 'w') as tar:
tar.add(os.path.join(ROOT_DIR, dirname), arcname=dirname)

with open(tar_path, 'rb') as f_in:
data = f_in.read()
cctx = zstd.ZstdCompressor()
compressed = cctx.compress(data)
with open(f'{tar_path}.zst', 'wb') as f_out:
f_out.write(compressed)

os.remove(tar_path)


def main() -> None:
# Create directories and generate data
create_directory_structure()

for patch_info in SAMPLE_PATCHES:
generate_sample(patch_info)

create_metadata()

# Compress directories
for dirname in ['BigEarthNet-S1', 'BigEarthNet-S2', 'Reference_Maps']:
compress_directory(dirname)


if __name__ == '__main__':
main()
Binary file added tests/data/bigearthnetV2/metadata.parquet
Binary file not shown.
159 changes: 158 additions & 1 deletion tests/datasets/test_bigearthnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch

from torchgeo.datasets import BigEarthNet, DatasetNotFoundError
from torchgeo.datasets import BigEarthNet, BigEarthNetV2, DatasetNotFoundError


class TestBigEarthNet:
Expand Down Expand Up @@ -140,3 +140,160 @@ def test_plot(self, dataset: BigEarthNet) -> None:
x['prediction'] = x['label'].clone()
dataset.plot(x)
plt.close()


class TestBigEarthNetV2:
@pytest.fixture(
params=zip(['all', 's1', 's2'], [19, 19, 19], ['train', 'val', 'test'])
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> BigEarthNetV2:
data_dir = os.path.join('tests', 'data', 'bigearthnetV2')
metadata = {
's1': {
'url': os.path.join(data_dir, 'BigEarthNet-S1.tar.zst'),
'md5': 'a55eaa2cdf6a917e296bd6601ec1e348',
'filename': 'BigEarthNet-S1.tar.zst',
'directory': 'BigEarthNet-S1',
},
's2': {
'url': os.path.join(data_dir, 'BigEarthNet-S2.tar.zst'),
'md5': '2245ed2d1a93f6ce637d839bc856396e',
'filename': 'BigEarthNet-S2.tar.zst',
'directory': 'BigEarthNet-S2',
},
'maps': {
'url': os.path.join(data_dir, 'Reference_Maps.tar.zst'),
'md5': '95d85a222fa983faddcac51a19f28917',
'filename': 'Reference_Maps.tar.zst',
'directory': 'Reference_Maps',
},
'metadata': {
'url': os.path.join(data_dir, 'metadata.parquet'),
'md5': '5f6b7f8b9d4b8e4c4e9a4c9b8d9e4f8b',
'filename': 'metadata.parquet',
},
}
monkeypatch.setattr(BigEarthNetV2, 'metadata_locs', metadata)

bands, num_classes, split = request.param

root = tmp_path
transforms = nn.Identity()
return BigEarthNetV2(
root, split, bands, num_classes, transforms, download=True, checksum=True
)

def test_getitem(self, dataset: BigEarthNetV2) -> None:
"""Test loading data."""
x = dataset[0]

if dataset.bands in ['s2', 'all']:
if dataset.bands == 's2':
assert x['image'].shape == (12, 120, 120)
else:
assert x['image_s2'].shape == (12, 120, 120)

if dataset.bands in ['s1', 'all']:
if dataset.bands == 's1':
assert x['image'].shape == (2, 120, 120)
else:
assert x['image_s1'].shape == (2, 120, 120)

assert x['mask'].shape == (1, 120, 120)
assert x['label'].shape == (dataset.num_classes,)

assert x['mask'].dtype == torch.int64
assert x['label'].dtype == torch.int64
if 'image' in x:
assert x['image'].dtype == torch.float32
if 'image_s1' in x:
assert x['image_s1'].dtype == torch.float32
if 'image_s2' in x:
assert x['image_s2'].dtype == torch.float32

def test_len(self, dataset: BigEarthNetV2) -> None:
"""Test dataset length."""
if dataset.split == 'train':
assert len(dataset) == 1
elif dataset.split == 'val':
assert len(dataset) == 1
else:
assert len(dataset) == 1

def test_already_downloaded(self, dataset: BigEarthNetV2, tmp_path: Path) -> None:
BigEarthNetV2(
root=tmp_path,
bands=dataset.bands,
split=dataset.split,
num_classes=dataset.num_classes,
download=True,
)

def test_not_downloaded(self, tmp_path: Path) -> None:
"""Test error handling when data not present."""
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
BigEarthNetV2(tmp_path)

def test_already_downloaded_not_extracted(
self, dataset: BigEarthNetV2, tmp_path: Path
) -> None:
shutil.copy(dataset.metadata_locs['metadata']['url'], tmp_path)
if dataset.bands == 'all':
shutil.rmtree(
os.path.join(dataset.root, dataset.metadata_locs['s1']['directory'])
)
shutil.rmtree(
os.path.join(dataset.root, dataset.metadata_locs['s2']['directory'])
)
shutil.copy(dataset.metadata_locs['s1']['url'], tmp_path)
shutil.copy(dataset.metadata_locs['s2']['url'], tmp_path)
elif dataset.bands == 's1':
shutil.rmtree(
os.path.join(dataset.root, dataset.metadata_locs['s1']['directory'])
)
shutil.copy(dataset.metadata_locs['s1']['url'], tmp_path)
else:
shutil.rmtree(
os.path.join(dataset.root, dataset.metadata_locs['s2']['directory'])
)
shutil.copy(dataset.metadata_locs['s2']['url'], tmp_path)

BigEarthNetV2(
root=tmp_path,
bands=dataset.bands,
split=dataset.split,
num_classes=dataset.num_classes,
download=False,
)

def test_invalid_split(self, tmp_path: Path) -> None:
"""Test error on invalid split."""
with pytest.raises(AssertionError, match='split must be one of'):
BigEarthNetV2(tmp_path, split='invalid')

def test_invalid_bands(self, tmp_path: Path) -> None:
"""Test error on invalid bands selection."""
with pytest.raises(AssertionError):
BigEarthNetV2(tmp_path, bands='invalid')

def test_invalid_num_classes(self, tmp_path: Path) -> None:
"""Test error on invalid number of classes."""
with pytest.raises(AssertionError):
BigEarthNetV2(tmp_path, num_classes=20)

def test_plot(self, dataset: BigEarthNetV2) -> None:
"""Test plotting functionality."""
x = dataset[0].copy()
dataset.plot(x, suptitle='Test')
plt.close()

# Test without titles
dataset.plot(x, show_titles=False)
plt.close()

# Test with prediction
x['prediction'] = x['label'].clone()
dataset.plot(x)
plt.close()
3 changes: 2 additions & 1 deletion torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .airphen import Airphen
from .astergdem import AsterGDEM
from .benin_cashews import BeninSmallHolderCashews
from .bigearthnet import BigEarthNet
from .bigearthnet import BigEarthNet, BigEarthNetV2
from .biomassters import BioMassters
from .cabuar import CaBuAr
from .caffe import CaFFe
Expand Down Expand Up @@ -181,6 +181,7 @@
'AsterGDEM',
'BeninSmallHolderCashews',
'BigEarthNet',
'BigEarthNetV2',
'BioMassters',
'BoundingBox',
'CMSGlobalMangroveCanopy',
Expand Down
Loading
Loading