diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8147863f095..2b9abf0e11b 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -331,6 +331,11 @@ HySpecNet-11k .. autoclass:: HySpecNet11k +iSAID +^^^^^ + +.. autoclass:: ISAID + IDTReeS ^^^^^^^ diff --git a/docs/api/datasets/non_geo_datasets.csv b/docs/api/datasets/non_geo_datasets.csv index 1defcb032bd..4d53ced59ff 100644 --- a/docs/api/datasets/non_geo_datasets.csv +++ b/docs/api/datasets/non_geo_datasets.csv @@ -25,6 +25,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `HySpecNet-11k`_,-,EnMAP,CC0-1.0,11k,-,128,30,HSI `IDTReeS`_,"OD,C",Aerial,"CC-BY-4.0",591,33,200x200,0.1--1,RGB `Inria Aerial Image Labeling`_,S,Aerial,-,360,2,"5,000x5,000",0.3,RGB +`iSAID`_,"OD,I",Aerial,"CC-BY-NC-4.0","2,806",15,"varies","varies",RGB `LandCover.ai`_,S,Aerial,"CC-BY-NC-SA-4.0","10,674",5,512x512,0.25--0.5,RGB `LEVIR-CD`_,CD,Google Earth,-,637,2,"1,024x1,024",0.5,RGB `LEVIR-CD+`_,CD,Google Earth,-,985,2,"1,024x1,024",0.5,RGB diff --git a/tests/data/isaid/data.py b/tests/data/isaid/data.py new file mode 100644 index 00000000000..6563dfffd5c --- /dev/null +++ b/tests/data/isaid/data.py @@ -0,0 +1,135 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import json +import os +import shutil +import tarfile +from pathlib import Path + +import numpy as np +from PIL import Image + + +def create_dummy_image(path: Path, size: tuple[int, int] = (64, 64)) -> None: + """Create dummy RGB image.""" + img = np.random.randint(0, 255, (*size, 3), dtype=np.uint8) + Image.fromarray(img).save(path) + + +def create_coco_annotations(split: str, num_images: int) -> dict: + """Create COCO format annotations.""" + return { + 'info': {'year': 2023, 'version': '1.0'}, + 'images': [ + {'id': i, 'file_name': f'P{i:04d}.png', 'height': 64, 'width': 64} + for i in range(num_images) + ], + 'annotations': [ + { + 'id': i, + 'image_id': i // 2, # 2 annotations per image + 'category_id': i % 15, + 'segmentation': [[10, 10, 20, 10, 20, 20, 10, 20]], + 'area': 100, + 'bbox': [10, 10, 10, 10], + 'iscrowd': 0, + } + for i in range(num_images * 2) + ], + 'categories': [ + {'id': i, 'name': name} + for i, name in enumerate( + [ + 'plane', + 'ship', + 'storage tank', + 'baseball diamond', + 'tennis court', + 'basketball court', + 'ground track field', + 'harbor', + 'bridge', + 'vehicle', + 'helicopter', + 'roundabout', + 'swimming pool', + 'soccer ball field', + 'container crane', + ] + ) + ], + } + + +def create_test_data(root: Path) -> None: + """Create iSAID test dataset.""" + splits = {'train': 3, 'val': 2} + + for split, num_samples in splits.items(): + if os.path.exists(root / split): + shutil.rmtree(root / split) + + # Create directories + for subdir in ['images', 'Annotations', 'Instance_masks', 'Semantic_masks']: + (root / split / subdir).mkdir(parents=True, exist_ok=True) + + # Create images and masks + for i in range(num_samples): + # RGB image + create_dummy_image(root / split / 'images' / f'P{i:04d}.png') + + # Instance mask (R+G*256+B*256^2 encoding) + instance_mask = np.zeros((64, 64, 3), dtype=np.uint8) + instance_mask[10:20, 10:20, 0] = i + 1 # R channel for unique IDs + Image.fromarray(instance_mask).save( + root / split / 'Instance_masks' / f'P{i:04d}.png' + ) + + # Semantic mask (similar encoding for class IDs) + semantic_mask = np.zeros((64, 64, 3), dtype=np.uint8) + semantic_mask[10:20, 10:20, 0] = 1 # Class ID 1 + Image.fromarray(semantic_mask).save( + root / split / 'Semantic_masks' / f'P{i:04d}.png' + ) + + # Create COCO annotations + annotations = create_coco_annotations(split, num_samples) + with open(root / split / 'Annotations' / f'iSAID_{split}.json', 'w') as f: + json.dump(annotations, f) + + # Create image tar + img_tar = f'dotav1_images_{split}.tar.gz' + with tarfile.open(root / img_tar, 'w:gz') as tar: + tar.add(root / split / 'images', arcname=os.path.join(split, 'images')) + + # Create annotations tar with all splits + ann_tar = f'isaid_annotations_{split}.tar.gz' + with tarfile.open(root / ann_tar, 'w:gz') as tar: + for split in splits: + for subdir in ['Annotations', 'Instance_masks', 'Semantic_masks']: + tar.add(root / split / subdir, arcname=os.path.join(split, subdir)) + + # print md5sums + def md5(fname: str) -> str: + hash_md5 = hashlib.md5() + with open(fname, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b''): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + # Print MD5 checksums + for split in splits: + print( + f'MD5 for dotav1_images_{split}.tar.gz: ' + f'{md5(root / f"dotav1_images_{split}.tar.gz")}' + ) + print( + f'MD5 for isaid_annotations_{split}.tar.gz: {md5(root / f"isaid_annotations_{split}.tar.gz")}' + ) + + +if __name__ == '__main__': + root = Path('.') + create_test_data(root) diff --git a/tests/data/isaid/dotav1_images_train.tar.gz b/tests/data/isaid/dotav1_images_train.tar.gz new file mode 100644 index 00000000000..0991de0e84b Binary files /dev/null and b/tests/data/isaid/dotav1_images_train.tar.gz differ diff --git a/tests/data/isaid/dotav1_images_val.tar.gz b/tests/data/isaid/dotav1_images_val.tar.gz new file mode 100644 index 00000000000..4ac9f97bbaf Binary files /dev/null and b/tests/data/isaid/dotav1_images_val.tar.gz differ diff --git a/tests/data/isaid/isaid_annotations_train.tar.gz b/tests/data/isaid/isaid_annotations_train.tar.gz new file mode 100644 index 00000000000..93b2ae902e6 Binary files /dev/null and b/tests/data/isaid/isaid_annotations_train.tar.gz differ diff --git a/tests/data/isaid/isaid_annotations_val.tar.gz b/tests/data/isaid/isaid_annotations_val.tar.gz new file mode 100644 index 00000000000..422c6042d90 Binary files /dev/null and b/tests/data/isaid/isaid_annotations_val.tar.gz differ diff --git a/tests/data/isaid/train/Annotations/iSAID_train.json b/tests/data/isaid/train/Annotations/iSAID_train.json new file mode 100644 index 00000000000..bc9c49d7763 --- /dev/null +++ b/tests/data/isaid/train/Annotations/iSAID_train.json @@ -0,0 +1 @@ +{"info": {"year": 2023, "version": "1.0"}, "images": [{"id": 0, "file_name": "P0000.png", "height": 64, "width": 64}, {"id": 1, "file_name": "P0001.png", "height": 64, "width": 64}, {"id": 2, "file_name": "P0002.png", "height": 64, "width": 64}], "annotations": [{"id": 0, "image_id": 0, "category_id": 0, "segmentation": [[10, 10, 20, 10, 20, 20, 10, 20]], "area": 100, "bbox": [10, 10, 10, 10], "iscrowd": 0}, {"id": 1, "image_id": 0, "category_id": 1, "segmentation": [[10, 10, 20, 10, 20, 20, 10, 20]], "area": 100, "bbox": [10, 10, 10, 10], "iscrowd": 0}, {"id": 2, "image_id": 1, "category_id": 2, "segmentation": [[10, 10, 20, 10, 20, 20, 10, 20]], "area": 100, "bbox": [10, 10, 10, 10], "iscrowd": 0}, {"id": 3, "image_id": 1, "category_id": 3, "segmentation": [[10, 10, 20, 10, 20, 20, 10, 20]], "area": 100, "bbox": [10, 10, 10, 10], "iscrowd": 0}, {"id": 4, "image_id": 2, "category_id": 4, "segmentation": [[10, 10, 20, 10, 20, 20, 10, 20]], "area": 100, "bbox": [10, 10, 10, 10], "iscrowd": 0}, {"id": 5, "image_id": 2, "category_id": 5, "segmentation": [[10, 10, 20, 10, 20, 20, 10, 20]], "area": 100, "bbox": [10, 10, 10, 10], "iscrowd": 0}], "categories": [{"id": 0, "name": "plane"}, {"id": 1, "name": "ship"}, {"id": 2, "name": "storage tank"}, {"id": 3, "name": "baseball diamond"}, {"id": 4, "name": "tennis court"}, {"id": 5, "name": "basketball court"}, {"id": 6, "name": "ground track field"}, {"id": 7, "name": "harbor"}, {"id": 8, "name": "bridge"}, {"id": 9, "name": "vehicle"}, {"id": 10, "name": "helicopter"}, {"id": 11, "name": "roundabout"}, {"id": 12, "name": "swimming pool"}, {"id": 13, "name": "soccer ball field"}, {"id": 14, "name": "container crane"}]} \ No newline at end of file diff --git a/tests/data/isaid/train/Instance_masks/P0000.png b/tests/data/isaid/train/Instance_masks/P0000.png new file mode 100644 index 00000000000..89de3058e86 Binary files /dev/null and b/tests/data/isaid/train/Instance_masks/P0000.png differ diff --git a/tests/data/isaid/train/Instance_masks/P0001.png b/tests/data/isaid/train/Instance_masks/P0001.png new file mode 100644 index 00000000000..a8692f7abcc Binary files /dev/null and b/tests/data/isaid/train/Instance_masks/P0001.png differ diff --git a/tests/data/isaid/train/Instance_masks/P0002.png b/tests/data/isaid/train/Instance_masks/P0002.png new file mode 100644 index 00000000000..d8ef53e6d5e Binary files /dev/null and b/tests/data/isaid/train/Instance_masks/P0002.png differ diff --git a/tests/data/isaid/train/Semantic_masks/P0000.png b/tests/data/isaid/train/Semantic_masks/P0000.png new file mode 100644 index 00000000000..89de3058e86 Binary files /dev/null and b/tests/data/isaid/train/Semantic_masks/P0000.png differ diff --git a/tests/data/isaid/train/Semantic_masks/P0001.png b/tests/data/isaid/train/Semantic_masks/P0001.png new file mode 100644 index 00000000000..89de3058e86 Binary files /dev/null and b/tests/data/isaid/train/Semantic_masks/P0001.png differ diff --git a/tests/data/isaid/train/Semantic_masks/P0002.png b/tests/data/isaid/train/Semantic_masks/P0002.png new file mode 100644 index 00000000000..89de3058e86 Binary files /dev/null and b/tests/data/isaid/train/Semantic_masks/P0002.png differ diff --git a/tests/data/isaid/train/images/P0000.png b/tests/data/isaid/train/images/P0000.png new file mode 100644 index 00000000000..f12b278a43a Binary files /dev/null and b/tests/data/isaid/train/images/P0000.png differ diff --git a/tests/data/isaid/train/images/P0001.png b/tests/data/isaid/train/images/P0001.png new file mode 100644 index 00000000000..f3a910b9613 Binary files /dev/null and b/tests/data/isaid/train/images/P0001.png differ diff --git a/tests/data/isaid/train/images/P0002.png b/tests/data/isaid/train/images/P0002.png new file mode 100644 index 00000000000..8a67e1fc367 Binary files /dev/null and b/tests/data/isaid/train/images/P0002.png differ diff --git a/tests/data/isaid/val/Annotations/iSAID_val.json b/tests/data/isaid/val/Annotations/iSAID_val.json new file mode 100644 index 00000000000..9f77d848851 --- /dev/null +++ b/tests/data/isaid/val/Annotations/iSAID_val.json @@ -0,0 +1 @@ +{"info": {"year": 2023, "version": "1.0"}, "images": [{"id": 0, "file_name": "P0000.png", "height": 64, "width": 64}, {"id": 1, "file_name": "P0001.png", "height": 64, "width": 64}], "annotations": [{"id": 0, "image_id": 0, "category_id": 0, "segmentation": [[10, 10, 20, 10, 20, 20, 10, 20]], "area": 100, "bbox": [10, 10, 10, 10], "iscrowd": 0}, {"id": 1, "image_id": 0, "category_id": 1, "segmentation": [[10, 10, 20, 10, 20, 20, 10, 20]], "area": 100, "bbox": [10, 10, 10, 10], "iscrowd": 0}, {"id": 2, "image_id": 1, "category_id": 2, "segmentation": [[10, 10, 20, 10, 20, 20, 10, 20]], "area": 100, "bbox": [10, 10, 10, 10], "iscrowd": 0}, {"id": 3, "image_id": 1, "category_id": 3, "segmentation": [[10, 10, 20, 10, 20, 20, 10, 20]], "area": 100, "bbox": [10, 10, 10, 10], "iscrowd": 0}], "categories": [{"id": 0, "name": "plane"}, {"id": 1, "name": "ship"}, {"id": 2, "name": "storage tank"}, {"id": 3, "name": "baseball diamond"}, {"id": 4, "name": "tennis court"}, {"id": 5, "name": "basketball court"}, {"id": 6, "name": "ground track field"}, {"id": 7, "name": "harbor"}, {"id": 8, "name": "bridge"}, {"id": 9, "name": "vehicle"}, {"id": 10, "name": "helicopter"}, {"id": 11, "name": "roundabout"}, {"id": 12, "name": "swimming pool"}, {"id": 13, "name": "soccer ball field"}, {"id": 14, "name": "container crane"}]} \ No newline at end of file diff --git a/tests/data/isaid/val/Instance_masks/P0000.png b/tests/data/isaid/val/Instance_masks/P0000.png new file mode 100644 index 00000000000..89de3058e86 Binary files /dev/null and b/tests/data/isaid/val/Instance_masks/P0000.png differ diff --git a/tests/data/isaid/val/Instance_masks/P0001.png b/tests/data/isaid/val/Instance_masks/P0001.png new file mode 100644 index 00000000000..a8692f7abcc Binary files /dev/null and b/tests/data/isaid/val/Instance_masks/P0001.png differ diff --git a/tests/data/isaid/val/Semantic_masks/P0000.png b/tests/data/isaid/val/Semantic_masks/P0000.png new file mode 100644 index 00000000000..89de3058e86 Binary files /dev/null and b/tests/data/isaid/val/Semantic_masks/P0000.png differ diff --git a/tests/data/isaid/val/Semantic_masks/P0001.png b/tests/data/isaid/val/Semantic_masks/P0001.png new file mode 100644 index 00000000000..89de3058e86 Binary files /dev/null and b/tests/data/isaid/val/Semantic_masks/P0001.png differ diff --git a/tests/data/isaid/val/images/P0000.png b/tests/data/isaid/val/images/P0000.png new file mode 100644 index 00000000000..83476529ac5 Binary files /dev/null and b/tests/data/isaid/val/images/P0000.png differ diff --git a/tests/data/isaid/val/images/P0001.png b/tests/data/isaid/val/images/P0001.png new file mode 100644 index 00000000000..483b958c70b Binary files /dev/null and b/tests/data/isaid/val/images/P0001.png differ diff --git a/tests/datasets/test_isaid.py b/tests/datasets/test_isaid.py new file mode 100644 index 00000000000..448cb3b0d38 --- /dev/null +++ b/tests/datasets/test_isaid.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch + +from torchgeo.datasets import ISAID, DatasetNotFoundError + +pytest.importorskip('pycocotools') + + +class TestISAID: + @pytest.fixture(params=['train', 'val']) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> ISAID: + url = os.path.join('tests', 'data', 'isaid', '{}') + monkeypatch.setattr(ISAID, 'img_url', url) + monkeypatch.setattr(ISAID, 'label_url', url) + + img_files = { + 'train': { + 'filename': 'dotav1_images_train.tar.gz', + 'md5': 'a38ad9832066e2ca6d30b8eec65f9ce8', + }, + 'val': { + 'filename': 'dotav1_images_val.tar.gz', + 'md5': '154babe8091484bd85c6340f43cea1ea', + }, + } + + monkeypatch.setattr(ISAID, 'img_files', img_files) + + label_files = { + 'train': { + 'filename': 'isaid_annotations_train.tar.gz', + 'md5': 'f4de0f6b38f1b11b121dc01c880aeb2a', + }, + 'val': { + 'filename': 'isaid_annotations_val.tar.gz', + 'md5': '88eccdf9744c201248266b9a784ffeab', + }, + } + monkeypatch.setattr(ISAID, 'label_files', label_files) + + root = tmp_path + split = request.param + + transforms = nn.Identity() + + return ISAID(root, split, transforms=transforms, download=True, checksum=True) + + def test_getitem(self, dataset: ISAID) -> None: + for i in range(len(dataset)): + x = dataset[i] + assert isinstance(x, dict) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['masks'], torch.Tensor) + assert isinstance(x['boxes'], torch.Tensor) + + def test_len(self, dataset: ISAID) -> None: + if dataset.split == 'train': + assert len(dataset) == 3 + else: + assert len(dataset) == 2 + + def test_already_downloaded(self, dataset: ISAID) -> None: + ISAID(root=dataset.root, download=True) + + def test_not_yet_extracted(self, tmp_path: Path) -> None: + files = [ + 'dotav1_images_train.tar.gz', + 'dotav1_images_val.tar.gz', + 'isaid_annotations_train.tar.gz', + 'isaid_annotations_val.tar.gz', + ] + for path in files: + shutil.copyfile( + os.path.join('tests', 'data', 'isaid', path), + os.path.join(str(tmp_path), path), + ) + + ISAID(root=tmp_path) + + def test_invalid_split(self) -> None: + with pytest.raises(AssertionError): + ISAID(split='foo') + + def test_corrupted(self, tmp_path: Path) -> None: + with open(os.path.join(tmp_path, 'dotav1_images_train.tar.gz'), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Archive'): + ISAID(root=tmp_path, checksum=True) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + ISAID(tmp_path) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 1d644c6fc69..84e488df384 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -68,6 +68,7 @@ from .inaturalist import INaturalist from .inria import InriaAerialImageLabeling from .iobench import IOBench +from .isaid import ISAID from .l7irish import L7Irish from .l8biome import L8Biome from .landcoverai import LandCoverAI, LandCoverAI100, LandCoverAIBase, LandCoverAIGeo @@ -163,6 +164,7 @@ 'FAIR1M', 'GBIF', 'GID15', + 'ISAID', 'LEVIRCD', 'MDAS', 'NAIP', diff --git a/torchgeo/datasets/isaid.py b/torchgeo/datasets/isaid.py new file mode 100644 index 00000000000..6183700e143 --- /dev/null +++ b/torchgeo/datasets/isaid.py @@ -0,0 +1,340 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""iSAID dataset.""" + +import os +from collections.abc import Callable +from typing import Any, ClassVar + +import numpy as np +import torch +from PIL import Image +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import ( + Path, + check_integrity, + download_and_extract_archive, + extract_archive, + lazy_import, +) + + +def convert_coco_poly_to_mask( + segmentations: list[int], height: int, width: int +) -> Tensor: + """Convert coco polygons to mask tensor. + + Args: + segmentations (List[int]): polygon coordinates + height (int): image height + width (int): image width + + Returns: + Tensor: Mask tensor + + Raises: + DependencyNotFoundError: If pycocotools is not installed. + """ + pycocotools = lazy_import('pycocotools') + masks = [] + for polygons in segmentations: + rles = pycocotools.mask.frPyObjects(polygons, height, width) + mask = pycocotools.mask.decode(rles) + mask = torch.as_tensor(mask, dtype=torch.uint8) + mask = mask.any(dim=2) + masks.append(mask) + masks_tensor = torch.stack(masks, dim=0) + return masks_tensor + + +class ConvertCocoAnnotations: + """Callable for converting the boxes, masks and labels into tensors. + + This is a modified version of ConvertCocoPolysToMask() from torchvision found in + https://github.com/pytorch/vision/blob/v0.14.0/references/detection/coco_utils.py + """ + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + """Converts MS COCO fields (boxes, masks & labels) from list of ints to tensors. + + Args: + sample: Sample + + Returns: + Processed sample + """ + image = sample['image'] + _, h, w = image.size() + target = sample['label'] + + image_id = target['image_id'] + image_id = torch.tensor([image_id]) + + anno = target['annotations'] + + anno = [obj for obj in anno if obj['iscrowd'] == 0] + + bboxes = [obj['bbox'] for obj in anno] + # guard against no boxes via resizing + boxes = torch.as_tensor(bboxes, dtype=torch.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2].clamp_(min=0, max=w) + boxes[:, 1::2].clamp_(min=0, max=h) + + categories = [obj['category_id'] for obj in anno] + classes = torch.tensor(categories, dtype=torch.int64) + + segmentations = [obj['segmentation'] for obj in anno] + + masks = convert_coco_poly_to_mask(segmentations, h, w) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + classes = classes[keep] + + target = {'boxes': boxes, 'labels': classes, 'image_id': image_id} + if masks.nelement() > 0: + masks = masks[keep] + target['masks'] = masks + + # for conversion to coco api + area = torch.tensor([obj['area'] for obj in anno]) + iscrowd = torch.tensor([obj['iscrowd'] for obj in anno]) + target['area'] = area + target['iscrowd'] = iscrowd + return {'image': image, 'label': target} + + +class ISAID(NonGeoDataset): + """iSAID dataset. + + The `iSAID `_ dataset is a large-scale instance segmentation dataset for aerial imagery. + It builds upon the DOTA V1 dataset, but includes instance-level annotations for 15 object categories. + + Dataset features: + + * multi-class instance segmentation + * multi-class object detection + * aerial imagery over various GSDs + + Dataset format: + + * images are three channel RGB PNGs with various pixel dimensions + * labels are annotaitons in json MSCOCO format + + Classes: + + * plane + * ship + * storage-tank + * baseball-diamond + * tennis-court + * basketball-court + * ground-track-field + * harbor + * bridge + * large-vehicle + * small-vehicle + * helicopter + * roundabout + * soccer-ball-field + * swimming-pool + + If you use this dataset in your research, please cite the following paper: + + * https://arxiv.org/abs/1905.12886 + * https://arxiv.org/abs/1711.10398 + + .. versionadded:: 0.7 + + .. note:: + + This dataset requires the following additional library to be installed: + + * `pycocotools `_ to load the + annotations + """ + + img_url = 'https://huggingface.co/datasets/torchgeo/dota/tree/main/{}' + + img_files: ClassVar[dict[str, dict[str, str]]] = { + 'train': {'filename': 'dotav1_images_train.tar.gz', 'md5': ''}, + 'val': {'filename': 'dotav1_images_val.tar.gz', 'md5': ''}, + } + + label_url = 'https://huggingface.co/datasets/torchgeo/isaid/tree/main/{}' + + label_files: ClassVar[dict[str, dict[str, str]]] = { + 'train': {'filename': 'isaid_annotations_train.tar.gz', 'md5': ''}, + 'val': {'filename': 'isaid_annotations_val.tar.gz', 'md5': ''}, + } + + classes: ClassVar[dict[int, str]] = { + 0: 'plane', + 1: 'ship', + 2: 'storage tank', + 3: 'baseball diamond', + 4: 'tennis court', + 5: 'basketball court', + 6: 'ground track field', + 7: 'harbor', + 8: 'bridge', + 9: 'vehicle', + 10: 'helicopter', + 11: 'roundabout', + 12: 'swimming pool', + 13: 'soccer ball field', + 14: 'container crane', + } + + valid_splits = ('train', 'val') + + def __init__( + self, + root: Path = 'data', + split: str = 'train', + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new VHR-10 dataset instance. + + Args: + root: root directory where dataset can be found + split: one of "positive" or "negative" + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: if ``split`` argument is invalid + DatasetNotFoundError: If dataset is not found and *download* is False. + DependencyNotFoundError: if pycocotools is + not installed. + """ + assert split in self.valid_splits, ( + f"Invalid split '{split}', please use one of {self.valid_splits}" + ) + + self.root = root + self.split = split + self.transforms = transforms + self.download = download + self.checksum = checksum + + self._verify() + + pc = lazy_import('pycocotools.coco') + self.coco = pc.COCO( + os.path.join( + self.root, self.split, 'Annotations', f'iSAID_{self.split}.json' + ) + ) + self.coco_convert = ConvertCocoAnnotations() + self.ids = list(sorted(self.coco.imgs.keys())) + + def __len__(self) -> int: + """Return the length of the dataset.""" + return len(self.ids) + + def __getitem__(self, index: int) -> dict[str, Any]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + id_ = index % len(self) + 1 + + sample: dict[str, Any] = { + 'image': self._load_image(id_), + 'label': self._load_mask(id_), + } + + sample = self.coco_convert(sample) + sample['labels'] = sample['label']['labels'] + sample['boxes'] = sample['label']['boxes'] + sample['masks'] = sample['label']['masks'] + del sample['label'] + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _load_mask(self, id_: int) -> dict[str, Any]: + """Load mask. + + Args: + id_: image ID for coco + + Returns: + instance mask tensor with unique IDs + """ + annot = self.coco.loadAnns(self.coco.getAnnIds(id_ - 1)) + + target = dict(image_id=id_, annotations=annot) + return target + + def _load_image(self, id_: int) -> Tensor: + """Load an image from a given path. + + Args: + id_: image ID for coco + + Returns: + image tensor + """ + filename = os.path.join( + self.root, self.split, 'images', self.coco.imgs[id_ - 1]['file_name'] + ) + image = Image.open(filename).convert('RGB') + return torch.from_numpy(np.array(image).transpose(2, 0, 1)).float() + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # check presence of directories + dirs = ['images', 'Annotations', 'Instance_masks', 'Semantic_masks'] + exists = [ + os.path.exists(os.path.join(self.root, self.split, dir)) for dir in dirs + ] + + if all(exists): + return + + # check compressed files + exists = [] + files = [ + self.img_files[self.split]['filename'], + self.label_files[self.split]['filename'], + ] + md5s = [self.img_files[self.split]['md5'], self.label_files[self.split]['md5']] + for file, md5 in zip(files, md5s): + if os.path.exists(os.path.join(self.root, file)): + if self.checksum and not check_integrity( + os.path.join(self.root, file), md5 + ): + raise RuntimeError(f'Archive {file} is found but corrupted') + exists.append(True) + extract_archive(os.path.join(self.root, file), self.root) + else: + exists.append(False) + + if all(exists): + return + + if not self.download: + raise DatasetNotFoundError(self) + + # download the dataset + for file in files: + download_and_extract_archive( + self.img_url.format(file), self.root, md5=md5 if self.checksum else None + )