-
Notifications
You must be signed in to change notification settings - Fork 388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trainers: add Instance Segmentation Task #2513
base: main
Are you sure you want to change the base?
Changes from 14 commits
d00c087
52daa1c
68756a7
e249883
7676ac3
0fa7b07
b4334f0
a160baa
fa8697b
f6ceed1
619760b
d9158a0
9f48f50
63aefc8
70074e7
b3de001
d70f1e3
1e68d2d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import torch | ||
import lightning.pytorch as pl | ||
from torch.utils.data import DataLoader, Subset | ||
from torchgeo.datasets import VHR10 | ||
from torchvision.transforms.functional import to_pil_image | ||
from matplotlib.patches import Rectangle | ||
import matplotlib.pyplot as plt | ||
import torch.nn.functional as F | ||
from torchgeo.trainers import InstanceSegmentationTask | ||
|
||
def collate_fn(batch): | ||
Check failure on line 11 in test_trainer_instancesegmentation.py GitHub Actions / ruffRuff (ANN201)
|
||
"""Custom collate function for DataLoader.""" | ||
max_height = max(sample['image'].shape[1] for sample in batch) | ||
max_width = max(sample['image'].shape[2] for sample in batch) | ||
|
||
images = torch.stack([ | ||
F.pad(sample['image'], (0, max_width - sample['image'].shape[2], 0, max_height - sample['image'].shape[1])) | ||
for sample in batch | ||
]) | ||
|
||
targets = [ | ||
{ | ||
"labels": sample["labels"].to(torch.int64), | ||
"boxes": sample["boxes"].to(torch.float32), | ||
"masks": F.pad( | ||
sample["masks"], | ||
(0, max_width - sample["masks"].shape[2], 0, max_height - sample["masks"].shape[1]), | ||
).to(torch.uint8), | ||
} | ||
for sample in batch | ||
] | ||
|
||
return {"image": images, "target": targets} | ||
|
||
def visualize_predictions(image, predictions, targets): | ||
Check failure on line 35 in test_trainer_instancesegmentation.py GitHub Actions / ruffRuff (ANN201)
Check failure on line 35 in test_trainer_instancesegmentation.py GitHub Actions / ruffRuff (ANN001)
Check failure on line 35 in test_trainer_instancesegmentation.py GitHub Actions / ruffRuff (ANN001)
|
||
"""Visualize predictions and ground truth.""" | ||
image = to_pil_image(image) | ||
|
||
fig, ax = plt.subplots(1, 1, figsize=(10, 10)) | ||
ax.imshow(image) | ||
|
||
# Predictions | ||
for box, label in zip(predictions['boxes'], predictions['labels']): | ||
x1, y1, x2, y2 = box | ||
rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none') | ||
ax.add_patch(rect) | ||
ax.text(x1, y1, f"Pred: {label.item()}", color='red', fontsize=12) | ||
|
||
# Ground truth | ||
for box, label in zip(targets['boxes'], targets['labels']): | ||
x1, y1, x2, y2 = box | ||
rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='blue', facecolor='none') | ||
ax.add_patch(rect) | ||
ax.text(x1, y1, f"GT: {label.item()}", color='blue', fontsize=12) | ||
|
||
plt.show() | ||
|
||
def plot_losses(train_losses, val_losses): | ||
Check failure on line 58 in test_trainer_instancesegmentation.py GitHub Actions / ruffRuff (ANN201)
|
||
"""Plot training and validation losses over epochs.""" | ||
plt.figure(figsize=(10, 5)) | ||
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', marker='o') | ||
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss', marker='s') | ||
plt.xlabel('Epochs') | ||
plt.ylabel('Loss') | ||
plt.title('Training and Validation Loss Over Epochs') | ||
plt.legend() | ||
plt.grid() | ||
plt.show() | ||
|
||
# Initialize VHR-10 dataset | ||
train_dataset = VHR10(root="data", split="positive", transforms=None, download=True) | ||
val_dataset = VHR10(root="data", split="positive", transforms=None) | ||
|
||
# Subset for quick experimentation (adjust N as needed) | ||
N = 100 | ||
train_subset = Subset(train_dataset, list(range(N))) | ||
val_subset = Subset(val_dataset, list(range(N))) | ||
|
||
|
||
if __name__ == '__main__': | ||
import multiprocessing | ||
multiprocessing.set_start_method('spawn', force=True) | ||
|
||
train_loader = DataLoader(train_subset, batch_size=8, shuffle=True, num_workers=1, collate_fn=collate_fn) | ||
val_loader = DataLoader(val_subset, batch_size=8, shuffle=False, num_workers=1, collate_fn=collate_fn) | ||
|
||
# Trainer setup | ||
trainer = pl.Trainer( | ||
max_epochs=5, | ||
accelerator="gpu" if torch.cuda.is_available() else "cpu", | ||
devices=1 | ||
) | ||
|
||
task = InstanceSegmentationTask( | ||
model="mask_rcnn", | ||
backbone="resnet50", | ||
weights="imagenet", # Pretrained on ImageNet | ||
num_classes=11, # VHR-10 has 10 classes + 1 background | ||
lr=1e-3, | ||
freeze_backbone=False | ||
) | ||
|
||
print('\nSTART TRAINING\n') | ||
# trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader) | ||
train_losses, val_losses = [], [] | ||
for epoch in range(5): | ||
trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader) | ||
train_loss = task.trainer.callback_metrics.get("train_loss") | ||
val_loss = task.trainer.callback_metrics.get("val_loss") | ||
if train_loss is not None: | ||
train_losses.append(train_loss.item()) | ||
if val_loss is not None: | ||
val_losses.append(val_loss.item()) | ||
|
||
plot_losses(train_losses, val_losses) | ||
|
||
#trainer.test(task, dataloaders=val_loader) | ||
|
||
# Inference and Visualization | ||
sample = train_dataset[1] | ||
image = sample['image'].unsqueeze(0) | ||
predictions = task.predict_step({"image": image}, batch_idx=0) | ||
visualize_predictions(image[0], predictions[0], sample) |
adamjstewart marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's rename to |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
from pathlib import Path | ||
from typing import Any, cast | ||
|
||
import pytest | ||
import segmentation_models_pytorch as smp | ||
import timm | ||
import torch | ||
import torch.nn as nn | ||
from lightning.pytorch import Trainer | ||
from pytest import MonkeyPatch | ||
from torch.nn.modules import Module | ||
from torchvision.models._api import WeightsEnum | ||
|
||
from torchgeo.datamodules import MisconfigurationException, SEN12MSDataModule | ||
from torchgeo.datasets import LandCoverAI, RGBBandsMissingError | ||
from torchgeo.main import main | ||
from torchgeo.models import ResNet18_Weights | ||
from torchgeo.trainers import InstanceSegmentationTask | ||
|
||
|
||
class SegmentationTestModel(Module): | ||
def __init__(self, in_channels: int = 3, classes: int = 3, **kwargs: Any) -> None: | ||
super().__init__() | ||
self.conv1 = nn.Conv2d( | ||
in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0 | ||
) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return cast(torch.Tensor, self.conv1(x)) | ||
|
||
|
||
def create_model(**kwargs: Any) -> Module: | ||
return SegmentationTestModel(**kwargs) | ||
|
||
|
||
def plot(*args: Any, **kwargs: Any) -> None: | ||
return None | ||
|
||
|
||
def plot_missing_bands(*args: Any, **kwargs: Any) -> None: | ||
raise RGBBandsMissingError() | ||
|
||
|
||
class TestSemanticSegmentationTask: | ||
@pytest.mark.parametrize( | ||
'name', | ||
[ | ||
'agrifieldnet', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a list of files whose configuration is present in |
||
'cabuar', | ||
'chabud', | ||
'chesapeake_cvpr_5', | ||
'chesapeake_cvpr_7', | ||
'deepglobelandcover', | ||
'etci2021', | ||
'ftw', | ||
'geonrw', | ||
'gid15', | ||
'inria', | ||
'l7irish', | ||
'l8biome', | ||
'landcoverai', | ||
'landcoverai100', | ||
'loveda', | ||
'naipchesapeake', | ||
'potsdam2d', | ||
'sen12ms_all', | ||
'sen12ms_s1', | ||
'sen12ms_s2_all', | ||
'sen12ms_s2_reduced', | ||
'sentinel2_cdl', | ||
'sentinel2_eurocrops', | ||
'sentinel2_nccm', | ||
'sentinel2_south_america_soybean', | ||
'southafricacroptype', | ||
'spacenet1', | ||
'spacenet6', | ||
'ssl4eo_l_benchmark_cdl', | ||
'ssl4eo_l_benchmark_nlcd', | ||
'vaihingen2d', | ||
], | ||
) | ||
def test_trainer( | ||
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool | ||
) -> None: | ||
match name: | ||
case 'chabud' | 'cabuar': | ||
pytest.importorskip('h5py', minversion='3.6') | ||
case 'ftw': | ||
pytest.importorskip('pyarrow') | ||
case 'landcoverai': | ||
sha256 = ( | ||
'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b' | ||
) | ||
monkeypatch.setattr(LandCoverAI, 'sha256', sha256) | ||
Comment on lines
+89
to
+98
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can be removed if you aren't using any of these datasets. |
||
|
||
config = os.path.join('tests', 'conf', name + '.yaml') | ||
|
||
monkeypatch.setattr(smp, 'Unet', create_model) | ||
monkeypatch.setattr(smp, 'DeepLabV3Plus', create_model) | ||
Comment on lines
+102
to
+103
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We aren't using these models (or even smp) for this task, this can be removed |
||
|
||
args = [ | ||
'--config', | ||
config, | ||
'--trainer.accelerator', | ||
'cpu', | ||
'--trainer.fast_dev_run', | ||
str(fast_dev_run), | ||
'--trainer.max_epochs', | ||
'1', | ||
'--trainer.log_every_n_steps', | ||
'1', | ||
] | ||
|
||
main(['fit', *args]) | ||
try: | ||
main(['test', *args]) | ||
except MisconfigurationException: | ||
pass | ||
try: | ||
main(['predict', *args]) | ||
except MisconfigurationException: | ||
pass | ||
|
||
@pytest.fixture | ||
def weights(self) -> WeightsEnum: | ||
return ResNet18_Weights.SENTINEL2_ALL_MOCO | ||
|
||
@pytest.fixture | ||
def mocked_weights( | ||
self, | ||
tmp_path: Path, | ||
monkeypatch: MonkeyPatch, | ||
weights: WeightsEnum, | ||
load_state_dict_from_url: None, | ||
) -> WeightsEnum: | ||
path = tmp_path / f'{weights}.pth' | ||
model = timm.create_model( | ||
weights.meta['model'], in_chans=weights.meta['in_chans'] | ||
) | ||
torch.save(model.state_dict(), path) | ||
try: | ||
monkeypatch.setattr(weights.value, 'url', str(path)) | ||
except AttributeError: | ||
monkeypatch.setattr(weights, 'url', str(path)) | ||
return weights | ||
|
||
def test_weight_file(self, checkpoint: str) -> None: | ||
InstanceSegmentationTask(backbone='resnet18', weights=checkpoint, num_classes=6) | ||
|
||
def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: | ||
InstanceSegmentationTask( | ||
backbone=mocked_weights.meta['model'], | ||
weights=mocked_weights, | ||
in_channels=mocked_weights.meta['in_chans'], | ||
) | ||
|
||
def test_weight_str(self, mocked_weights: WeightsEnum) -> None: | ||
InstanceSegmentationTask( | ||
backbone=mocked_weights.meta['model'], | ||
weights=str(mocked_weights), | ||
in_channels=mocked_weights.meta['in_chans'], | ||
) | ||
|
||
@pytest.mark.slow | ||
def test_weight_enum_download(self, weights: WeightsEnum) -> None: | ||
InstanceSegmentationTask( | ||
backbone=weights.meta['model'], | ||
weights=weights, | ||
in_channels=weights.meta['in_chans'], | ||
) | ||
|
||
@pytest.mark.slow | ||
def test_weight_str_download(self, weights: WeightsEnum) -> None: | ||
InstanceSegmentationTask( | ||
backbone=weights.meta['model'], | ||
weights=str(weights), | ||
in_channels=weights.meta['in_chans'], | ||
) | ||
|
||
def test_invalid_model(self) -> None: | ||
match = "Model type 'invalid_model' is not valid." | ||
with pytest.raises(ValueError, match=match): | ||
InstanceSegmentationTask(model='invalid_model') | ||
|
||
def test_invalid_loss(self) -> None: | ||
match = "Loss type 'invalid_loss' is not valid." | ||
with pytest.raises(ValueError, match=match): | ||
InstanceSegmentationTask(loss='invalid_loss') | ||
|
||
def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: | ||
monkeypatch.setattr(SEN12MSDataModule, 'plot', plot) | ||
datamodule = SEN12MSDataModule( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This datamodule isn't compatible with instance segmentation |
||
root='tests/data/sen12ms', batch_size=1, num_workers=0 | ||
) | ||
model = InstanceSegmentationTask( | ||
backbone='resnet18', in_channels=15, num_classes=6 | ||
) | ||
trainer = Trainer( | ||
accelerator='cpu', | ||
fast_dev_run=fast_dev_run, | ||
log_every_n_steps=1, | ||
max_epochs=1, | ||
) | ||
trainer.validate(model=model, datamodule=datamodule) | ||
|
||
def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: | ||
monkeypatch.setattr(SEN12MSDataModule, 'plot', plot_missing_bands) | ||
datamodule = SEN12MSDataModule( | ||
root='tests/data/sen12ms', batch_size=1, num_workers=0 | ||
) | ||
model = InstanceSegmentationTask( | ||
backbone='resnet18', in_channels=15, num_classes=6 | ||
) | ||
trainer = Trainer( | ||
accelerator='cpu', | ||
fast_dev_run=fast_dev_run, | ||
log_every_n_steps=1, | ||
max_epochs=1, | ||
) | ||
trainer.validate(model=model, datamodule=datamodule) | ||
|
||
@pytest.mark.parametrize('model_name', ['unet', 'deeplabv3+']) | ||
@pytest.mark.parametrize( | ||
'backbone', ['resnet18', 'mobilenet_v2', 'efficientnet-b0'] | ||
) | ||
def test_freeze_backbone(self, model_name: str, backbone: str) -> None: | ||
model = InstanceSegmentationTask( | ||
model=model_name, backbone=backbone, freeze_backbone=True | ||
) | ||
assert all( | ||
[param.requires_grad is False for param in model.model.encoder.parameters()] | ||
) | ||
assert all([param.requires_grad for param in model.model.decoder.parameters()]) | ||
assert all( | ||
[ | ||
param.requires_grad | ||
for param in model.model.segmentation_head.parameters() | ||
] | ||
) | ||
|
||
# @pytest.mark.parametrize('model_name', ['unet', 'deeplabv3+']) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can remove commented-out code |
||
# def test_freeze_decoder(self, model_name: str) -> None: | ||
# model = InstanceSegmentationTask(model=model_name, freeze_decoder=True) | ||
# assert all( | ||
# [param.requires_grad is False for param in model.model.decoder.parameters()] | ||
# ) | ||
# assert all([param.requires_grad for param in model.model.encoder.parameters()]) | ||
# assert all( | ||
# [ | ||
# param.requires_grad | ||
# for param in model.model.segmentation_head.parameters() | ||
# ] | ||
# ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file can be deleted