Skip to content

Commit aff5aaa

Browse files
committed
Fix tests
1 parent 4b343f2 commit aff5aaa

File tree

3 files changed

+39
-12
lines changed

3 files changed

+39
-12
lines changed

tests/datasets/test_ssl4eo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_getitem(self, dataset: SSL4EOL) -> None:
6565
assert isinstance(x['image'], torch.Tensor)
6666
assert (
6767
x['image'].size(0)
68-
== dataset.seasons * dataset.metadata[dataset.split]['num_bands']
68+
== dataset.seasons * len(dataset.metadata[dataset.split]['all_bands'])
6969
)
7070

7171
def test_len(self, dataset: SSL4EOL) -> None:

torchgeo/datasets/landsat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ class Landsat7(Landsat):
226226

227227
filename_glob = 'LE07_*_{}.*'
228228

229-
default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7')
229+
default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7')
230230
rgb_bands = ('SR_B3', 'SR_B2', 'SR_B1')
231231

232232
wavelengths: ClassVar[dict[str, float]] = {

torchgeo/datasets/ssl4eo.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,42 @@ class SSL4EOL(SSL4EO):
103103
"""
104104

105105
class _Metadata(TypedDict):
106-
num_bands: int
106+
all_bands: list[str]
107107
rgb_bands: list[int]
108108

109109
metadata: ClassVar[dict[str, _Metadata]] = {
110-
'tm_toa': {'num_bands': 7, 'rgb_bands': [2, 1, 0]},
111-
'etm_toa': {'num_bands': 9, 'rgb_bands': [2, 1, 0]},
112-
'etm_sr': {'num_bands': 6, 'rgb_bands': [2, 1, 0]},
113-
'oli_tirs_toa': {'num_bands': 11, 'rgb_bands': [3, 2, 1]},
114-
'oli_sr': {'num_bands': 7, 'rgb_bands': [3, 2, 1]},
110+
'tm_toa': {
111+
'all_bands': ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7'],
112+
'rgb_bands': [2, 1, 0],
113+
},
114+
'etm_toa': {
115+
'all_bands': ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B6', 'B7', 'B8'],
116+
'rgb_bands': [2, 1, 0],
117+
},
118+
'etm_sr': {
119+
'all_bands': ['B1', 'B2', 'B3', 'B4', 'B5', 'B7'],
120+
'rgb_bands': [2, 1, 0],
121+
},
122+
'oli_tirs_toa': {
123+
'all_bands': [
124+
'B1',
125+
'B2',
126+
'B3',
127+
'B4',
128+
'B5',
129+
'B6',
130+
'B7',
131+
'B8',
132+
'B9',
133+
'B10',
134+
'B11',
135+
],
136+
'rgb_bands': [3, 2, 1],
137+
},
138+
'oli_sr': {
139+
'all_bands': ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7'],
140+
'rgb_bands': [3, 2, 1],
141+
},
115142
}
116143

117144
url = 'https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/e2467887e6a6bcd7547d9d5999f8d9bc3323dc31/{0}/ssl4eo_l_{0}.tar.gz{1}'
@@ -212,8 +239,8 @@ def __init__(
212239
base = Landsat8
213240

214241
self.wavelengths = []
215-
for band in range(1, self.metadata[split]['num_bands'] + 1):
216-
self.wavelengths.append(base.wavelengths[f'B{band}'])
242+
for band in self.metadata[split]['all_bands']:
243+
self.wavelengths.append(base.wavelengths[band])
217244

218245
self.scenes = sorted(os.listdir(self.subdir))
219246

@@ -236,7 +263,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
236263
ts = []
237264
wavelengths = []
238265
for subdir in subdirs:
239-
mint, maxt = disambiguate_timestamp(subdir[:-8], Landsat.date_format)
266+
mint, maxt = disambiguate_timestamp(subdir[-8:], Landsat.date_format)
240267
directory = os.path.join(root, subdir)
241268
filename = os.path.join(directory, 'all_bands.tif')
242269
with rasterio.open(filename) as f:
@@ -338,7 +365,7 @@ def plot(
338365
fig, axes = plt.subplots(
339366
ncols=self.seasons, squeeze=False, figsize=(4 * self.seasons, 4)
340367
)
341-
num_bands = self.metadata[self.split]['num_bands']
368+
num_bands = len(self.metadata[self.split]['all_bands'])
342369
rgb_bands = self.metadata[self.split]['rgb_bands']
343370

344371
for i in range(self.seasons):

0 commit comments

Comments
 (0)