@@ -103,15 +103,42 @@ class SSL4EOL(SSL4EO):
103
103
"""
104
104
105
105
class _Metadata (TypedDict ):
106
- num_bands : int
106
+ all_bands : list [ str ]
107
107
rgb_bands : list [int ]
108
108
109
109
metadata : ClassVar [dict [str , _Metadata ]] = {
110
- 'tm_toa' : {'num_bands' : 7 , 'rgb_bands' : [2 , 1 , 0 ]},
111
- 'etm_toa' : {'num_bands' : 9 , 'rgb_bands' : [2 , 1 , 0 ]},
112
- 'etm_sr' : {'num_bands' : 6 , 'rgb_bands' : [2 , 1 , 0 ]},
113
- 'oli_tirs_toa' : {'num_bands' : 11 , 'rgb_bands' : [3 , 2 , 1 ]},
114
- 'oli_sr' : {'num_bands' : 7 , 'rgb_bands' : [3 , 2 , 1 ]},
110
+ 'tm_toa' : {
111
+ 'all_bands' : ['B1' , 'B2' , 'B3' , 'B4' , 'B5' , 'B6' , 'B7' ],
112
+ 'rgb_bands' : [2 , 1 , 0 ],
113
+ },
114
+ 'etm_toa' : {
115
+ 'all_bands' : ['B1' , 'B2' , 'B3' , 'B4' , 'B5' , 'B6' , 'B6' , 'B7' , 'B8' ],
116
+ 'rgb_bands' : [2 , 1 , 0 ],
117
+ },
118
+ 'etm_sr' : {
119
+ 'all_bands' : ['B1' , 'B2' , 'B3' , 'B4' , 'B5' , 'B7' ],
120
+ 'rgb_bands' : [2 , 1 , 0 ],
121
+ },
122
+ 'oli_tirs_toa' : {
123
+ 'all_bands' : [
124
+ 'B1' ,
125
+ 'B2' ,
126
+ 'B3' ,
127
+ 'B4' ,
128
+ 'B5' ,
129
+ 'B6' ,
130
+ 'B7' ,
131
+ 'B8' ,
132
+ 'B9' ,
133
+ 'B10' ,
134
+ 'B11' ,
135
+ ],
136
+ 'rgb_bands' : [3 , 2 , 1 ],
137
+ },
138
+ 'oli_sr' : {
139
+ 'all_bands' : ['B1' , 'B2' , 'B3' , 'B4' , 'B5' , 'B6' , 'B7' ],
140
+ 'rgb_bands' : [3 , 2 , 1 ],
141
+ },
115
142
}
116
143
117
144
url = 'https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/e2467887e6a6bcd7547d9d5999f8d9bc3323dc31/{0}/ssl4eo_l_{0}.tar.gz{1}'
@@ -212,8 +239,8 @@ def __init__(
212
239
base = Landsat8
213
240
214
241
self .wavelengths = []
215
- for band in range ( 1 , self .metadata [split ]['num_bands' ] + 1 ) :
216
- self .wavelengths .append (base .wavelengths [f'B { band } ' ])
242
+ for band in self .metadata [split ]['all_bands' ] :
243
+ self .wavelengths .append (base .wavelengths [band ])
217
244
218
245
self .scenes = sorted (os .listdir (self .subdir ))
219
246
@@ -236,7 +263,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
236
263
ts = []
237
264
wavelengths = []
238
265
for subdir in subdirs :
239
- mint , maxt = disambiguate_timestamp (subdir [: - 8 ], Landsat .date_format )
266
+ mint , maxt = disambiguate_timestamp (subdir [- 8 : ], Landsat .date_format )
240
267
directory = os .path .join (root , subdir )
241
268
filename = os .path .join (directory , 'all_bands.tif' )
242
269
with rasterio .open (filename ) as f :
@@ -338,7 +365,7 @@ def plot(
338
365
fig , axes = plt .subplots (
339
366
ncols = self .seasons , squeeze = False , figsize = (4 * self .seasons , 4 )
340
367
)
341
- num_bands = self .metadata [self .split ]['num_bands' ]
368
+ num_bands = len ( self .metadata [self .split ]['all_bands' ])
342
369
rgb_bands = self .metadata [self .split ]['rgb_bands' ]
343
370
344
371
for i in range (self .seasons ):
0 commit comments