Skip to content

Commit 6dc4ba5

Browse files
committed
initial implementation of raster datasets
1 parent 0ea44b6 commit 6dc4ba5

File tree

9 files changed

+311
-142
lines changed

9 files changed

+311
-142
lines changed

README.md

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,32 @@
1-
# eotorch: Template Python repository
2-
3-
This repository serves as a template for creating a Python library
4-
5-
## How do I use this?
6-
7-
1. Create a new repository in GitHub with this repo as a starting point
8-
![](images/new_repo.png)
9-
10-
2. Change all occurences of `eotorch` to match the name of your new library
11-
12-
3. Consider if the [license](LICENSE) should be modified.
13-
14-
15-
## Additional resources
16-
17-
If you're interested in learning more about best practices for developing Python packages, check out the following resources:
18-
19-
- [Python Package Development at DHI](https://dhi.github.io/python-package-development/)
20-
- [Scientific Python Library Development Guide](https://learn.scientific-python.org/development/)
1+
## Example of how to use SegmentationRasterDataset:
2+
3+
```python
4+
from eotorch.datasets.geo import SegmentationRasterDataset
5+
from eotorch.plot.plot import plot_samples
6+
7+
class_mapping = {
8+
1: "Baresoil",
9+
2: "Buildings",
10+
3: "Coniferous Trees",
11+
4: "Deciduous Trees",
12+
5: "Grass",
13+
6: "Impervious",
14+
7: "Water",
15+
}
16+
17+
bla = SegmentationRasterDataset.create(
18+
images_dir="dev_data/sr_data",
19+
labels_dir="dev_data/labels",
20+
image_kwargs=dict(
21+
all_bands=("B02", "B03", "B04", "B08", "B11", "B12"),
22+
rgb_bands=("B04", "B03", "B02"),
23+
),
24+
label_kwargs=dict(
25+
class_mapping=class_mapping,
26+
),
27+
)
28+
plot_samples(bla, n=2, patch_size=256)
29+
```
30+
31+
![alt text](media/sample_1.png)
32+
![alt text](media/sample_2.png)

eotorch/datasets/geo.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from pathlib import Path
2+
from typing import Any
3+
4+
import matplotlib.patches as mpatches
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
from matplotlib import cm
8+
from rasterio.plot import show
9+
from torchgeo.datasets import IntersectionDataset, RasterDataset
10+
11+
12+
class PlottableRasterDataset(RasterDataset):
13+
colormap = cm.tab20
14+
nodata_value = 0
15+
class_mapping = None
16+
17+
def plot(self, sample, ax=None, **kwargs):
18+
if ax is None:
19+
_, ax = plt.subplots()
20+
21+
if self.is_image:
22+
rgb_indices = []
23+
for band in self.rgb_bands:
24+
rgb_indices.append(self.all_bands.index(band))
25+
26+
image = sample["image"][rgb_indices]
27+
return show(image.numpy(), ax=ax, adjust=True, **kwargs)
28+
29+
vals = sample["mask"].numpy()
30+
ax = show(vals, ax=ax, cmap=self.colormap, **kwargs)
31+
values = np.unique(vals.ravel()).tolist()
32+
33+
class_mapping = self.class_mapping or {v: str(v) for v in values}
34+
35+
if (self.nodata_value in values) and (self.nodata_value not in class_mapping):
36+
class_mapping[self.nodata_value] = "No Data"
37+
38+
im = ax.get_images()[0]
39+
colors = [im.cmap(im.norm(value)) for value in range(len(self.colormap.colors))]
40+
41+
patches = [
42+
mpatches.Patch(color=colors[i], label=class_mapping[i]) for i in values
43+
]
44+
ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)
45+
46+
47+
class LabelledRasterDataset(IntersectionDataset):
48+
def plot(self, sample: dict[str, Any], **kwargs):
49+
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
50+
51+
for i, dataset in enumerate(self.datasets):
52+
if isinstance(dataset, PlottableRasterDataset):
53+
dataset.plot(sample, ax=axes[i], **kwargs)
54+
else:
55+
raise NotImplementedError("Dataset must be plottable")
56+
57+
return fig
58+
59+
60+
class SegmentationRasterDataset:
61+
"""
62+
A dataset for semantic segmentation of raster data.
63+
64+
Args:
65+
images_dir: A directory containing the image files.
66+
labels_dir: A directory containing the label files.
67+
image_glob: A glob pattern to match the image files.
68+
label_glob: A glob pattern to match the label files.
69+
image_kwargs: Keyword arguments to pass to the image dataset.
70+
label_kwargs: Keyword arguments to pass to the label dataset.
71+
72+
This dataset is an intersection of two raster datasets: one for the images and one for the labels.
73+
Both those datasets are based on TorchGeo's RasterDataset (which itself is a subclass of TorchGeo's GeoDataset).
74+
Therefore it comes with strong capabilities for spatial and temporal indexing, however it is also limited
75+
in how it handles global datasets spanning across multiple crs.
76+
77+
"""
78+
79+
@classmethod
80+
def create(
81+
cls,
82+
images_dir: str | Path,
83+
labels_dir: str | Path = None,
84+
image_glob="*.tif",
85+
label_glob="*.tif",
86+
image_kwargs: dict[str, Any] = None,
87+
label_kwargs: dict[str, Any] = None,
88+
):
89+
image_kwargs = image_kwargs or {}
90+
label_kwargs = label_kwargs or {}
91+
92+
image_ds = PlottableRasterDataset(paths=images_dir)
93+
image_ds.filename_glob = image_glob
94+
for k, v in image_kwargs.items():
95+
setattr(image_ds, k, v)
96+
97+
if labels_dir:
98+
label_ds = PlottableRasterDataset(paths=labels_dir)
99+
label_ds.is_image = False
100+
label_ds.filename_glob = label_glob
101+
for k, v in label_kwargs.items():
102+
setattr(label_ds, k, v)
103+
104+
return LabelledRasterDataset(image_ds, label_ds)
105+
106+
else:
107+
return image_ds

eotorch/inference/inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from rasterio.io import BufferedDatasetWriter
99

1010
from eotorch.inference import inference_utils as iu
11+
from eotorch.plot import plot
1112

1213

1314
def predict_on_tif(
@@ -101,5 +102,5 @@ def predict_on_tif(
101102

102103
if show_results:
103104
print(f"Showing results for {out_file_path}")
104-
return iu.plot_predictions_pyplot(out_file_path, classes, ax=ax)
105+
return plot.plot_predictions_pyplot(out_file_path, classes, ax=ax)
105106
return out_file_path

eotorch/inference/inference_utils.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,54 +4,9 @@
44
import numpy as np
55
import rasterio as rst
66
from affine import Affine
7-
from matplotlib import pyplot as plt
87
from rasterio.windows import Window, from_bounds
98

109

11-
def plot_predictions_pyplot(
12-
predictions_path: str | Path, classes: dict[int, str], ax: plt.Axes = None
13-
):
14-
"""
15-
Plot predictions using matplotlib.pyplot.
16-
17-
Parameters
18-
----------
19-
predictions_path: str or Path, path to the predictions file
20-
classes: dict, mapping of class indices to class names
21-
ax: matplotlib.axes.Axes, axes to plot on
22-
"""
23-
24-
import matplotlib.patches as mpatches
25-
import matplotlib.pyplot as plt
26-
27-
if ax is None:
28-
_, ax = plt.subplots(nrows=1, ncols=1)
29-
30-
with rst.open(predictions_path) as src:
31-
predictions = src.read(1)
32-
no_data_val = src.nodata
33-
im = ax.imshow(predictions, interpolation="none")
34-
values = np.unique(predictions.ravel()).tolist()
35-
36-
if no_data_val not in values:
37-
plot_values = values + [no_data_val]
38-
else:
39-
plot_values = values
40-
41-
if not classes:
42-
classes = {v: str(v) for v in values}
43-
if (no_data_val in values) and (no_data_val not in classes):
44-
classes[no_data_val] = "No Data"
45-
46-
colors = [im.cmap(im.norm(value)) for value in plot_values]
47-
patches = [
48-
mpatches.Patch(color=colors[i], label=classes[i])
49-
for i in values
50-
# mpatches.Patch(color=colors[i - 1], label=classes[i - 1]) for i in values
51-
]
52-
ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)
53-
54-
5510
def crop_np_to_window(arr: np.ndarray, w_buffered: Window, w_unbuffered: Window):
5611
left = int(round(w_unbuffered.col_off - w_buffered.col_off))
5712
top = int(round(w_unbuffered.row_off - w_buffered.row_off))

eotorch/plot/plot.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import rasterio as rst
5+
from matplotlib import pyplot as plt
6+
from torch.utils.data import DataLoader
7+
from torchgeo.datasets import GeoDataset, stack_samples, unbind_samples
8+
from torchgeo.samplers import RandomGeoSampler
9+
10+
11+
def plot_predictions_pyplot(
12+
predictions_path: str | Path, classes: dict[int, str], ax: plt.Axes = None
13+
):
14+
"""
15+
Plot predictions using matplotlib.pyplot.
16+
17+
Parameters
18+
----------
19+
predictions_path: str or Path, path to the predictions file
20+
classes: dict, mapping of class indices to class names
21+
ax: matplotlib.axes.Axes, axes to plot on
22+
"""
23+
24+
import matplotlib.patches as mpatches
25+
import matplotlib.pyplot as plt
26+
27+
if ax is None:
28+
_, ax = plt.subplots(nrows=1, ncols=1)
29+
30+
with rst.open(predictions_path) as src:
31+
predictions = src.read(1)
32+
no_data_val = src.nodata
33+
im = ax.imshow(predictions, interpolation="none")
34+
values = np.unique(predictions.ravel()).tolist()
35+
36+
if no_data_val not in values:
37+
plot_values = values + [no_data_val]
38+
else:
39+
plot_values = values
40+
41+
if not classes:
42+
classes = {v: str(v) for v in values}
43+
if (no_data_val in values) and (no_data_val not in classes):
44+
classes[no_data_val] = "No Data"
45+
46+
colors = [im.cmap(im.norm(value)) for value in plot_values]
47+
patches = [
48+
mpatches.Patch(color=colors[i], label=classes[i])
49+
for i in values
50+
# mpatches.Patch(color=colors[i - 1], label=classes[i - 1]) for i in values
51+
]
52+
ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)
53+
54+
55+
def plot_samples(dataset: GeoDataset, n: int = 3, patch_size: int = 256):
56+
assert hasattr(dataset, "plot"), "Dataset must have a plot method"
57+
# sampler = GridGeoSampler(dataset, size=patch_size, stride=100)
58+
sampler = RandomGeoSampler(dataset, size=patch_size)
59+
dataloader = DataLoader(
60+
dataset, sampler=sampler, collate_fn=stack_samples, batch_size=n
61+
)
62+
batch = next(iter(dataloader))
63+
samples = unbind_samples(batch)
64+
for i, sample in enumerate(samples):
65+
dataset.plot(sample, title=f"Sample {i + 1}")
66+
plt.show()

media/sample_1.png

324 KB
Loading

media/sample_2.png

260 KB
Loading

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ exclude = ["notebooks", "tests", "images"]
1616
name = "eotorch"
1717
version = "0.0.1"
1818
dependencies = [
19+
"alive-progress>=3.2.0",
1920
"torchgeo>=0.6.2",
2021
]
2122

0 commit comments

Comments
 (0)