Skip to content

Commit 1028953

Browse files
committed
aoi split and data preview
1 parent 01f7472 commit 1028953

File tree

5 files changed

+769
-25
lines changed

5 files changed

+769
-25
lines changed

eotorch/data/datamodules.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,64 @@ def setup(self, stage: str) -> None:
162162
self.test_dataset, self.patch_size, self.patch_size
163163
)
164164

165+
def preview_data(self, max_samples: int = 100, map=None):
166+
"""
167+
Visualize the dataset splits and their samplers on an interactive map.
168+
169+
This method allows users to verify their data splits before training by
170+
displaying the datasets and how they are sampled using the configured samplers.
171+
172+
Args:
173+
max_samples: Maximum number of samples to display per dataset.
174+
Set to a reasonable value to keep the map responsive.
175+
map: Optional existing folium map to add the visualization to.
176+
If None, a new map will be created.
177+
178+
Returns:
179+
folium.Map: The interactive map with dataset boundaries and samplers visualized.
180+
"""
181+
from eotorch.plot import plot_samplers_on_map
182+
183+
# Ensure samplers are set up
184+
if self.train_sampler is None:
185+
print("Setting up samplers for visualization...")
186+
self.setup(stage="fit")
187+
if self.test_dataset is not None and self.test_sampler is None:
188+
self.setup(stage="test")
189+
190+
# Collect datasets and samplers
191+
datasets, samplers, names = [], [], []
192+
if self.train_dataset is not None and self.train_sampler is not None:
193+
# if hasattr(self, "train_dataset") and hasattr(self, "train_sampler"):
194+
datasets.append(self.train_dataset)
195+
samplers.append(self.train_sampler)
196+
names.append("Train")
197+
198+
if self.val_dataset is not None and self.val_sampler is not None:
199+
datasets.append(self.val_dataset)
200+
samplers.append(self.val_sampler)
201+
names.append("Validation")
202+
203+
if self.test_dataset is not None and self.test_sampler is not None:
204+
datasets.append(self.test_dataset)
205+
samplers.append(self.test_sampler)
206+
names.append("Test")
207+
208+
if not datasets:
209+
raise ValueError(
210+
"No datasets with samplers available. "
211+
"Make sure to initialize the data module properly."
212+
)
213+
214+
# Visualize the datasets and samplers
215+
return plot_samplers_on_map(
216+
datasets=datasets,
217+
samplers=samplers,
218+
map=map,
219+
max_samples=max_samples,
220+
dataset_names=names,
221+
)
222+
165223
def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
166224
"""
167225
Same as GeoDataModule._dataloader_factory but allows for customization of dataloader behavior.

0 commit comments

Comments
 (0)