11
11
from torchmetrics import MetricCollection
12
12
from torchvision .models .detection import maskrcnn_resnet50_fpn , MaskRCNN_ResNet50_FPN_Weights
13
13
from torchvision .models .detection .faster_rcnn import FastRCNNPredictor
14
- from .base import BaseTask
14
+ from torchgeo . trainers .base import BaseTask
15
15
import matplotlib .pyplot as plt
16
16
from matplotlib .figure import Figure
17
- from ..datasets import RGBBandsMissingError , unbind_samples
17
+ from torchgeo .datasets import RGBBandsMissingError , unbind_samples
18
+ from torchvision .models .detection .mask_rcnn import MaskRCNNPredictor
19
+ import numpy as np
18
20
19
21
class InstanceSegmentationTask (BaseTask ):
20
22
"""Instance Segmentation."""
@@ -66,7 +68,7 @@ def configure_models(self) -> None:
66
68
67
69
if model == 'mask_rcnn' :
68
70
# Load the Mask R-CNN model with a ResNet50 backbone
69
- self .model = maskrcnn_resnet50_fpn (weights = MaskRCNN_ResNet50_FPN_Weights .DEFAULT )
71
+ self .model = maskrcnn_resnet50_fpn (weights = MaskRCNN_ResNet50_FPN_Weights .DEFAULT , rpn_nms_thresh = 0.5 , box_nms_thresh = 0.3 )
70
72
71
73
# Update the classification head to predict `num_classes`
72
74
in_features = self .model .roi_heads .box_predictor .cls_score .in_features
@@ -75,9 +77,10 @@ def configure_models(self) -> None:
75
77
76
78
# Update the mask head for instance segmentation
77
79
in_features_mask = self .model .roi_heads .mask_predictor .conv5_mask .in_channels
78
- self .model .roi_heads .mask_predictor = nn .ConvTranspose2d (
79
- in_features_mask , num_classes , kernel_size = 2 , stride = 2
80
- )
80
+
81
+ hidden_layer = 256
82
+ self .model .roi_heads .mask_predictor = MaskRCNNPredictor (
83
+ in_features_mask , hidden_layer , num_classes )
81
84
82
85
else :
83
86
raise ValueError (
@@ -114,9 +117,7 @@ def training_step(self, batch: Any, batch_idx: int) -> Tensor:
114
117
loss_dict = self .model (images , targets )
115
118
loss = sum (loss for loss in loss_dict .values ())
116
119
117
- print ('\n TRAINING LOSS\n ' )
118
- print (loss_dict , '\n \n ' )
119
- print (loss )
120
+ print (f"\n TRAINING STEP LOSS: { loss .item ()} " )
120
121
121
122
self .log ('train_loss' , loss , batch_size = len (images ))
122
123
return loss
@@ -134,20 +135,21 @@ def validation_step(self, batch: Any, batch_idx: int) -> None:
134
135
batch_size = images .shape [0 ]
135
136
136
137
outputs = self .model (images )
137
- loss_dict = self .model (images , targets ) # list of dictionaries
138
- total_loss = sum (loss_item for loss_dict in loss_dict for loss_item in loss_dict .values () if loss_item .ndim == 0 )
138
+ loss_dict_list = self .model (images , targets ) # list of dictionaries
139
+ total_loss = sum (
140
+ sum (loss_item for loss_item in loss_dict .values () if loss_item .ndim == 0 )
141
+ for loss_dict in loss_dict_list
142
+ )
139
143
140
144
for target in targets :
141
145
target ["masks" ] = (target ["masks" ] > 0 ).to (torch .uint8 )
142
146
target ["boxes" ] = target ["boxes" ].to (torch .float32 )
143
147
target ["labels" ] = target ["labels" ].to (torch .int64 )
144
-
145
- # Post-process the outputs to ensure masks are in the correct format
148
+
146
149
for output in outputs :
147
150
if "masks" in output :
148
151
output ["masks" ] = (output ["masks" ] > 0.5 ).squeeze (1 ).to (torch .uint8 )
149
-
150
- # Sum the losses
152
+
151
153
self .log ('val_loss' , total_loss , batch_size = batch_size )
152
154
153
155
metrics = self .val_metrics (outputs , targets )
@@ -197,16 +199,20 @@ def test_step(self, batch: Any, batch_idx: int) -> None:
197
199
batch_size = images .shape [0 ]
198
200
199
201
outputs = self .model (images )
200
- loss_dict = self .model (images , targets ) # Compute all losses
201
- total_loss = sum (loss_item for loss_dict in loss_dict for loss_item in loss_dict .values () if loss_item .ndim == 0 )
202
+ loss_dict_list = self .model (images , targets ) # Compute all losses, list of dictonaries (one for every batch element)
203
+ total_loss = sum (
204
+ sum (loss_item for loss_item in loss_dict .values () if loss_item .ndim == 0 )
205
+ for loss_dict in loss_dict_list
206
+ )
202
207
203
208
for target in targets :
204
209
target ["masks" ] = target ["masks" ].to (torch .uint8 )
205
210
target ["boxes" ] = target ["boxes" ].to (torch .float32 )
206
211
target ["labels" ] = target ["labels" ].to (torch .int64 )
207
212
208
213
for output in outputs :
209
- output ["masks" ] = (output ["masks" ] > 0.5 ).squeeze (1 ).to (torch .uint8 )
214
+ if "masks" in output :
215
+ output ["masks" ] = (output ["masks" ] > 0.5 ).squeeze (1 ).to (torch .uint8 )
210
216
211
217
self .log ('test_loss' , total_loss , batch_size = batch_size )
212
218
@@ -219,33 +225,22 @@ def test_step(self, batch: Any, batch_idx: int) -> None:
219
225
value = value .to (torch .float32 ).mean ()
220
226
scalar_metrics [key ] = value
221
227
222
- self .log_dict (scalar_metrics , batch_size = batch_size )
223
-
224
- print ('\n TESTING LOSS\n ' )
225
- print (loss_dict , '\n \n ' )
226
- print (total_loss )
228
+ self .log_dict (scalar_metrics , batch_size = batch_size )
227
229
228
230
def predict_step (self , batch : Any , batch_idx : int ) -> Any :
229
- """Perform inference on a batch of images.
230
-
231
- Args:
232
- batch: A batch of images.
233
-
234
- Returns:
235
- Predicted masks and bounding boxes for the batch.
236
- """
231
+ """Perform inference on a batch of images."""
237
232
self .model .eval ()
238
- images = batch ['image' ]
239
- outputs = self .model (images )
240
-
241
- for output in outputs :
242
- output ["masks" ] = (output ["masks" ] > 0.5 ).to (torch .uint8 )
243
- return outputs
244
-
245
-
246
-
247
-
233
+ images = batch ['image' ]
248
234
235
+ with torch .no_grad ():
236
+ outputs = self .model (images )
249
237
238
+ for output in outputs :
239
+ keep = output ["scores" ] > 0.05
240
+ output ["boxes" ] = output ["boxes" ][keep ]
241
+ output ["labels" ] = output ["labels" ][keep ]
242
+ output ["scores" ] = output ["scores" ][keep ]
243
+ output ["masks" ] = (output ["masks" ] > 0.5 ).squeeze (1 ).to (torch .uint8 )[keep ]
250
244
245
+ return outputs
251
246
0 commit comments