Skip to content

Commit

Permalink
separate reading item for subclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsleh committed Oct 4, 2024
1 parent f84af05 commit f5b3fa2
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions torchgeo/datasets/mmearth.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def __init__(
subset: str = 'MMEarth',
modalities: Sequence[str] = all_modalities,
modality_bands: dict[str, list[str]] | None = None,
split: str = 'train',
normalization_mode: str = 'z-score',
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
) -> None:
Expand Down Expand Up @@ -213,9 +212,6 @@ def __init__(
assert (
subset in self.subsets
), f'Invalid dataset version: {subset}, please choose from {self.subsets}'
assert (
split in self.splits
), f'Invalid split: {split}, please choose from {self.splits}'

self._validate_modalities(modalities)
if modality_bands is None:
Expand All @@ -230,7 +226,7 @@ def __init__(
self.root = root
self.subset = subset
self.normalization_mode = normalization_mode
self.split = split
self.split = 'train'
self.transforms = transforms

self.dataset_filename = f'{self.filenames[subset]}.h5'
Expand Down Expand Up @@ -360,11 +356,29 @@ def __getitem__(self, index: int) -> dict[str, Any]:
dictionary containing the modalities and metadata
of the sample
"""
h5py = lazy_import('h5py')

sample: dict[str, Any] = {}
ds_index = self.indices[index]

# expose sample retrieval to separate function to allow for different index sampling strategies
# in subclasses
sample = self._retrieve_sample(ds_index)

if self.transforms is not None:
sample = self.transforms(sample)

return sample

def _retrieve_sample(self, ds_index: int) -> dict[str, Any]:
"""Retrieve a sample from the dataset.
Args:
ds_index: index inside the hdf5 dataset file
Returns:
dictionary containing the modalities and metadata
of the sample
"""
h5py = lazy_import('h5py')
sample: dict[str, Any] = {}
with h5py.File(
os.path.join(self.root, self.filenames[self.subset], self.dataset_filename),
'r',
Expand All @@ -381,17 +395,6 @@ def __getitem__(self, index: int) -> dict[str, Any]:
tensor = self._preprocess_modality(data, modality, l2a)
modality_name = self.modality_category_name.get(modality, '') + modality
sample[modality_name] = tensor
# # separate asc and desc
# # get indices for asc and desc in self.modality_bands[modality]
# def _select_sentinel1_bands(asc_or_desc: str, tensor) -> bool:
# indices = [
# self.all_modality_bands[modality].index(band)
# for band in self.modality_bands[modality]
# if asc_or_desc in band
# ]
# return tensor[indices, ...]
# sample[self.modality_category_name.get('sentinel1') + 'sentinel1_asc'] = _select_sentinel1_bands('asc', tensor)
# sample[self.modality_category_name.get('sentinel1') + 'sentinel1_desc'] = _select_sentinel1_bands('desc', tensor)

# add additional metadata to the sample
sample['lat'] = tile_info['lat']
Expand All @@ -400,9 +403,6 @@ def __getitem__(self, index: int) -> dict[str, Any]:
sample['crs'] = tile_info['CRS']
sample['tile_id'] = name

if self.transforms is not None:
sample = self.transforms(sample)

return sample

def _select_indices_for_modality(self, modality: str) -> list[int]:
Expand Down

0 comments on commit f5b3fa2

Please sign in to comment.