@@ -162,6 +162,64 @@ def setup(self, stage: str) -> None:
162
162
self .test_dataset , self .patch_size , self .patch_size
163
163
)
164
164
165
+ def preview_data (self , max_samples : int = 100 , map = None ):
166
+ """
167
+ Visualize the dataset splits and their samplers on an interactive map.
168
+
169
+ This method allows users to verify their data splits before training by
170
+ displaying the datasets and how they are sampled using the configured samplers.
171
+
172
+ Args:
173
+ max_samples: Maximum number of samples to display per dataset.
174
+ Set to a reasonable value to keep the map responsive.
175
+ map: Optional existing folium map to add the visualization to.
176
+ If None, a new map will be created.
177
+
178
+ Returns:
179
+ folium.Map: The interactive map with dataset boundaries and samplers visualized.
180
+ """
181
+ from eotorch .plot import plot_samplers_on_map
182
+
183
+ # Ensure samplers are set up
184
+ if self .train_sampler is None :
185
+ print ("Setting up samplers for visualization..." )
186
+ self .setup (stage = "fit" )
187
+ if self .test_dataset is not None and self .test_sampler is None :
188
+ self .setup (stage = "test" )
189
+
190
+ # Collect datasets and samplers
191
+ datasets , samplers , names = [], [], []
192
+ if self .train_dataset is not None and self .train_sampler is not None :
193
+ # if hasattr(self, "train_dataset") and hasattr(self, "train_sampler"):
194
+ datasets .append (self .train_dataset )
195
+ samplers .append (self .train_sampler )
196
+ names .append ("Train" )
197
+
198
+ if self .val_dataset is not None and self .val_sampler is not None :
199
+ datasets .append (self .val_dataset )
200
+ samplers .append (self .val_sampler )
201
+ names .append ("Validation" )
202
+
203
+ if self .test_dataset is not None and self .test_sampler is not None :
204
+ datasets .append (self .test_dataset )
205
+ samplers .append (self .test_sampler )
206
+ names .append ("Test" )
207
+
208
+ if not datasets :
209
+ raise ValueError (
210
+ "No datasets with samplers available. "
211
+ "Make sure to initialize the data module properly."
212
+ )
213
+
214
+ # Visualize the datasets and samplers
215
+ return plot_samplers_on_map (
216
+ datasets = datasets ,
217
+ samplers = samplers ,
218
+ map = map ,
219
+ max_samples = max_samples ,
220
+ dataset_names = names ,
221
+ )
222
+
165
223
def _dataloader_factory (self , split : str ) -> DataLoader [dict [str , Tensor ]]:
166
224
"""
167
225
Same as GeoDataModule._dataloader_factory but allows for customization of dataloader behavior.
0 commit comments