Skip to content

Commit 464e45d

Browse files
Trainers: add Instance Segmentation Task (microsoft#2513)
* Add files via upload * Add files via upload * Update instancesegmentation.py * Update and rename instancesegmentation.py to instance_segmentation.py * Update test_instancesegmentation.py * Update instance_segmentation.py * Update __init__.py * Update instance_segmentation.py * Update instance_segmentation.py * Add files via upload * Update test_instancesegmentation.py * Update and rename test_instancesegmentation.py to test_trainer_instancesegmentation.py * Update instance_segmentation.py * Add files via upload * Creato con Colab * Creato con Colab * Creato con Colab * Update instance_segmentation.py * Delete test_trainer.ipynb * Delete test_trainer_instancesegmentation.py * Update and rename test_instancesegmentation.py to test_instance_segmentation.py * Update instance_segmentation.py * Update test_instance_segmentation.py * Update instance_segmentation.py * Update instance_segmentation.py * Update instance_segmentation.py run ruff * Ruff * dos2unix * Add support for MSI, weights * Update tests * timm and torchvision are not compatible * Finalize trainer code, simpler * Update VHR10 tests * Uniformity * Fix most tests * 100% coverage * Fix datasets tests * Fix weight tests * Fix MSI support * Fix parameter replacement * Fix minimum tests * Fix minimum tests * Add all unpacked data * Fix tests * Undo FTW changes * Undo FTW changes * Undo FTW changes * Remove dead code * Remove dead code, match detection style * Try newer torchmetrics * Try newer torchmetrics * Try newer torchmetrics * More metrics * Fix mypy * Fix and test weights=True, num_classes!=91 --------- Co-authored-by: Adam J. Stewart <[email protected]>
1 parent 5eb2a5e commit 464e45d

File tree

12 files changed

+422
-35
lines changed

12 files changed

+422
-35
lines changed

pyproject.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,16 @@ dependencies = [
6767
"rasterio>=1.3,!=1.4.0,!=1.4.1,!=1.4.2",
6868
# rtree 1+ required for Python 3.10 wheels
6969
"rtree>=1",
70-
# segmentation-models-pytorch 0.2+ required for smp.losses module
71-
"segmentation-models-pytorch>=0.2",
70+
# segmentation-models-pytorch 0.3.3+ required for timm 0.8+ support
71+
"segmentation-models-pytorch>=0.3.3",
7272
# shapely 1.8+ required for Python 3.10 wheels
7373
"shapely>=1.8",
74-
# timm 0.4.12 required by segmentation-models-pytorch
75-
"timm>=0.4.12",
74+
# timm 0.8+ required for timm.models.adapt_input_conv, 0.9.2 required by SMP
75+
"timm>=0.9.2",
7676
# torch 1.13+ required by torchvision
7777
"torch>=1.13",
78-
# torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics
79-
"torchmetrics>=0.10",
78+
# torchmetrics 1.2+ required for average argument in mAP metric
79+
"torchmetrics>=1.2",
8080
# torchvision 0.14+ required for torchvision.models.swin_v2_b
8181
"torchvision>=0.14",
8282
# typing-extensions 4.5+ required for typing_extensions.deprecated

requirements/min-reqs.old

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ pillow==8.4.0
1414
pyproj==3.3.0
1515
rasterio==1.3.0.post1
1616
rtree==1.0.0
17-
segmentation-models-pytorch==0.2.0
17+
segmentation-models-pytorch==0.3.3
1818
shapely==1.8.0
19-
timm==0.4.12
19+
timm==0.9.2
2020
torch==1.13.0
21-
torchmetrics==0.10.0
21+
torchmetrics==1.2.0
2222
torchvision==0.14.0
2323
typing-extensions==4.5.0
2424

tests/conf/vhr10_ins_seg.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
model:
2+
class_path: InstanceSegmentationTask
3+
init_args:
4+
model: 'mask-rcnn'
5+
backbone: 'resnet50'
6+
num_classes: 11
7+
data:
8+
class_path: VHR10DataModule
9+
init_args:
10+
batch_size: 1
11+
num_workers: 0
12+
patch_size: 4
13+
dict_kwargs:
14+
root: 'tests/data/vhr10'

tests/conf/vhr10.yaml renamed to tests/conf/vhr10_obj_det.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@ model:
22
class_path: ObjectDetectionTask
33
init_args:
44
model: 'faster-rcnn'
5-
backbone: 'resnet50'
5+
backbone: 'resnet18'
66
num_classes: 11
7-
lr: 2.5e-5
8-
patience: 10
97
data:
108
class_path: VHR10DataModule
119
init_args:

tests/datasets/test_vhr10.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ def test_plot(self, dataset: VHR10) -> None:
8282
scores = [0.7, 0.3, 0.7]
8383
for i in range(3):
8484
x = dataset[i]
85-
x['prediction_labels'] = x['label']
85+
x['prediction_label'] = x['label']
8686
x['prediction_bbox_xyxy'] = x['bbox_xyxy']
87-
x['prediction_scores'] = torch.Tensor([scores[i]])
87+
x['prediction_score'] = torch.Tensor([scores[i]])
8888
if 'mask' in x:
89-
x['prediction_masks'] = x['mask']
89+
x['prediction_mask'] = x['mask']
9090
dataset.plot(x, show_feats='masks')
9191
plt.close()

tests/trainers/test_detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def plot(*args: Any, **kwargs: Any) -> None:
6767

6868

6969
class TestObjectDetectionTask:
70-
@pytest.mark.parametrize('name', ['nasa_marine_debris', 'vhr10'])
70+
@pytest.mark.parametrize('name', ['nasa_marine_debris', 'vhr10_obj_det'])
7171
@pytest.mark.parametrize('model_name', ['faster-rcnn', 'fcos', 'retinanet'])
7272
def test_trainer(
7373
self, monkeypatch: MonkeyPatch, name: str, model_name: str, fast_dev_run: bool
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
import os
5+
from typing import Any
6+
7+
import pytest
8+
from lightning.pytorch import Trainer
9+
from pytest import MonkeyPatch
10+
11+
from torchgeo.datamodules import MisconfigurationException, VHR10DataModule
12+
from torchgeo.datasets import VHR10, RGBBandsMissingError
13+
from torchgeo.main import main
14+
from torchgeo.trainers import InstanceSegmentationTask
15+
16+
# mAP metric requires pycocotools to be installed
17+
pytest.importorskip('pycocotools')
18+
19+
20+
class PredictInstanceSegmentationDataModule(VHR10DataModule):
21+
def setup(self, stage: str) -> None:
22+
self.predict_dataset = VHR10(**self.kwargs)
23+
24+
25+
def plot(*args: Any, **kwargs: Any) -> None:
26+
return None
27+
28+
29+
def plot_missing_bands(*args: Any, **kwargs: Any) -> None:
30+
raise RGBBandsMissingError()
31+
32+
33+
class TestInstanceSegmentationTask:
34+
@pytest.mark.parametrize('name', ['vhr10_ins_seg'])
35+
def test_trainer(
36+
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
37+
) -> None:
38+
config = os.path.join('tests', 'conf', name + '.yaml')
39+
40+
args = [
41+
'--config',
42+
config,
43+
'--trainer.accelerator',
44+
'cpu',
45+
'--trainer.fast_dev_run',
46+
str(fast_dev_run),
47+
'--trainer.max_epochs',
48+
'1',
49+
'--trainer.log_every_n_steps',
50+
'1',
51+
]
52+
53+
main(['fit', *args])
54+
try:
55+
main(['test', *args])
56+
except MisconfigurationException:
57+
pass
58+
try:
59+
main(['predict', *args])
60+
except MisconfigurationException:
61+
pass
62+
63+
def test_invalid_model(self) -> None:
64+
match = 'Invalid model type'
65+
with pytest.raises(ValueError, match=match):
66+
InstanceSegmentationTask(model='invalid_model')
67+
68+
def test_invalid_backbone(self) -> None:
69+
match = 'Invalid backbone type'
70+
with pytest.raises(ValueError, match=match):
71+
InstanceSegmentationTask(backbone='invalid_backbone')
72+
73+
def test_weights(self) -> None:
74+
InstanceSegmentationTask(weights=True, num_classes=3)
75+
InstanceSegmentationTask(weights=True, num_classes=91)
76+
77+
def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
78+
monkeypatch.setattr(VHR10DataModule, 'plot', plot)
79+
datamodule = VHR10DataModule(
80+
root='tests/data/vhr10', batch_size=1, num_workers=0
81+
)
82+
model = InstanceSegmentationTask(in_channels=3, num_classes=11)
83+
trainer = Trainer(
84+
accelerator='cpu',
85+
fast_dev_run=fast_dev_run,
86+
log_every_n_steps=1,
87+
max_epochs=1,
88+
)
89+
trainer.validate(model=model, datamodule=datamodule)
90+
91+
def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
92+
monkeypatch.setattr(VHR10DataModule, 'plot', plot_missing_bands)
93+
datamodule = VHR10DataModule(
94+
root='tests/data/vhr10', batch_size=1, num_workers=0
95+
)
96+
model = InstanceSegmentationTask(in_channels=3, num_classes=11)
97+
trainer = Trainer(
98+
accelerator='cpu',
99+
fast_dev_run=fast_dev_run,
100+
log_every_n_steps=1,
101+
max_epochs=1,
102+
)
103+
trainer.validate(model=model, datamodule=datamodule)
104+
105+
def test_predict(self, fast_dev_run: bool) -> None:
106+
datamodule = PredictInstanceSegmentationDataModule(
107+
root='tests/data/vhr10', batch_size=1, num_workers=0
108+
)
109+
model = InstanceSegmentationTask(num_classes=11)
110+
trainer = Trainer(
111+
accelerator='cpu',
112+
fast_dev_run=fast_dev_run,
113+
log_every_n_steps=1,
114+
max_epochs=1,
115+
)
116+
trainer.predict(model=model, datamodule=datamodule)
117+
118+
def test_freeze_backbone(self) -> None:
119+
task = InstanceSegmentationTask(freeze_backbone=True)
120+
for param in task.model.backbone.parameters():
121+
assert param.requires_grad is False
122+
123+
for head in ['rpn', 'roi_heads']:
124+
for param in getattr(task.model, head).parameters():
125+
assert param.requires_grad is True

torchgeo/datasets/vhr10.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
250250
sample = self.coco_convert(sample)
251251
sample['class'] = sample['label']['labels']
252252
sample['bbox_xyxy'] = sample['label']['boxes']
253-
sample['mask'] = sample['label']['masks'].float()
253+
sample['mask'] = sample['label']['masks']
254254
sample['label'] = sample.pop('class')
255255

256256
if self.transforms is not None:
@@ -408,21 +408,21 @@ def plot(
408408
n_gt = len(boxes)
409409

410410
ncols = 1
411-
show_predictions = 'prediction_labels' in sample
411+
show_predictions = 'prediction_label' in sample
412412

413413
if show_predictions:
414414
show_pred_boxes = False
415415
show_pred_masks = False
416-
prediction_labels = sample['prediction_labels'].numpy()
417-
prediction_scores = sample['prediction_scores'].numpy()
416+
prediction_label = sample['prediction_label'].numpy()
417+
prediction_score = sample['prediction_score'].numpy()
418418
if 'prediction_bbox_xyxy' in sample:
419419
prediction_bbox_xyxy = sample['prediction_bbox_xyxy'].numpy()
420420
show_pred_boxes = True
421-
if 'prediction_masks' in sample:
422-
prediction_masks = sample['prediction_masks'].numpy()
421+
if 'prediction_mask' in sample:
422+
prediction_mask = sample['prediction_mask'].numpy()
423423
show_pred_masks = True
424424

425-
n_pred = len(prediction_labels)
425+
n_pred = len(prediction_label)
426426
ncols += 1
427427

428428
# Display image
@@ -475,11 +475,11 @@ def plot(
475475
axs[0, 1].imshow(image)
476476
axs[0, 1].axis('off')
477477
for i in range(n_pred):
478-
score = prediction_scores[i]
478+
score = prediction_score[i]
479479
if score < 0.5:
480480
continue
481481

482-
class_num = prediction_labels[i]
482+
class_num = prediction_label[i]
483483
color = cm(class_num / len(self.categories))
484484

485485
if show_pred_boxes:
@@ -511,7 +511,7 @@ def plot(
511511

512512
# Add masks
513513
if show_pred_masks:
514-
mask = prediction_masks[i]
514+
mask = prediction_mask[i]
515515
contours = skimage.measure.find_contours(mask, 0.5)
516516
for verts in contours:
517517
verts = np.fliplr(verts)

torchgeo/models/fcsiam.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,7 @@ def __init__(
7575
)
7676
encoder_out_channels = [c * 2 for c in self.encoder.out_channels[1:]]
7777
encoder_out_channels.insert(0, self.encoder.out_channels[0])
78-
try:
79-
# smp 0.3+
80-
UnetDecoder = smp.decoders.unet.decoder.UnetDecoder
81-
except AttributeError:
82-
# smp 0.2
83-
UnetDecoder = smp.unet.decoder.UnetDecoder
84-
self.decoder = UnetDecoder(
78+
self.decoder = smp.decoders.unet.decoder.UnetDecoder(
8579
encoder_channels=encoder_out_channels,
8680
decoder_channels=decoder_channels,
8781
n_blocks=encoder_depth,

torchgeo/trainers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .byol import BYOLTask
88
from .classification import ClassificationTask, MultiLabelClassificationTask
99
from .detection import ObjectDetectionTask
10+
from .instance_segmentation import InstanceSegmentationTask
1011
from .iobench import IOBenchTask
1112
from .moco import MoCoTask
1213
from .regression import PixelwiseRegressionTask, RegressionTask
@@ -18,6 +19,7 @@
1819
'BaseTask',
1920
'ClassificationTask',
2021
'IOBenchTask',
22+
'InstanceSegmentationTask',
2123
'MoCoTask',
2224
'MultiLabelClassificationTask',
2325
'ObjectDetectionTask',

0 commit comments

Comments
 (0)