diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index c60b08f6666..c2eb9bab961 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -202,6 +202,10 @@ ADVANCE .. autoclass:: ADVANCE +AI4ArcticSeaIce + +.. autoclass:: AI4ArcticSeaIce + Benin Cashew Plantations ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/api/datasets/non_geo_datasets.csv b/docs/api/datasets/non_geo_datasets.csv index f91f6b0e967..0f452ee265d 100644 --- a/docs/api/datasets/non_geo_datasets.csv +++ b/docs/api/datasets/non_geo_datasets.csv @@ -1,5 +1,6 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `ADVANCE`_,C,"Google Earth, Freesound","CC-BY-4.0","5,075",13,512x512,0.5,RGB +`AI4ArcticSeaIce`_,S,"Sentinel-1","CC-BY-4.0","520",2,"~5000x5000",80,"HH,HV" `Benin Cashew Plantations`_,S,Airbus Pléiades,"CC-BY-4.0",70,6,"1,122x1,186",10,MSI `BigEarthNet`_,C,Sentinel-1/2,"CDLA-Permissive-1.0","590,326",19--43,120x120,10,"SAR, MSI" `BioMassters`_,R,Sentinel-1/2 and Lidar,"CC-BY-4.0",,,256x256, 10, "SAR, MSI" diff --git a/tests/data/ai4arctic_sea_ice/data.py b/tests/data/ai4arctic_sea_ice/data.py new file mode 100644 index 00000000000..826e73e397c --- /dev/null +++ b/tests/data/ai4arctic_sea_ice/data.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import numpy as np +import xarray as xr +import pandas as pd +import tarfile +import hashlib +import shutil +from datetime import datetime, timedelta + + +def create_dummy_nc_file(filepath: str, is_reference: bool = False): + """Create dummy netCDF file matching original dataset structure.""" + + # Define dimensions + dims = { + 'sar_lines': 12, + 'sar_samples': 9, + 'sar_sample_2dgrid_points': 3, + 'sar_line_2dgrid_points': 4, + '2km_grid_lines': 5, + '2km_grid_samples': 6, + } + + # Create variables with realistic dummy data + data_vars = { + # SAR variables (full resolution) + 'nersc_sar_primary': ( + ('sar_lines', 'sar_samples'), + np.random.normal(-20, 5, (dims['sar_lines'], dims['sar_samples'])).astype( + np.float32 + ), + ), + 'nersc_sar_secondary': ( + ('sar_lines', 'sar_samples'), + np.random.normal(-25, 5, (dims['sar_lines'], dims['sar_samples'])).astype( + np.float32 + ), + ), + # Grid coordinates + 'sar_grid2d_latitude': ( + ('sar_sample_2dgrid_points', 'sar_line_2dgrid_points'), + np.random.uniform( + 60, + 80, + (dims['sar_sample_2dgrid_points'], dims['sar_line_2dgrid_points']), + ).astype(np.float64), + ), + 'sar_grid2d_longitude': ( + ('sar_sample_2dgrid_points', 'sar_line_2dgrid_points'), + np.random.uniform( + -60, + 0, + (dims['sar_sample_2dgrid_points'], dims['sar_line_2dgrid_points']), + ).astype(np.float64), + ), + # Weather variables (2km grid) + 'u10m_rotated': ( + ('2km_grid_lines', '2km_grid_samples'), + np.random.normal( + 0, 5, (dims['2km_grid_lines'], dims['2km_grid_samples']) + ).astype(np.float32), + ), + 'v10m_rotated': ( + ('2km_grid_lines', '2km_grid_samples'), + np.random.normal( + 0, 5, (dims['2km_grid_lines'], dims['2km_grid_samples']) + ).astype(np.float32), + ), + # AMSR2 variables (6.9, 7.3, 10.7, 23.8, 36.5, 89.0 GHz, h, v) + **{ + f'btemp_{freq}{pol}': ( + ('2km_grid_lines', '2km_grid_samples'), + np.random.normal( + 250, 20, (dims['2km_grid_lines'], dims['2km_grid_samples']) + ).astype(np.float32), + ) + for freq in ['6_9', '7_3'] + for pol in ['h', 'v'] + }, + # Add distance map + 'distance_map': ( + ('sar_lines', 'sar_samples'), + np.random.uniform(0, 10, (dims['sar_lines'], dims['sar_samples'])).astype( + np.float32 + ), + { + 'long_name': 'Distance to land zones numbered with ids ranging from 0 to N', + 'zonal_range_description': '\ndist_id; dist_range_km\n0; land\n1; 0 -> 0.5\n2; 0.5 -> 1\n3; 1 -> 2\n4; 2 -> 4\n5; 4 -> 8\n6; 8 -> 16\n7; 16 -> 32\n8; 32 -> 64\n9; 64 -> 128\n10; >128', + }, + ), + } + + # Add target variables if reference file + if is_reference: + data_vars.update( + { + 'SOD': ( + ('sar_lines', 'sar_samples'), + np.random.randint( + 0, 6, (dims['sar_lines'], dims['sar_samples']) + ).astype(np.uint8), + ), + 'SIC': ( + ('sar_lines', 'sar_samples'), + np.random.randint( + 0, 11, (dims['sar_lines'], dims['sar_samples']) + ).astype(np.uint8), + ), + 'FLOE': ( + ('sar_lines', 'sar_samples'), + np.random.randint( + 0, 7, (dims['sar_lines'], dims['sar_samples']) + ).astype(np.uint8), + ), + } + ) + + # Create dataset with correct attributes + ds = xr.Dataset( + data_vars=data_vars, + attrs={ + 'scene_id': os.path.basename(filepath), + 'original_id': f'S1A_EW_GRDM_1SDH_{os.path.basename(filepath)}', + 'ice_service': 'dmi' if 'dmi' in filepath else 'cis', + 'flip': 0, + 'pixel_spacing': 80, + }, + ) + + # Save to netCDF file + os.makedirs(os.path.dirname(filepath), exist_ok=True) + ds.to_netcdf(filepath) + + +def create_metadata_csv(root_dir: str, n_train: int = 3, n_test: int = 2): + """Create metadata CSV file.""" + records = [] + + # Generate dates + base_date = datetime(2021, 1, 1) + dates = [base_date + timedelta(days=i) for i in range(n_train + n_test)] + + # Create train records + for i in range(n_train): + date_str = dates[i].strftime('%Y%m%dT%H%M%S') + service = 'dmi' if i % 2 == 0 else 'cis' + path = f'train/{date_str}_{service}_prep.nc' + records.append( + { + 'input_path': path, + 'reference_path': None, + 'date': dates[i], + 'ice_service': service, + 'split': 'train', + 'region_id': 'SGRDIFOXE' if service == 'cis' else 'North_RIC', + } + ) + + # Create test records + for i in range(n_test): + date_str = dates[n_train + i].strftime('%Y%m%dT%H%M%S') + service = 'dmi' if i % 2 == 0 else 'cis' + input_path = f'test/{date_str}_{service}_prep.nc' + ref_path = f'test/{date_str}_{service}_prep_reference.nc' + records.append( + { + 'input_path': input_path, + 'reference_path': ref_path, + 'date': dates[n_train + i], + 'ice_service': service, + 'split': 'test', + 'region_id': 'SGRDIFOXE' if service == 'cis' else 'North_RIC', + } + ) + + # Create DataFrame and save + df = pd.DataFrame(records) + df.to_csv(os.path.join(root_dir, 'metadata.csv'), index=False) + return df + + +def main(): + """Create complete dummy dataset.""" + root_dir = '.' + n_train = 3 + n_test = 2 + + # Create metadata + df = create_metadata_csv(root_dir, n_train, n_test) + + # Create train files + train_files = df[df['split'] == 'train']['input_path'] + for f in train_files: + create_dummy_nc_file(os.path.join(root_dir, f), is_reference=True) + + # Create test files + test_files = df[df['split'] == 'test'] + for _, row in test_files.iterrows(): + create_dummy_nc_file( + os.path.join(root_dir, row['input_path']), is_reference=False + ) + create_dummy_nc_file( + os.path.join(root_dir, row['reference_path']), is_reference=True + ) + + # Create and split train tarball + shutil.make_archive('train', 'gztar', '.', 'train') + + with open('train.tar.gz', 'rb') as f: + content = f.read() + + # Split into two chunks + chunk1 = content[: len(content) // 2] + chunk2 = content[len(content) // 2 :] + + with open('train.tar.gzaa', 'wb') as g: + g.write(chunk1) + with open('train.tar.gzab', 'wb') as g: + g.write(chunk2) + + # Remove original tarball + os.remove('train.tar.gz') + + with tarfile.open('test.tar.gz', 'w:gz') as tar: + tar.add('test') + + # compute md5sum + def md5(fname: str) -> str: + hash_md5 = hashlib.md5() + with open(fname, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b''): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + print(f'MD5 checksum train.gzaa: {md5("train.tar.gzaa")}') + print(f'MD5 checksum train.gzab: {md5("train.tar.gzab")}') + print(f'MD5 checksum test.gz: {md5("test.tar.gz")}') + print(f'MD5 checksum metadata: {md5("metadata.csv")}') + + +if __name__ == '__main__': + main() diff --git a/tests/data/ai4arctic_sea_ice/metadata.csv b/tests/data/ai4arctic_sea_ice/metadata.csv new file mode 100644 index 00000000000..61b73f0dbb9 --- /dev/null +++ b/tests/data/ai4arctic_sea_ice/metadata.csv @@ -0,0 +1,6 @@ +input_path,reference_path,date,ice_service,split,region_id +train/20210101T000000_dmi_prep.nc,,2021-01-01,dmi,train,North_RIC +train/20210102T000000_cis_prep.nc,,2021-01-02,cis,train,SGRDIFOXE +train/20210103T000000_dmi_prep.nc,,2021-01-03,dmi,train,North_RIC +test/20210104T000000_dmi_prep.nc,test/20210104T000000_dmi_prep_reference.nc,2021-01-04,dmi,test,North_RIC +test/20210105T000000_cis_prep.nc,test/20210105T000000_cis_prep_reference.nc,2021-01-05,cis,test,SGRDIFOXE diff --git a/tests/data/ai4arctic_sea_ice/test.tar.gz b/tests/data/ai4arctic_sea_ice/test.tar.gz new file mode 100644 index 00000000000..970ed7e1f72 Binary files /dev/null and b/tests/data/ai4arctic_sea_ice/test.tar.gz differ diff --git a/tests/data/ai4arctic_sea_ice/test/20210104T000000_dmi_prep.nc b/tests/data/ai4arctic_sea_ice/test/20210104T000000_dmi_prep.nc new file mode 100644 index 00000000000..ee956c2e447 Binary files /dev/null and b/tests/data/ai4arctic_sea_ice/test/20210104T000000_dmi_prep.nc differ diff --git a/tests/data/ai4arctic_sea_ice/test/20210104T000000_dmi_prep_reference.nc b/tests/data/ai4arctic_sea_ice/test/20210104T000000_dmi_prep_reference.nc new file mode 100644 index 00000000000..35b206103ef Binary files /dev/null and b/tests/data/ai4arctic_sea_ice/test/20210104T000000_dmi_prep_reference.nc differ diff --git a/tests/data/ai4arctic_sea_ice/test/20210105T000000_cis_prep.nc b/tests/data/ai4arctic_sea_ice/test/20210105T000000_cis_prep.nc new file mode 100644 index 00000000000..ef79d0ac121 Binary files /dev/null and b/tests/data/ai4arctic_sea_ice/test/20210105T000000_cis_prep.nc differ diff --git a/tests/data/ai4arctic_sea_ice/test/20210105T000000_cis_prep_reference.nc b/tests/data/ai4arctic_sea_ice/test/20210105T000000_cis_prep_reference.nc new file mode 100644 index 00000000000..4d2aaec6e49 Binary files /dev/null and b/tests/data/ai4arctic_sea_ice/test/20210105T000000_cis_prep_reference.nc differ diff --git a/tests/data/ai4arctic_sea_ice/train.tar.gzaa b/tests/data/ai4arctic_sea_ice/train.tar.gzaa new file mode 100644 index 00000000000..01d86337102 Binary files /dev/null and b/tests/data/ai4arctic_sea_ice/train.tar.gzaa differ diff --git a/tests/data/ai4arctic_sea_ice/train.tar.gzab b/tests/data/ai4arctic_sea_ice/train.tar.gzab new file mode 100644 index 00000000000..4ff54ff4dd8 Binary files /dev/null and b/tests/data/ai4arctic_sea_ice/train.tar.gzab differ diff --git a/tests/data/ai4arctic_sea_ice/train/20210101T000000_dmi_prep.nc b/tests/data/ai4arctic_sea_ice/train/20210101T000000_dmi_prep.nc new file mode 100644 index 00000000000..3453ec06aa0 Binary files /dev/null and b/tests/data/ai4arctic_sea_ice/train/20210101T000000_dmi_prep.nc differ diff --git a/tests/data/ai4arctic_sea_ice/train/20210102T000000_cis_prep.nc b/tests/data/ai4arctic_sea_ice/train/20210102T000000_cis_prep.nc new file mode 100644 index 00000000000..1d3c8a072fd Binary files /dev/null and b/tests/data/ai4arctic_sea_ice/train/20210102T000000_cis_prep.nc differ diff --git a/tests/data/ai4arctic_sea_ice/train/20210103T000000_dmi_prep.nc b/tests/data/ai4arctic_sea_ice/train/20210103T000000_dmi_prep.nc new file mode 100644 index 00000000000..85158f84acb Binary files /dev/null and b/tests/data/ai4arctic_sea_ice/train/20210103T000000_dmi_prep.nc differ diff --git a/tests/datasets/test_ai4arctic_sea_ice.py b/tests/datasets/test_ai4arctic_sea_ice.py new file mode 100644 index 00000000000..8e321c1be9c --- /dev/null +++ b/tests/datasets/test_ai4arctic_sea_ice.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import shutil +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch + +from torchgeo.datasets import DatasetNotFoundError, AI4ArcticSeaIce + +pytest.importorskip('xarray', minversion='2023.9') +pytest.importorskip('netCDF4', minversion='1.5.4') + +valid_amsr2_vars = ('btemp_6_9h', 'btemp_6_9v', 'btemp_7_3h', 'btemp_7_3v') +valid_weather_vars = ('u10m_rotated', 'v10m_rotated') + + +class TestAI4ArcticSeaIce: + @pytest.fixture( + params=zip( + ['train', 'train', 'test', 'test'], + ['SOD', 'SIC', 'FLOE', 'SIC'], + [None, 'distance_map', None, 'distance_map'], + [valid_amsr2_vars, None, valid_amsr2_vars, None], + [valid_weather_vars, None, valid_weather_vars, None], + ) + ) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> AI4ArcticSeaIce: + url = os.path.join('tests', 'data', 'ai4arctic_sea_ice', '{}') + monkeypatch.setattr(AI4ArcticSeaIce, 'url', url) + files = [ + {'name': 'train.tar.gzaa', 'md5': '399952b2603d0d508a30909357e6956a'}, + {'name': 'train.tar.gzab', 'md5': 'a998c852a2f418394f97cb1f99716489'}, + {'name': 'test.tar.gz', 'md5': 'b81e53b4c402a64d53854f02f66ce938'}, + {'name': 'metadata.csv', 'md5': 'd1222877af76d3fe9620678c930d70f0'}, + ] + monkeypatch.setattr(AI4ArcticSeaIce, 'files', files) + + monkeypatch.setattr(AI4ArcticSeaIce, 'valid_amsr2_vars', valid_amsr2_vars) + + monkeypatch.setattr(AI4ArcticSeaIce, 'valid_weather_vars', valid_weather_vars) + root = tmp_path + split, target_var, geo_var, amsr2_var, weather_var = request.param + transforms = nn.Identity() + return AI4ArcticSeaIce( + root, + split=split, + target_var=target_var, + geo_var=geo_var, + amsr2_vars=amsr2_var, + weather_vars=weather_var, + transforms=transforms, + download=True, + checksum=False, + ) + + def test_getitem(self, dataset: AI4ArcticSeaIce) -> None: + x = dataset[0] + assert isinstance(x, dict) + + def test_len(self, dataset: AI4ArcticSeaIce) -> None: + if dataset.split == 'train': + assert len(dataset) == 3 + else: + assert len(dataset) == 2 + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + AI4ArcticSeaIce(tmp_path) + + def test_already_downloaded_and_extracted(self, dataset: AI4ArcticSeaIce) -> None: + AI4ArcticSeaIce(root=dataset.root, download=False) + + def test_invalid_split(self) -> None: + with pytest.raises(AssertionError): + AI4ArcticSeaIce(split='foo') + + def test_plot(self, dataset: AI4ArcticSeaIce) -> None: + dataset.plot(dataset[0], suptitle='Test') + plt.close() + + sample = dataset[0] + sample['prediction'] = torch.clone(sample['mask']) + dataset.plot(sample, suptitle='Test with prediction') + plt.close() diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 0e522c09976..ec2fc2d419c 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -6,6 +6,7 @@ from .advance import ADVANCE from .agb_live_woody_density import AbovegroundLiveWoodyBiomassDensity from .agrifieldnet import AgriFieldNet +from .ai4arctic_sea_ice import AI4ArcticSeaIce from .airphen import Airphen from .astergdem import AsterGDEM from .benin_cashews import BeninSmallHolderCashews @@ -151,6 +152,7 @@ __all__ = ( 'ADVANCE', + 'AI4ArcticSeaIce', 'CDL', 'COWC', 'DFC2022', diff --git a/torchgeo/datasets/ai4arctic_sea_ice.py b/torchgeo/datasets/ai4arctic_sea_ice.py new file mode 100644 index 00000000000..f3f157819df --- /dev/null +++ b/torchgeo/datasets/ai4arctic_sea_ice.py @@ -0,0 +1,507 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""AI4Artic Sea Ice Dataset.""" + +import json +import os +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.figure import Figure +from matplotlib.patches import Patch +from collections.abc import Callable, Sequence +from datetime import datetime, timedelta +from typing import Any, ClassVar, cast + +import numpy as np +import torch +from torch import Tensor +import xarray as xr + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path, download_url, extract_archive, check_integrity + + +class AI4ArcticSeaIce(NonGeoDataset): + """AI4Artic Sea Ice Dataset. + + The `AI4ArcticSea Ice Challenge Dataset `_ contains Sentinel-1 SAR imagery, passive microwave radiometer observations + from AMSR2, and numerical weather prediction data from the ECMWF Reanalysis v5 (ERA5) dataset - all + gridded to match the Sentinel-1 SAR scenes geometrically. As label data, the dataset contains ice charts + manually produced by the ice analysts at the Greenland Ice Service and the Canadian Ice Service. This is the + "Ready-To-Train" version of the dataset (Version 3). + + Dataset features: + + * Dual-polarization SAR (HH, HV) imagery for each patch. + * Sea Ice Concentration (SIC): the percentage ratio of sea ice to open water for an area, + discretized into 11 10% bins ranging from 0% to 100%. + * Stage Of Development (SOD): type of sea ice, as proxy for ice thickness and + ease of traversing with 6 classes + * Floe size (FLOE): Classifying or segmenting distinct ice floes based on size, shape, + or other geometric properties. + + Dataset format: + + * each sample scene is stored in a separate .nc file + * pixel dimension of varying sizes up to ~5000pxx5000px + * 80m resolution + + Geographical variables: + + * distance-to-land layer (distance_map) + + SAR variables: + + * Sentinel-1 backscatter intensity (dB) in HH polarization (nersc_sar_primary) + * Sentinel-1 backscatter intensity (dB) in HV polarization (nersc_sar_secondary) + * Sentinel-1 incidence angle (sar_incidenceangle) + + Weather variables: + + * eastward wind component at 10m (u10m_rotated) + * northward wind component at 10m (v10m_rotated) + * ERA5 2m air temperature (t2m) + * ERA5 skin temperature (skt) + * ERA5 total column water vapor (tcwv) + * ERA5 total column liquid water (tclw) + + Advanced Microwave Scanning Radiometer 2 (AMSR2) variables: + + * 6.9 GHz Brightness Temperature (btemp_6_9h, btemp_6_9v) + * 7.3 GHz Brightness Temperature (btemp_7_3h, btemp_7_3v) + * 10.7 GHz Brightness Temperature (btemp_10_7h, btemp_10_7v) + * 18.7 GHz Brightness Temperature (btemp_18_7h, btemp_18_7v) + * 23.8 GHz Brightness Temperature (btemp_23_8h, btemp_23_8v) + * 36.5 GHz Brightness Temperature (btemp_36_5h, btemp_36_5v) + * 89.0 GHz Brightness Temperature (btemp_89_0h, btemp_89_0v) + + Sea Ice Concentration (SIC) classes: + + * 0: 0% + * 1: 0-10% + * 2: 10-20% + * 3: 20-30% + * 4: 30-40% + * 5: 40-50% + * 6: 50-60% + * 7: 60-70% + * 8: 70-80% + * 9: 80-90% + * 10: 90-100% + + Stage of Development (SOD) classes: + + * 0: Open-water + * 1: New ice + * 2: Young ice + * 3: Thin First-year ice + * 4: Thick First-year ice + * 5: Old ice (older than 1 year) + + Floe size (FLOE) classes: + + * 0: Open-water + * 1: Cake ice + * 2: Small floe + * 3: Medium floe + * 4: Big floe + * 5: Vast floe + * 6: Bergs (variants of icebergs and glacier ice) + + files: + Danish Meteorological Institute (DMI) and the Canadian Ice Service (CIS) + dmi_prep: data by DMI + cis_prep: data by CIS + dmi_prep_referece: contains SIC, SOD, FLOE + cis_prep_reference: contains SIC, SOD, FLOE + + Dataset format: + + * Dataset in separate .nc files + + If you use this dataset in your research, please cite the following paper: + + * https://data.dtu.dk/articles/dataset/Ready-To-Train_AI4Arctic_Sea_Ice_Challenge_Dataset/21316608 + + .. note:: + + This dataset requires the following additional libraries to be installed: + + * `xarray `_ + * `netcdf4 `_ + + .. versionadded:: 0.7 + + # Variables in the ASID3 challenge ready-to-train dataset + """ + + url = 'https://huggingface.co/datasets/torchgeo/ai4artic-sea-ice-challenge/resolve/main/{}' + + files = [ + {'name': 'metadata.csv', 'md5': '4b610118c2d182325ec7599434b37deb'}, + {'name': 'train.tar.gzaa', 'md5': '847ea12d0a5100f0a00af4bb110404b4'}, + {'name': 'train.tar.gzab', 'md5': '3f4770c586487dc681d1d216c7003f2c'}, + {'name': 'test.tar.gz', 'md5': 'bca98ec6734783aa6f005382549a0d21'}, + ] + + splits = ('train', 'test') + + # https://github.com/astokholm/AI4ArcticSeaIceChallenge/blob/4d5e3bc85e681f6c56821d96f2ebfcf4ed58b495/utils.py#L68 + SIC_GROUPS = { + 0: 0, + 1: 10, + 2: 20, + 3: 30, + 4: 40, + 5: 50, + 6: 60, + 7: 70, + 8: 80, + 9: 90, + 10: 100, + } + + SOD_GROUPS = { + 0: 'Open water', + 1: 'New Ice', + 2: 'Young ice', + 3: 'Thin FYI', + 4: 'Thick FYI', + 5: 'Old ice', + } + + FLOE_GROUPS = { + 0: 'Open water', + 1: 'Cake Ice', + 2: 'Small floe', + 3: 'Medium floe', + 4: 'Big floe', + 5: 'Vast floe', + 6: 'Bergs', + } + + valid_sar_vars = ('nersc_sar_primary', 'nersc_sar_secondary', 'sar_incidenceangle') + valid_geo_vars = ('distance_map',) + valid_amsr2_vars = ( + 'btemp_6_9h', + 'btemp_6_9v', + 'btemp_7_3h', + 'btemp_7_3v', + 'btemp_10_7h', + 'btemp_10_7v', + 'btemp_18_7h', + 'btemp_18_7v', + 'btemp_23_8h', + 'btemp_23_8v', + 'btemp_36_5h', + 'btemp_36_5v', + 'btemp_89_0h', + 'btemp_89_0v', + ) + valid_weather_vars = ('u10m_rotated', 'v10m_rotated', 't2m', 'skt', 'tcwv', 'tclw') + + valid_target_vars = ('SOD', 'SIC', 'FLOE') + + def __init__( + self, + root: Path = 'data', + split: str = 'train', + target_var: str = 'SOD', + geo_var: str | None = None, + amsr2_vars: Sequence[str] | None = None, + weather_vars: Sequence[str] | None = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize the AI4Artic Sea Ice dataset. + + Args: + root: root directory where the dataset can be found + split: The split of the dataset. Either 'train' or 'test'. + target_var: Target variable to be the label mask + geo_var: Geographical variables to include in the dataset, only option is 'distance_map' + amsr2_vars: AMSR2 channels to include in the dataset + weather_vars: Environmental variables to include in the dataset + transforms: a function/transform that takes input sample dictionary + and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If *split* is not one of 'train' or 'test', or if selected variables are not valid. + DatasetNotFoundError: If the dataset is not found and *download* is False. + DependencyNotFoundError: If xarray is not installed. + """ + assert target_var in self.valid_target_vars, ( + f'Invalid target variable selected. Must be one of {self.valid_target_vars}' + ) + if geo_var is not None: + assert geo_var == 'distance_map', ( + f"Invalid geographical variable selected. Only 'distance_map' is supported." + ) + + if amsr2_vars is not None: + assert all(var in self.valid_amsr2_vars for var in amsr2_vars), ( + f'Invalid AMSR2 variables selected. Must be a subset of {self.valid_amsr2_vars}' + ) + + if weather_vars is not None: + assert all(var in self.valid_weather_vars for var in weather_vars), ( + f'Invalid weather variables selected. Must be a subset of {self.valid_weather_vars}' + ) + + assert split in self.splits, ( + f"Split '{split}' not supported, must be one of {self.splits}" + ) + + self.target_var = target_var + self.geo_var = geo_var + self.amsr2_vars = amsr2_vars + self.weather_vars = weather_vars + + self.root = root + self.split = split + self.transforms = transforms + self.download = download + self.checksum = checksum + + self._verify() + + # metadata df + self.metadata_df = pd.read_csv(os.path.join(self.root, 'metadata.csv')) + self.metadata_df = self.metadata_df[ + self.metadata_df['split'] == self.split + ].reset_index(drop=True) + + def __len__(self) -> int: + """Return the number of samples in the dataset.""" + return len(self.metadata_df) + + def __getitem__(self, idx: int) -> dict[str, Tensor]: + """Get the sample at the given index. + + Args: + idx: index of the sample to return + + Returns: + A dictionary containing the sample data, split into the following keys by data type: + * 'image': SAR data stacked hh, hv in that order + * 'geo': Geographical data + * 'amsr2': AMSR2 data + * 'weather': Weather data + * 'mask': Chosen target data + """ + df_row = self.metadata_df.iloc[idx] + + # load data + sample = self._load_data(os.path.join(self.root, df_row['input_path'])) + + # load target + sample['mask'] = self._load_label(os.path.join(self.root, df_row['input_path'])) + if self.transforms is not None: + sample = self.transforms(sample) + + # crop bottom right corner of the image + # sample["image"] = sample["image"][:, -1024:, -1024:] + # sample["mask"] = sample["mask"][-1024:, -1024:] + # crop bottom left corner + # sample["image"] = sample["image"][:, -1024:, :1024] + # sample["mask"] = sample["mask"][-1024:, :1024] + + return sample + + def _load_data(self, path: str) -> dict[str, Tensor]: + """Load the data from the given path. + + Args: + input_path: path to the data file + + Returns: + A dictionary containing the data, split into the following keys followed by var name if specified: + * 'image': SAR data stacked hh, hv in that order + * 'geo': Geographical data + * 'amsr2': AMSR2 data + * 'weather': Weather data + """ + sample: dict[str, Tensor] = {} + + input_data = xr.open_dataset(path) + + # load s1 vars + hh = torch.from_numpy(input_data['nersc_sar_primary'].values) + hv = torch.from_numpy(input_data['nersc_sar_secondary'].values) + + # NaN values in SAR data have value 2 + sample['image'] = torch.stack([hh, hv], dim=0) + + if self.geo_var is not None: + sample['geo'] = torch.from_numpy(input_data[self.geo_var].values) + + if self.amsr2_vars is not None: + data = np.stack([input_data[var].values for var in self.amsr2_vars]) + sample['amsr2'] = torch.from_numpy(data) + + if self.weather_vars is not None: + data = np.stack([input_data[var].values for var in self.weather_vars]) + sample['weather'] = torch.from_numpy(data) + + input_data.close() + + return sample + + def _load_label(self, path: str) -> Tensor: + """Load the label from the given path. + + Args: + path: path to the label file + + Returns: + A tensor containing the label data + """ + # in test directory label is under a separate file + if self.split == 'test': + # append 'reference' to the input path to get the reference file + path = path.replace('.nc', '_reference.nc') + + target_data = xr.open_dataset(path) + # NaN values in target data have value 255 + tensor = torch.from_numpy(target_data[self.target_var].values).long() + target_data.close() + + return tensor + + def _verify(self) -> None: + """Verify integrity of the dataset.""" + # check if metadata file exists + exists = [] + if os.path.exists(os.path.join(self.root, 'metadata.csv')): + df = pd.read_csv(os.path.join(self.root, 'metadata.csv')) + for i, row in df.iterrows(): + exists.append( + os.path.exists(os.path.join(self.root, row['input_path'])) + ) + else: + exists.append(False) + + if all(exists): + return + + # check presence of tarball files + exists = [ + os.path.exists(os.path.join(self.root, file['name'])) for file in self.files + ] + if all(exists): + return + + if not self.download: + raise DatasetNotFoundError(self) + + self._download_data() + self._extract_data() + + def _download_data(self) -> None: + """Download data.""" + for file in self.files: + download_url( + self.url.format(file['name']), + self.root, + md5=file['md5'] if self.checksum else None, + ) + + def _extract_data(self) -> None: + """Extract the dataset.""" + # Concatenate the train tarballs together + chunk_size = 2**15 # same as torchvision + path = os.path.join(self.root, 'train.tar.gz') + with open(path, 'wb') as f: + for split in ['aa', 'ab']: + with open(os.path.join(self.root, f'train.tar.gz{split}'), 'rb') as g: + while chunk := g.read(chunk_size): + f.write(chunk) + extract_archive(path, self.root) + + # Extract test tarball + extract_archive(os.path.join(self.root, 'test.tar.gz'), self.root) + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`CaFFe.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + if 'prediction' in sample: + ncols = 3 + else: + ncols = 2 + + class_mapping = getattr(self, f'{self.target_var}_GROUPS') + # add 255 for NaN values + class_mapping[255] = 'NaN' + + num_classes = len(class_mapping) + + fig, axs = plt.subplots(1, ncols, figsize=(15, 7)) + + # Plot SAR image (HH channel) with proper normalization + hh_image = sample['image'][0].numpy() + vmin, vmax = np.nanpercentile(hh_image, (2, 98)) # robust normalization + axs[0].imshow(hh_image, cmap='gray', vmin=vmin, vmax=vmax) + axs[0].axis('off') + if show_titles: + axs[0].set_title('SAR HH Channel') + + # Create colormap with transparent color for NaN + colors = plt.cm.tab20(np.linspace(0, 1, num_classes)) + # colors = np.vstack((colors, [1, 1, 1, 0])) # add transparent for NaN + cmap = plt.cm.colors.ListedColormap(colors) + + # Plot mask with proper handling of NaN values + # import pdb + # pdb.set_trace() + mask = sample['mask'].numpy() + # mask_ma = ma.masked_where(mask == 255, mask) # mask NaN values + axs[1].imshow(mask, cmap=cmap, vmin=0, vmax=num_classes) + if show_titles: + axs[1].set_title(f'{self.target_var} Mask') + axs[1].axis('off') + + if 'prediction' in sample: + prediction = sample['prediction'].numpy() + # pred_ma = ma.masked_where(prediction == 255, prediction) + axs[2].imshow(prediction, cmap=cmap) + if show_titles: + axs[2].set_title('Prediction Mask') + axs[2].axis('off') + + # create legend with class names + # import pdb + # pdb.set_trace() + legend_elements = [ + Patch(facecolor=colors[i], label=list(class_mapping.values())[i]) + for i in range(num_classes) + ] + fig.legend( + handles=legend_elements, + loc='center right', + bbox_to_anchor=(0.98, 0.5), + title=self.target_var, + ) + + if suptitle is not None: + fig.suptitle(suptitle) + + return fig