Skip to content

Commit

Permalink
Upload Segmentation Heads MaskRCNN-Head
Browse files Browse the repository at this point in the history
  • Loading branch information
jiaowoguanren0615 committed Jul 15, 2024
1 parent dcbbeb5 commit 59540f8
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 6 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,14 @@ On the second machine: python -m torch.distributed.run --nproc_per_node=1 --nnod
pages={418--434},
year={2018}
}
```

```
@inproceedings{he2017mask,
title={Mask r-cnn},
author={He, Kaiming and Gkioxari, Georgia and Doll{\'a}r, Piotr and Girshick, Ross},
booktitle={Proceedings of the IEEE international conference on computer vision},
pages={2961--2969},
year={2017}
}
```
14 changes: 8 additions & 6 deletions models/build_seg_models.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from models.segformer_head import SegFormerHead
from models.upernet_head import UPerHead
from models.maskrcnn_head import MaskRCNNSegmentationHead
import torch.nn.functional as F
from models.seg_model_backbones import *


class SegMambaVisionModel(nn.Module):
def __init__(self, backbone, num_classes=19, use_segformer_head=False, **kwargs):
def __init__(self, backbone, num_classes=19, use_maskrcnn_head=False, **kwargs):
super(SegMambaVisionModel, self).__init__()

self.backbone = eval(backbone + '()')

if use_segformer_head == True:
self.decode_head = SegFormerHead(self.backbone.feature_dims, 256 if 'T' in backbone or 'S' in backbone else 768,
num_classes)
if use_maskrcnn_head == True:
self.decode_head = MaskRCNNSegmentationHead(self.backbone.feature_dims,
256 if 'T' in backbone or 'S' in backbone else 768,
num_classes)
else:
self.decode_head = UPerHead(self.backbone.feature_dims, 128 if 'T' in backbone or 'S' in backbone else 768,
num_classes)
Expand Down Expand Up @@ -64,10 +67,9 @@ def SegMambaVision_L2(pretrained=False, pretrained_cfg=None, pretrained_cfg_over
model = SegMambaVisionModel(backbone, **kwargs)
return model


# if __name__ == '__main__':
# import torch
# input_data = torch.randn(2, 3, 224, 224).cuda()
# model = SegMambaVision_T().cuda()
# y = model(input_data)
# print(y.shape)
# print(y.shape)
59 changes: 59 additions & 0 deletions models/maskrcnn_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class MaskRCNNSegmentationHead(nn.Module):
def __init__(self, in_channels_list, channel=256, num_classes=19, dropout_rate=0.3):
super(MaskRCNNSegmentationHead, self).__init__()
self.fpn_lateral = nn.ModuleList([
nn.Conv2d(in_channels, channel, kernel_size=1, stride=1, padding=0)
for in_channels in in_channels_list
])
self.fpn_output = nn.ModuleList([
nn.Sequential(
nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(channel),
nn.ReLU(inplace=True)
)
for _ in range(len(in_channels_list))
])

self.conv1 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(channel)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(channel)
self.relu2 = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(p=dropout_rate)
self.mask_pred = nn.Conv2d(channel, num_classes, kernel_size=1, stride=1)

def forward(self, features):
x_list = []
for i, feature in enumerate(features):
lateral_conv = self.fpn_lateral[i](feature)
output_conv = self.fpn_output[i](lateral_conv)
x_list.append(output_conv)

x = sum([F.interpolate(x, size=x_list[-1].shape[-2:], mode='bilinear', align_corners=False) for x in x_list])

x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.dropout(x)
mask_pred = self.mask_pred(x)
return mask_pred


# if __name__ == '__main__':
# x = torch.randn(2, 3, 224, 224)
# model = MaskRCNNSegmentationHead([160, 320, 640, 640], 128)
# x1 = torch.randn(2, 160, 28, 28)
# x2 = torch.randn(2, 320, 14, 14)
# x3 = torch.randn(2, 640, 7, 7)
# x4 = torch.randn(2, 640, 7, 7)
# y = model([x1, x2, x3, x4])
# y = F.interpolate(y, size=x.shape[2:], mode='bilinear', align_corners=False)
# print(y.shape)

0 comments on commit 59540f8

Please sign in to comment.