diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 6c5c57ff176..d01a91dfe70 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -221,6 +221,11 @@ BioMassters .. autoclass:: BioMassters +BRIGHT +^^^^^^ + +.. autoclass:: BRIGHTDFC2025 + CaBuAr ^^^^^^ diff --git a/docs/api/datasets/non_geo_datasets.csv b/docs/api/datasets/non_geo_datasets.csv index d1d0ff03a9c..1defcb032bd 100644 --- a/docs/api/datasets/non_geo_datasets.csv +++ b/docs/api/datasets/non_geo_datasets.csv @@ -3,6 +3,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `Benin Cashew Plantations`_,S,Airbus Pléiades,"CC-BY-4.0",70,6,"1,122x1,186",10,MSI `BigEarthNet`_,C,Sentinel-1/2,"CDLA-Permissive-1.0","590,326",19--43,120x120,10,"SAR, MSI" `BioMassters`_,R,Sentinel-1/2 and Lidar,"CC-BY-4.0",,,256x256, 10, "SAR, MSI" +`BRIGHT`_,CD,"MAXAR, NAIP, Capella, Umbra","CC-BY-4.0 AND CC-BY-NC-4.0",3239,4,"0.1-1","RGB,SAR" `CaBuAr`_,CD,Sentinel-2,"OpenRAIL",424,2,512x512,20,MSI `CaFFe`_,S,"Sentinel-1, TerraSAR-X, TanDEM-X, ENVISAT, ERS-1/2, ALOS PALSAR, and RADARSAT-1","CC-BY-4.0","19092","2 or 4","512x512",6-20,"SAR" `ChaBuD`_,CD,Sentinel-2,"OpenRAIL",356,2,512x512,10,MSI diff --git a/tests/data/bright/data.py b/tests/data/bright/data.py new file mode 100644 index 00000000000..61a03423509 --- /dev/null +++ b/tests/data/bright/data.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil + +import numpy as np +import rasterio + +ROOT = '.' +DATA_DIR = 'dfc25_track2_trainval' + +TRAIN_FILE = 'train_setlevel.txt' +HOLDOUT_FILE = 'holdout_setlevel.txt' +VAL_FILE = 'val_setlevel.txt' + +TRAIN_IDS = [ + 'bata-explosion_00000049', + 'bata-explosion_00000014', + 'bata-explosion_00000047', +] +HOLDOUT_IDS = ['turkey-earthquake_00000413'] +VAL_IDS = ['val-disaster_00000001', 'val-disaster_00000002'] + +SIZE = 32 + + +def make_dirs() -> None: + paths = [ + os.path.join(ROOT, DATA_DIR), + os.path.join(ROOT, DATA_DIR, 'train', 'pre-event'), + os.path.join(ROOT, DATA_DIR, 'train', 'post-event'), + os.path.join(ROOT, DATA_DIR, 'train', 'target'), + os.path.join(ROOT, DATA_DIR, 'val', 'pre-event'), + os.path.join(ROOT, DATA_DIR, 'val', 'post-event'), + os.path.join(ROOT, DATA_DIR, 'val', 'target'), + ] + for p in paths: + os.makedirs(p, exist_ok=True) + + +def write_list_file(filename: str, ids: list[str]) -> None: + file_path = os.path.join(ROOT, DATA_DIR, filename) + with open(file_path, 'w') as f: + for sid in ids: + f.write(f'{sid}\n') + + +def write_tif(filepath: str, channels: int) -> None: + data = np.random.randint(0, 255, (channels, SIZE, SIZE), dtype=np.uint8) + # transform = from_origin(0, 0, 1, 1) + crs = 'epsg:4326' + with rasterio.open( + filepath, + 'w', + driver='GTiff', + height=SIZE, + width=SIZE, + count=channels, + crs=crs, + dtype=data.dtype, + compress='lzw', + # transform=transform, + ) as dst: + dst.write(data) + + +def populate_data(ids: list[str], dir_name: str, with_target: bool = True) -> None: + for sid in ids: + pre_path = os.path.join( + ROOT, DATA_DIR, dir_name, 'pre-event', f'{sid}_pre_disaster.tif' + ) + write_tif(pre_path, channels=3) + post_path = os.path.join( + ROOT, DATA_DIR, dir_name, 'post-event', f'{sid}_post_disaster.tif' + ) + write_tif(post_path, channels=1) + if with_target: + target_path = os.path.join( + ROOT, DATA_DIR, dir_name, 'target', f'{sid}_building_damage.tif' + ) + write_tif(target_path, channels=1) + + +def main() -> None: + make_dirs() + + # Write the ID lists to text files + write_list_file(TRAIN_FILE, TRAIN_IDS) + write_list_file(HOLDOUT_FILE, HOLDOUT_IDS) + write_list_file(VAL_FILE, VAL_IDS) + + # Generate TIF files for the train (with target) and val (no target) splits + populate_data(TRAIN_IDS, 'train', with_target=True) + populate_data(HOLDOUT_IDS, 'train', with_target=True) + populate_data(VAL_IDS, 'val', with_target=False) + + # zip and compute md5 + zip_filename = os.path.join(ROOT, 'dfc25_track2_trainval') + shutil.make_archive(zip_filename, 'zip', ROOT, DATA_DIR) + + def md5(fname: str) -> str: + hash_md5 = hashlib.md5() + with open(fname, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b''): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + md5sum = md5(zip_filename + '.zip') + print(f'MD5 checksum: {md5sum}') + + +if __name__ == '__main__': + main() diff --git a/tests/data/bright/dfc25_track2_trainval.zip b/tests/data/bright/dfc25_track2_trainval.zip new file mode 100644 index 00000000000..e42936c035a Binary files /dev/null and b/tests/data/bright/dfc25_track2_trainval.zip differ diff --git a/tests/data/bright/dfc25_track2_trainval/holdout_setlevel.txt b/tests/data/bright/dfc25_track2_trainval/holdout_setlevel.txt new file mode 100644 index 00000000000..b549d244076 --- /dev/null +++ b/tests/data/bright/dfc25_track2_trainval/holdout_setlevel.txt @@ -0,0 +1 @@ +turkey-earthquake_00000413 diff --git a/tests/data/bright/dfc25_track2_trainval/train/post-event/bata-explosion_00000014_post_disaster.tif b/tests/data/bright/dfc25_track2_trainval/train/post-event/bata-explosion_00000014_post_disaster.tif new file mode 100644 index 00000000000..5130a787952 Binary files /dev/null and b/tests/data/bright/dfc25_track2_trainval/train/post-event/bata-explosion_00000014_post_disaster.tif differ diff --git a/tests/data/bright/dfc25_track2_trainval/train/post-event/bata-explosion_00000047_post_disaster.tif b/tests/data/bright/dfc25_track2_trainval/train/post-event/bata-explosion_00000047_post_disaster.tif new file mode 100644 index 00000000000..28474f92bcf Binary files /dev/null and b/tests/data/bright/dfc25_track2_trainval/train/post-event/bata-explosion_00000047_post_disaster.tif differ diff --git a/tests/data/bright/dfc25_track2_trainval/train/post-event/bata-explosion_00000049_post_disaster.tif b/tests/data/bright/dfc25_track2_trainval/train/post-event/bata-explosion_00000049_post_disaster.tif new file mode 100644 index 00000000000..ce9f9398b26 Binary files /dev/null and b/tests/data/bright/dfc25_track2_trainval/train/post-event/bata-explosion_00000049_post_disaster.tif differ diff --git a/tests/data/bright/dfc25_track2_trainval/train/post-event/turkey-earthquake_00000413_post_disaster.tif b/tests/data/bright/dfc25_track2_trainval/train/post-event/turkey-earthquake_00000413_post_disaster.tif new file mode 100644 index 00000000000..b035ee7c66d Binary files /dev/null and b/tests/data/bright/dfc25_track2_trainval/train/post-event/turkey-earthquake_00000413_post_disaster.tif differ diff --git a/tests/data/bright/dfc25_track2_trainval/train/pre-event/bata-explosion_00000014_pre_disaster.tif b/tests/data/bright/dfc25_track2_trainval/train/pre-event/bata-explosion_00000014_pre_disaster.tif new file mode 100644 index 00000000000..3c54722bbf8 Binary files /dev/null and b/tests/data/bright/dfc25_track2_trainval/train/pre-event/bata-explosion_00000014_pre_disaster.tif differ diff --git a/tests/data/bright/dfc25_track2_trainval/train/pre-event/bata-explosion_00000047_pre_disaster.tif b/tests/data/bright/dfc25_track2_trainval/train/pre-event/bata-explosion_00000047_pre_disaster.tif new file mode 100644 index 00000000000..2eba4547d09 Binary files /dev/null and b/tests/data/bright/dfc25_track2_trainval/train/pre-event/bata-explosion_00000047_pre_disaster.tif differ diff --git a/tests/data/bright/dfc25_track2_trainval/train/pre-event/bata-explosion_00000049_pre_disaster.tif b/tests/data/bright/dfc25_track2_trainval/train/pre-event/bata-explosion_00000049_pre_disaster.tif new file mode 100644 index 00000000000..5e672133df9 Binary files /dev/null and b/tests/data/bright/dfc25_track2_trainval/train/pre-event/bata-explosion_00000049_pre_disaster.tif differ diff --git a/tests/data/bright/dfc25_track2_trainval/train/pre-event/turkey-earthquake_00000413_pre_disaster.tif b/tests/data/bright/dfc25_track2_trainval/train/pre-event/turkey-earthquake_00000413_pre_disaster.tif new file mode 100644 index 00000000000..328c73db826 Binary files /dev/null and b/tests/data/bright/dfc25_track2_trainval/train/pre-event/turkey-earthquake_00000413_pre_disaster.tif differ diff --git a/tests/data/bright/dfc25_track2_trainval/train_setlevel.txt b/tests/data/bright/dfc25_track2_trainval/train_setlevel.txt new file mode 100644 index 00000000000..6da4e905850 --- /dev/null +++ b/tests/data/bright/dfc25_track2_trainval/train_setlevel.txt @@ -0,0 +1,3 @@ +bata-explosion_00000049 +bata-explosion_00000014 +bata-explosion_00000047 diff --git a/tests/data/bright/dfc25_track2_trainval/val/post-event/val-disaster_00000001_post_disaster.tif b/tests/data/bright/dfc25_track2_trainval/val/post-event/val-disaster_00000001_post_disaster.tif new file mode 100644 index 00000000000..dc8dfc4506f Binary files /dev/null and b/tests/data/bright/dfc25_track2_trainval/val/post-event/val-disaster_00000001_post_disaster.tif differ diff --git a/tests/data/bright/dfc25_track2_trainval/val/post-event/val-disaster_00000002_post_disaster.tif b/tests/data/bright/dfc25_track2_trainval/val/post-event/val-disaster_00000002_post_disaster.tif new file mode 100644 index 00000000000..d0233611678 Binary files /dev/null and b/tests/data/bright/dfc25_track2_trainval/val/post-event/val-disaster_00000002_post_disaster.tif differ diff --git a/tests/data/bright/dfc25_track2_trainval/val/pre-event/val-disaster_00000001_pre_disaster.tif b/tests/data/bright/dfc25_track2_trainval/val/pre-event/val-disaster_00000001_pre_disaster.tif new file mode 100644 index 00000000000..4dea79477dc Binary files /dev/null and b/tests/data/bright/dfc25_track2_trainval/val/pre-event/val-disaster_00000001_pre_disaster.tif differ diff --git a/tests/data/bright/dfc25_track2_trainval/val/pre-event/val-disaster_00000002_pre_disaster.tif b/tests/data/bright/dfc25_track2_trainval/val/pre-event/val-disaster_00000002_pre_disaster.tif new file mode 100644 index 00000000000..aca2ac66351 Binary files /dev/null and b/tests/data/bright/dfc25_track2_trainval/val/pre-event/val-disaster_00000002_pre_disaster.tif differ diff --git a/tests/data/bright/dfc25_track2_trainval/val_setlevel.txt b/tests/data/bright/dfc25_track2_trainval/val_setlevel.txt new file mode 100644 index 00000000000..dc0a4439bff --- /dev/null +++ b/tests/data/bright/dfc25_track2_trainval/val_setlevel.txt @@ -0,0 +1,2 @@ +val-disaster_00000001 +val-disaster_00000002 diff --git a/tests/datasets/test_bright.py b/tests/datasets/test_bright.py new file mode 100644 index 00000000000..b29f534722a --- /dev/null +++ b/tests/datasets/test_bright.py @@ -0,0 +1,89 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch + +from torchgeo.datasets import BRIGHTDFC2025, DatasetNotFoundError + + +class TestBRIGHTDFC2025: + @pytest.fixture(params=['train', 'val', 'test']) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> BRIGHTDFC2025: + md5 = '7b0e24d45fb2d9a4f766196702586414' + monkeypatch.setattr(BRIGHTDFC2025, 'md5', md5) + url = os.path.join('tests', 'data', 'bright', 'dfc25_track2_trainval.zip') + monkeypatch.setattr(BRIGHTDFC2025, 'url', url) + root = tmp_path + split = request.param + transforms = nn.Identity() + return BRIGHTDFC2025(root, split, transforms, download=True, checksum=True) + + def test_getitem(self, dataset: BRIGHTDFC2025) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x['image_pre'], torch.Tensor) + assert x['image_pre'].shape[0] == 3 + assert isinstance(x['image_post'], torch.Tensor) + assert x['image_post'].shape[0] == 3 + assert x['image_pre'].shape[-2:] == x['image_post'].shape[-2:] + if dataset.split != 'test': + assert isinstance(x['mask'], torch.Tensor) + assert x['image_pre'].shape[-2:] == x['mask'].shape[-2:] + + def test_len(self, dataset: BRIGHTDFC2025) -> None: + if dataset.split == 'train': + assert len(dataset) == 3 + elif dataset.split == 'val': + assert len(dataset) == 1 + else: + assert len(dataset) == 2 + + def test_already_downloaded(self, dataset: BRIGHTDFC2025) -> None: + BRIGHTDFC2025(root=dataset.root) + + def test_not_yet_extracted(self, tmp_path: Path) -> None: + filename = 'dfc25_track2_trainval.zip' + dir = os.path.join('tests', 'data', 'bright') + shutil.copyfile( + os.path.join(dir, filename), os.path.join(str(tmp_path), filename) + ) + BRIGHTDFC2025(root=str(tmp_path)) + + def test_invalid_split(self) -> None: + with pytest.raises(AssertionError): + BRIGHTDFC2025(split='foo') + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + BRIGHTDFC2025(tmp_path) + + def test_corrupted(self, tmp_path: Path) -> None: + with open(os.path.join(tmp_path, 'dfc25_track2_trainval.zip'), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): + BRIGHTDFC2025(root=tmp_path, checksum=True) + + def test_plot(self, dataset: BRIGHTDFC2025) -> None: + dataset.plot(dataset[0], suptitle='Test') + plt.close() + + if dataset.split != 'test': + sample = dataset[0] + sample['prediction'] = torch.clone(sample['mask']) + dataset.plot(sample, suptitle='Prediction') + plt.close() + + del sample['mask'] + dataset.plot(sample, suptitle='Only Prediction') + plt.close() diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 8f238abd916..8177120c2a7 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -11,6 +11,7 @@ from .benin_cashews import BeninSmallHolderCashews from .bigearthnet import BigEarthNet from .biomassters import BioMassters +from .bright import BRIGHTDFC2025 from .cabuar import CaBuAr from .caffe import CaFFe from .cbf import CanadianBuildingFootprints @@ -152,6 +153,7 @@ __all__ = ( 'ADVANCE', + 'BRIGHTDFC2025', 'CDL', 'COWC', 'DFC2022', diff --git a/torchgeo/datasets/bright.py b/torchgeo/datasets/bright.py new file mode 100644 index 00000000000..bf98711a4ed --- /dev/null +++ b/torchgeo/datasets/bright.py @@ -0,0 +1,340 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""BRIGHT dataset.""" + +import os +import textwrap +from collections.abc import Callable +from typing import ClassVar + +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +import numpy as np +import rasterio +import torch +from einops import repeat +from matplotlib import colors +from matplotlib.figure import Figure +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path, check_integrity, download_url, extract_archive + + +class BRIGHTDFC2025(NonGeoDataset): + """BRIGHT DFC2025 dataset. + + The `BRIGHT `__ dataset consists of bi-temporal + high-resolution multimodal images for + building damage assessment. The dataset is part of the 2025 IEEE GRSS Data Fusion Contest. + The pre-disaster images are optical images and the post-disaster images are SAR images, and + targets were manually annotated. The dataset is split into train, val, and test splits, but + the test split does not contain targets in this version. + + More information can be found at the `Challenge website `__. + + Dataset Features: + + * Pre-disaster optical images from MAXAR, NAIP, NOAA Digital Coast Raster Datasets, and the National Plan for Aerial Orthophotography Spain + * Post-disaster SAR images from Capella Space and Umbra + * high image resolution of 0.3-1m + + Dataset Format: + + * Images are in GeoTIFF format with pixel dimensions of 1024x1024 + * Pre-disaster are three channel images + * Post-disaster SAR images are single channel but repeated to have 3 channels + + If you use this dataset in your research, please cite the following paper: + + * https://arxiv.org/abs/2501.06019 + + .. versionadded:: 0.7 + """ + + classes = ('background', 'intact', 'damaged', 'destroyed') + + colormap = ( + 'white', # background + 'green', # intact + 'burlywood', # damaged + 'red', # destroyed + ) + + md5 = '2c435bb50345d425390eff59a92134ac' + + url = 'https://huggingface.co/datasets/torchgeo/bright/resolve/d19972f5e682ad684dcde35529a6afad4c719f1b/dfc25_track2_trainval_with_split.zip' + + data_dir = 'dfc25_track2_trainval' + + valid_split = ('train', 'val', 'test') + + # train_setlevels.txt are the training samples + # holdout_setlevels.txt are the validation samples + # val_setlevels.txt are the test samples + split_files: ClassVar[dict[str, str]] = { + 'train': 'train_setlevel.txt', + 'val': 'holdout_setlevel.txt', + 'test': 'val_setlevel.txt', + } + + px_class_values: ClassVar[dict[int, str]] = { + 0: 'background', + 1: 'intact', + 2: 'damaged', + 3: 'destroyed', + } + + def __init__( + self, + root: Path = 'data', + split: str = 'train', + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new BRIGHT DFC2025 dataset instance. + + Args: + root: root directory where dataset can be found + split: train/val/test split to load + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + DatasetNotFoundError: If dataset is not found and *download* is False. + AssertionError: If *split* is not one of 'train', 'val', or 'test. + """ + assert split in self.valid_split, f'Split must be one of {self.valid_split}' + self.root = root + self.split = split + self.transforms = transforms + self.download = download + self.checksum = checksum + + self._verify() + + self.sample_paths = self._get_paths() + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and target at that index, pre and post image + are returned under separate image keys + """ + idx_paths = self.sample_paths[index] + + image_pre = self._load_image(idx_paths['image_pre']).float() + image_post = self._load_image(idx_paths['image_post']).float() + # https://github.com/ChenHongruixuan/BRIGHT/blob/11b1ffafa4d30d2df2081189b56864b0de4e3ed7/dfc25_benchmark/dataset/make_data_loader.py#L101 + # post image is stacked to also have 3 channels + image_post = repeat(image_post, 'c h w -> (repeat c) h w', repeat=3) + + sample = {'image_pre': image_pre, 'image_post': image_post} + + if 'target' in idx_paths and self.split != 'test': + target = self._load_image(idx_paths['target']).long() + sample['mask'] = target + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _get_paths(self) -> list[dict[str, str]]: + """Get paths to the dataset files based on specified splits. + + Returns: + a list of dictionaries containing paths to the pre, post, and target images + """ + split_file = self.split_files[self.split] + + file_path = os.path.join(self.root, self.data_dir, split_file) + with open(file_path) as f: + sample_ids = f.readlines() + + if self.split in ('train', 'val'): + dir_split_name = 'train' + else: + dir_split_name = 'val' + + sample_paths = [ + { + 'image_pre': os.path.join( + self.root, + self.data_dir, + dir_split_name, + 'pre-event', + f'{sample_id.strip()}_pre_disaster.tif', + ), + 'image_post': os.path.join( + self.root, + self.data_dir, + dir_split_name, + 'post-event', + f'{sample_id.strip()}_post_disaster.tif', + ), + } + for sample_id in sample_ids + ] + if self.split != 'test': + for sample, sample_id in zip(sample_paths, sample_ids): + sample['target'] = os.path.join( + self.root, + self.data_dir, + dir_split_name, + 'target', + f'{sample_id.strip()}_building_damage.tif', + ) + + return sample_paths + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # check if the text split files exist + if all( + os.path.exists(os.path.join(self.root, self.data_dir, split_file)) + for split_file in self.split_files.values() + ): + # if split txt files exist check whether sample files exist + sample_paths = self._get_paths() + exists = [] + for sample in sample_paths: + exists.append( + all(os.path.exists(path) for name, path in sample.items()) + ) + if all(exists): + return + + # check if .zip files already exists (if so, then extract) + exists = [] + zip_file_path = os.path.join(self.root, self.data_dir + '.zip') + if os.path.exists(zip_file_path): + if self.checksum and not check_integrity(zip_file_path, self.md5): + raise RuntimeError('Dataset found, but corrupted.') + exists.append(True) + extract_archive(zip_file_path, self.root) + else: + exists.append(False) + + if all(exists): + return + + if not self.download: + raise DatasetNotFoundError(self) + + # download and extract the dataset + self._download() + extract_archive(zip_file_path, self.root) + + def _download(self) -> None: + """Download the dataset.""" + download_url( + self.url, + self.root, + self.data_dir + '.zip', + md5=self.md5 if self.checksum else None, + ) + + def __len__(self) -> int: + """Return the number of samples in the dataset. + + Returns: + number of samples in the dataset + """ + return len(self.sample_paths) + + def _load_image(self, path: Path) -> Tensor: + """Load a file from disk. + + Args: + path: path to the file to load + + Returns: + image tensor + """ + with rasterio.open(path) as src: + img = src.read() + tensor: Tensor = torch.from_numpy(img).float() + return tensor + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + ncols = 2 + showing_mask = 'mask' in sample + showing_prediction = 'prediction' in sample + if showing_mask: + ncols += 1 + if showing_prediction: + ncols += 1 + + fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(15, 5)) + + axs[0].imshow(sample['image_pre'].permute(1, 2, 0) / 255.0) + axs[0].axis('off') + + axs[1].imshow(sample['image_post'].permute(1, 2, 0) / 255.0) + axs[1].axis('off') + + cmap = colors.ListedColormap(self.colormap) + + if showing_mask: + axs[2].imshow(sample['mask'].squeeze(0), cmap=cmap, interpolation='none') + axs[2].axis('off') + unique_classes = np.unique(sample['mask'].numpy()) + handles = [ + mpatches.Patch( + color=cmap(ordinal), + label='\n'.join( + textwrap.wrap(self.px_class_values[px_class], width=10) + ), + ) + for ordinal, px_class in enumerate(self.px_class_values.keys()) + if ordinal in unique_classes + ] + axs[2].legend(handles=handles, loc='upper right', bbox_to_anchor=(1.4, 1)) + if showing_prediction: + axs[3].imshow( + sample['prediction'].squeeze(0), cmap=cmap, interpolation='none' + ) + axs[3].axis('off') + elif showing_prediction: + axs[2].imshow( + sample['prediction'].squeeze(0), cmap=cmap, interpolation='none' + ) + axs[2].axis('off') + + if show_titles: + axs[0].set_title('Pre-disaster image') + axs[1].set_title('Post-disaster image') + if showing_mask: + axs[2].set_title('Ground truth') + if showing_prediction: + axs[-1].set_title('Prediction') + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig