Skip to content

Commit 9664834

Browse files
Update instance_segmentation.py
1 parent 98c836a commit 9664834

File tree

1 file changed

+36
-41
lines changed

1 file changed

+36
-41
lines changed

torchgeo/trainers/instance_segmentation.py

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
from torchmetrics import MetricCollection
1212
from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights
1313
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
14-
from .base import BaseTask
14+
from torchgeo.trainers.base import BaseTask
1515
import matplotlib.pyplot as plt
1616
from matplotlib.figure import Figure
17-
from ..datasets import RGBBandsMissingError, unbind_samples
17+
from torchgeo.datasets import RGBBandsMissingError, unbind_samples
18+
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
19+
import numpy as np
1820

1921
class InstanceSegmentationTask(BaseTask):
2022
"""Instance Segmentation."""
@@ -66,7 +68,7 @@ def configure_models(self) -> None:
6668

6769
if model == 'mask_rcnn':
6870
# Load the Mask R-CNN model with a ResNet50 backbone
69-
self.model = maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT)
71+
self.model = maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT, rpn_nms_thresh=0.5, box_nms_thresh=0.3)
7072

7173
# Update the classification head to predict `num_classes`
7274
in_features = self.model.roi_heads.box_predictor.cls_score.in_features
@@ -75,9 +77,10 @@ def configure_models(self) -> None:
7577

7678
# Update the mask head for instance segmentation
7779
in_features_mask = self.model.roi_heads.mask_predictor.conv5_mask.in_channels
78-
self.model.roi_heads.mask_predictor = nn.ConvTranspose2d(
79-
in_features_mask, num_classes, kernel_size=2, stride=2
80-
)
80+
81+
hidden_layer = 256
82+
self.model.roi_heads.mask_predictor = MaskRCNNPredictor(
83+
in_features_mask, hidden_layer, num_classes)
8184

8285
else:
8386
raise ValueError(
@@ -114,9 +117,7 @@ def training_step(self, batch: Any, batch_idx: int) -> Tensor:
114117
loss_dict = self.model(images, targets)
115118
loss = sum(loss for loss in loss_dict.values())
116119

117-
print('\nTRAINING LOSS\n')
118-
print(loss_dict, '\n\n')
119-
print(loss)
120+
print(f"\nTRAINING STEP LOSS: {loss.item()}")
120121

121122
self.log('train_loss', loss, batch_size=len(images))
122123
return loss
@@ -134,20 +135,21 @@ def validation_step(self, batch: Any, batch_idx: int) -> None:
134135
batch_size = images.shape[0]
135136

136137
outputs = self.model(images)
137-
loss_dict = self.model(images, targets) # list of dictionaries
138-
total_loss = sum(loss_item for loss_dict in loss_dict for loss_item in loss_dict.values() if loss_item.ndim == 0)
138+
loss_dict_list = self.model(images, targets) # list of dictionaries
139+
total_loss = sum(
140+
sum(loss_item for loss_item in loss_dict.values() if loss_item.ndim == 0)
141+
for loss_dict in loss_dict_list
142+
)
139143

140144
for target in targets:
141145
target["masks"] = (target["masks"] > 0).to(torch.uint8)
142146
target["boxes"] = target["boxes"].to(torch.float32)
143147
target["labels"] = target["labels"].to(torch.int64)
144-
145-
# Post-process the outputs to ensure masks are in the correct format
148+
146149
for output in outputs:
147150
if "masks" in output:
148151
output["masks"] = (output["masks"] > 0.5).squeeze(1).to(torch.uint8)
149-
150-
# Sum the losses
152+
151153
self.log('val_loss', total_loss, batch_size=batch_size)
152154

153155
metrics = self.val_metrics(outputs, targets)
@@ -197,16 +199,20 @@ def test_step(self, batch: Any, batch_idx: int) -> None:
197199
batch_size = images.shape[0]
198200

199201
outputs = self.model(images)
200-
loss_dict = self.model(images, targets) # Compute all losses
201-
total_loss = sum(loss_item for loss_dict in loss_dict for loss_item in loss_dict.values() if loss_item.ndim == 0)
202+
loss_dict_list = self.model(images, targets) # Compute all losses, list of dictonaries (one for every batch element)
203+
total_loss = sum(
204+
sum(loss_item for loss_item in loss_dict.values() if loss_item.ndim == 0)
205+
for loss_dict in loss_dict_list
206+
)
202207

203208
for target in targets:
204209
target["masks"] = target["masks"].to(torch.uint8)
205210
target["boxes"] = target["boxes"].to(torch.float32)
206211
target["labels"] = target["labels"].to(torch.int64)
207212

208213
for output in outputs:
209-
output["masks"] = (output["masks"] > 0.5).squeeze(1).to(torch.uint8)
214+
if "masks" in output:
215+
output["masks"] = (output["masks"] > 0.5).squeeze(1).to(torch.uint8)
210216

211217
self.log('test_loss', total_loss, batch_size=batch_size)
212218

@@ -219,33 +225,22 @@ def test_step(self, batch: Any, batch_idx: int) -> None:
219225
value = value.to(torch.float32).mean()
220226
scalar_metrics[key] = value
221227

222-
self.log_dict(scalar_metrics, batch_size=batch_size)
223-
224-
print('\nTESTING LOSS\n')
225-
print(loss_dict, '\n\n')
226-
print(total_loss)
228+
self.log_dict(scalar_metrics, batch_size=batch_size)
227229

228230
def predict_step(self, batch: Any, batch_idx: int) -> Any:
229-
"""Perform inference on a batch of images.
230-
231-
Args:
232-
batch: A batch of images.
233-
234-
Returns:
235-
Predicted masks and bounding boxes for the batch.
236-
"""
231+
"""Perform inference on a batch of images."""
237232
self.model.eval()
238-
images = batch['image']
239-
outputs = self.model(images)
240-
241-
for output in outputs:
242-
output["masks"] = (output["masks"] > 0.5).to(torch.uint8)
243-
return outputs
244-
245-
246-
247-
233+
images = batch['image']
248234

235+
with torch.no_grad():
236+
outputs = self.model(images)
249237

238+
for output in outputs:
239+
keep = output["scores"] > 0.05
240+
output["boxes"] = output["boxes"][keep]
241+
output["labels"] = output["labels"][keep]
242+
output["scores"] = output["scores"][keep]
243+
output["masks"] = (output["masks"] > 0.5).squeeze(1).to(torch.uint8)[keep]
250244

245+
return outputs
251246

0 commit comments

Comments
 (0)