Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainers: add Instance Segmentation Task #2513

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions test_trainer_instancesegmentation.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file can be deleted

Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import torch

Check failure on line 1 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D100)

test_trainer_instancesegmentation.py:1:1: D100 Missing docstring in public module
import lightning.pytorch as pl
from torch.utils.data import DataLoader, Subset
from torchgeo.datasets import VHR10
from torchvision.transforms.functional import to_pil_image
from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchgeo.trainers import InstanceSegmentationTask

Check failure on line 9 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

test_trainer_instancesegmentation.py:1:1: I001 Import block is un-sorted or un-formatted

def collate_fn(batch):

Check failure on line 11 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN201)

test_trainer_instancesegmentation.py:11:5: ANN201 Missing return type annotation for public function `collate_fn`

Check failure on line 11 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

test_trainer_instancesegmentation.py:11:16: ANN001 Missing type annotation for function argument `batch`
"""Custom collate function for DataLoader."""
max_height = max(sample['image'].shape[1] for sample in batch)
max_width = max(sample['image'].shape[2] for sample in batch)

images = torch.stack([
F.pad(sample['image'], (0, max_width - sample['image'].shape[2], 0, max_height - sample['image'].shape[1]))
for sample in batch
])

targets = [
{
"labels": sample["labels"].to(torch.int64),
"boxes": sample["boxes"].to(torch.float32),
"masks": F.pad(
sample["masks"],
(0, max_width - sample["masks"].shape[2], 0, max_height - sample["masks"].shape[1]),
).to(torch.uint8),
}
for sample in batch
]

return {"image": images, "target": targets}

def visualize_predictions(image, predictions, targets):

Check failure on line 35 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN201)

test_trainer_instancesegmentation.py:35:5: ANN201 Missing return type annotation for public function `visualize_predictions`

Check failure on line 35 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

test_trainer_instancesegmentation.py:35:27: ANN001 Missing type annotation for function argument `image`

Check failure on line 35 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

test_trainer_instancesegmentation.py:35:34: ANN001 Missing type annotation for function argument `predictions`

Check failure on line 35 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

test_trainer_instancesegmentation.py:35:47: ANN001 Missing type annotation for function argument `targets`
"""Visualize predictions and ground truth."""
image = to_pil_image(image)

fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(image)

# Predictions
for box, label in zip(predictions['boxes'], predictions['labels']):
x1, y1, x2, y2 = box
rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none')
ax.add_patch(rect)
ax.text(x1, y1, f"Pred: {label.item()}", color='red', fontsize=12)

# Ground truth
for box, label in zip(targets['boxes'], targets['labels']):
x1, y1, x2, y2 = box
rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='blue', facecolor='none')
ax.add_patch(rect)
ax.text(x1, y1, f"GT: {label.item()}", color='blue', fontsize=12)

plt.show()

def plot_losses(train_losses, val_losses):

Check failure on line 58 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN201)

test_trainer_instancesegmentation.py:58:5: ANN201 Missing return type annotation for public function `plot_losses`

Check failure on line 58 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

test_trainer_instancesegmentation.py:58:17: ANN001 Missing type annotation for function argument `train_losses`
"""Plot training and validation losses over epochs."""
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', marker='o')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss', marker='s')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.legend()
plt.grid()
plt.show()

# Initialize VHR-10 dataset
train_dataset = VHR10(root="data", split="positive", transforms=None, download=True)
val_dataset = VHR10(root="data", split="positive", transforms=None)

# Subset for quick experimentation (adjust N as needed)
N = 100
train_subset = Subset(train_dataset, list(range(N)))
val_subset = Subset(val_dataset, list(range(N)))


if __name__ == '__main__':
import multiprocessing
multiprocessing.set_start_method('spawn', force=True)

train_loader = DataLoader(train_subset, batch_size=8, shuffle=True, num_workers=1, collate_fn=collate_fn)
val_loader = DataLoader(val_subset, batch_size=8, shuffle=False, num_workers=1, collate_fn=collate_fn)

# Trainer setup
trainer = pl.Trainer(
max_epochs=5,
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=1
)

task = InstanceSegmentationTask(
model="mask_rcnn",
backbone="resnet50",
weights="imagenet", # Pretrained on ImageNet
num_classes=11, # VHR-10 has 10 classes + 1 background
lr=1e-3,
freeze_backbone=False
)

print('\nSTART TRAINING\n')
# trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)
train_losses, val_losses = [], []
for epoch in range(5):
trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)
train_loss = task.trainer.callback_metrics.get("train_loss")
val_loss = task.trainer.callback_metrics.get("val_loss")
if train_loss is not None:
train_losses.append(train_loss.item())
if val_loss is not None:
val_losses.append(val_loss.item())

plot_losses(train_losses, val_losses)

#trainer.test(task, dataloaders=val_loader)

# Inference and Visualization
sample = train_dataset[1]
image = sample['image'].unsqueeze(0)
predictions = task.predict_step({"image": image}, batch_idx=0)
visualize_predictions(image[0], predictions[0], sample)
257 changes: 257 additions & 0 deletions tests/trainers/test_instancesegmentation.py
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename to test_instance_segmentation.py to match the other filename

Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from pathlib import Path
from typing import Any, cast

import pytest
import segmentation_models_pytorch as smp
import timm
import torch
import torch.nn as nn
from lightning.pytorch import Trainer
from pytest import MonkeyPatch
from torch.nn.modules import Module
from torchvision.models._api import WeightsEnum

from torchgeo.datamodules import MisconfigurationException, SEN12MSDataModule
from torchgeo.datasets import LandCoverAI, RGBBandsMissingError
from torchgeo.main import main
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import InstanceSegmentationTask


class SegmentationTestModel(Module):
def __init__(self, in_channels: int = 3, classes: int = 3, **kwargs: Any) -> None:
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, self.conv1(x))


def create_model(**kwargs: Any) -> Module:
return SegmentationTestModel(**kwargs)


def plot(*args: Any, **kwargs: Any) -> None:
return None


def plot_missing_bands(*args: Any, **kwargs: Any) -> None:
raise RGBBandsMissingError()


class TestSemanticSegmentationTask:
@pytest.mark.parametrize(
'name',
[
'agrifieldnet',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a list of files whose configuration is present in tests/conf. This current list is from the semantic segmentation tests. We need to create a new list (and new tests/conf/*.yaml files) for instance segmentation datasets. We can start with something like VHR-10 and work from there.

'cabuar',
'chabud',
'chesapeake_cvpr_5',
'chesapeake_cvpr_7',
'deepglobelandcover',
'etci2021',
'ftw',
'geonrw',
'gid15',
'inria',
'l7irish',
'l8biome',
'landcoverai',
'landcoverai100',
'loveda',
'naipchesapeake',
'potsdam2d',
'sen12ms_all',
'sen12ms_s1',
'sen12ms_s2_all',
'sen12ms_s2_reduced',
'sentinel2_cdl',
'sentinel2_eurocrops',
'sentinel2_nccm',
'sentinel2_south_america_soybean',
'southafricacroptype',
'spacenet1',
'spacenet6',
'ssl4eo_l_benchmark_cdl',
'ssl4eo_l_benchmark_nlcd',
'vaihingen2d',
],
)
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
match name:
case 'chabud' | 'cabuar':
pytest.importorskip('h5py', minversion='3.6')
case 'ftw':
pytest.importorskip('pyarrow')
case 'landcoverai':
sha256 = (
'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
)
monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
Comment on lines +89 to +98
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be removed if you aren't using any of these datasets.


config = os.path.join('tests', 'conf', name + '.yaml')

monkeypatch.setattr(smp, 'Unet', create_model)
monkeypatch.setattr(smp, 'DeepLabV3Plus', create_model)
Comment on lines +102 to +103
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We aren't using these models (or even smp) for this task, this can be removed


args = [
'--config',
config,
'--trainer.accelerator',
'cpu',
'--trainer.fast_dev_run',
str(fast_dev_run),
'--trainer.max_epochs',
'1',
'--trainer.log_every_n_steps',
'1',
]

main(['fit', *args])
try:
main(['test', *args])
except MisconfigurationException:
pass
try:
main(['predict', *args])
except MisconfigurationException:
pass

@pytest.fixture
def weights(self) -> WeightsEnum:
return ResNet18_Weights.SENTINEL2_ALL_MOCO

@pytest.fixture
def mocked_weights(
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
weights.meta['model'], in_chans=weights.meta['in_chans']
)
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
return weights

def test_weight_file(self, checkpoint: str) -> None:
InstanceSegmentationTask(backbone='resnet18', weights=checkpoint, num_classes=6)

def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
InstanceSegmentationTask(
backbone=mocked_weights.meta['model'],
weights=mocked_weights,
in_channels=mocked_weights.meta['in_chans'],
)

def test_weight_str(self, mocked_weights: WeightsEnum) -> None:
InstanceSegmentationTask(
backbone=mocked_weights.meta['model'],
weights=str(mocked_weights),
in_channels=mocked_weights.meta['in_chans'],
)

@pytest.mark.slow
def test_weight_enum_download(self, weights: WeightsEnum) -> None:
InstanceSegmentationTask(
backbone=weights.meta['model'],
weights=weights,
in_channels=weights.meta['in_chans'],
)

@pytest.mark.slow
def test_weight_str_download(self, weights: WeightsEnum) -> None:
InstanceSegmentationTask(
backbone=weights.meta['model'],
weights=str(weights),
in_channels=weights.meta['in_chans'],
)

def test_invalid_model(self) -> None:
match = "Model type 'invalid_model' is not valid."
with pytest.raises(ValueError, match=match):
InstanceSegmentationTask(model='invalid_model')

def test_invalid_loss(self) -> None:
match = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=match):
InstanceSegmentationTask(loss='invalid_loss')

def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(SEN12MSDataModule, 'plot', plot)
datamodule = SEN12MSDataModule(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This datamodule isn't compatible with instance segmentation

root='tests/data/sen12ms', batch_size=1, num_workers=0
)
model = InstanceSegmentationTask(
backbone='resnet18', in_channels=15, num_classes=6
)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(SEN12MSDataModule, 'plot', plot_missing_bands)
datamodule = SEN12MSDataModule(
root='tests/data/sen12ms', batch_size=1, num_workers=0
)
model = InstanceSegmentationTask(
backbone='resnet18', in_channels=15, num_classes=6
)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

@pytest.mark.parametrize('model_name', ['unet', 'deeplabv3+'])
@pytest.mark.parametrize(
'backbone', ['resnet18', 'mobilenet_v2', 'efficientnet-b0']
)
def test_freeze_backbone(self, model_name: str, backbone: str) -> None:
model = InstanceSegmentationTask(
model=model_name, backbone=backbone, freeze_backbone=True
)
assert all(
[param.requires_grad is False for param in model.model.encoder.parameters()]
)
assert all([param.requires_grad for param in model.model.decoder.parameters()])
assert all(
[
param.requires_grad
for param in model.model.segmentation_head.parameters()
]
)

# @pytest.mark.parametrize('model_name', ['unet', 'deeplabv3+'])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove commented-out code

# def test_freeze_decoder(self, model_name: str) -> None:
# model = InstanceSegmentationTask(model=model_name, freeze_decoder=True)
# assert all(
# [param.requires_grad is False for param in model.model.decoder.parameters()]
# )
# assert all([param.requires_grad for param in model.model.encoder.parameters()])
# assert all(
# [
# param.requires_grad
# for param in model.model.segmentation_head.parameters()
# ]
# )
2 changes: 2 additions & 0 deletions torchgeo/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
from .regression import PixelwiseRegressionTask, RegressionTask
from .segmentation import SemanticSegmentationTask
from .simclr import SimCLRTask
from .instance_segmentation import InstanceSegmentationTask

__all__ = (
'BYOLTask',
'BaseTask',
'ClassificationTask',
'InstanceSegmentationTask'
'IOBenchTask',
'MoCoTask',
'MultiLabelClassificationTask',
Expand Down
Loading
Loading