Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
cavalleria committed Jun 3, 2020
1 parent f68aaf3 commit 253ab9e
Show file tree
Hide file tree
Showing 12 changed files with 767 additions and 16 deletions.
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* [Requirements](#requirements)
* [Features](#features)
* [Folder Structure](#folder-structure)
* [Benchmark](#benchmark)
* [Usage](#usage)
* [Config file format](#config-file-format)
* [Using config files](#using-config-files)
Expand Down Expand Up @@ -80,6 +81,30 @@
├── metric.py
└── util.py
```
## Benchmark
| Methon | Backbone | Loss | Pretrain | Train Loss | Train miou | Valid Loss | Valid miou |
| :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: |
| UNet | Mobilenetv2 | Dice Loss | no | 0.0231 | 0.9534 | 0.0242 | 0.9512 |
| UNet | ResNet-18 | Dice Loss | no | 0.0220 | 0.9600 | 0.0239 | 0.9582 |
| UNet | ResNet-18 | BCE Loss | no | 0.0334 | 0.9656 | 0.0365 | 0.9594 |
| UNet | ResNet-18 | Lovasz Loss | no | 0.0368 | 0.9593 | 0.0452 | 0.9550 |
| UNet | ResNet-50 | BCE Loss | no | 0.0340 | 0.9651 | 0.0368 | 0.9585 |
| Deeplabv3+ | ResNet-18 | CE | yes | 0.0279 | 0.9707 | 0.0303 | 0.9667 |
| Deeplabv3+ | ResNet-50 | BCE Loss | yes | 0.0241 | 0.9744 | 0.0290 | 0.9696 |
| UNet | Mobilenetv2 | BCE Loss | no | 0.0392 | 0.9604 | 0.0383 | 0.9576 |
| UNet | Mobilenetv2 | BCE Loss | yes | 0.0278 | 0.9712 | 0.0324 | 0.9662 |
| UNet | Mobilenetv2 | Lovasz Loss | yes | 0.0357 | 0.9674 | 0.0426 | 0.9656 |
| Deeplabv3+ | Mobilenetv2 | BCE Loss | yes | 0.0311 | 0.9677 | 0.0313 | 0.9659 |
| Deeplabv3+ | Xception65 | BCE Loss | yes | 0.0359 | 0.9626 | 0.0424 | 0.9543 |s










## Usage
The code in this repo is an MNIST example of the template.
Expand Down
2 changes: 1 addition & 1 deletion config/config_deeplab.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
}
},

"loss": "bce_loss",
"loss": "ce_loss",
"metrics": [
"miou"
],
Expand Down
90 changes: 90 additions & 0 deletions config/config_hrnet.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
{
"name": "HumanSeg",
"n_gpu": 1,

"arch": {
"type": "HighResolutionNet",
"args": {
"backbone": null,
"num_classes": 2,
"pretrained_backbone": "./pretrained/hrnet_w18_small_v2.pth"
}
},

"train_loader": {
"type": "SegmentationDataLoader",
"args":{
"prefix": "/workspace/data",
"pairs_file": "../data/train_mask.txt",
"color_channel": "RGB",
"resize": 320,
"padding_value": 0,
"is_training": true,
"noise_std": 3,
"crop_range": [0.90, 1.0],
"flip_hor": 0.5,
"rotate": 0.0,
"angle": 10,
"normalize": true,
"one_hot": false,
"shuffle": true,
"batch_size": 16,
"n_workers": 24,
"pin_memory": true
}
},

"valid_loader": {
"type": "SegmentationDataLoader",
"args":{
"prefix": "/workspace/data",
"pairs_file": "../data/valid_mask.txt",
"color_channel": "RGB",
"resize": 320,
"padding_value": 0,
"is_training": false,
"normalize": true,
"one_hot": false,
"shuffle": false,
"batch_size": 16,
"n_workers": 24,
"pin_memory": true
}
},

"optimizer": {
"type": "SGD",
"args":{
"lr": 1e-2,
"momentum": 0.9,
"weight_decay": 1e-8
}
},

"loss": "custom_hrnet_loss",
"metrics": [
"custom_hrnet_miou"
],

"lr_scheduler": {
"type":"StepLR",
"args":{
"step_size": 100,
"gamma": 1.0
}
},

"trainer": {
"epochs": 80,
"save_dir": "/workspace/models/",
"save_freq": 1,
"verbosity": 2,
"monitor": "valid_loss",
"monitor_mode": "min"
},

"visualization":{
"tensorboardX": true,
"log_dir": "/workspace/models/"
}
}
2 changes: 1 addition & 1 deletion config/config_unet.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
}
},

"loss": "lovasz_softmax",
"loss": "bce_loss",
"metrics": [
"miou"
],
Expand Down
51 changes: 42 additions & 9 deletions evaluation/losses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
try:
Expand Down Expand Up @@ -40,14 +41,14 @@ def dice_loss_with_sigmoid(sigmoid, targets, smooth=1.0):
return dice

def bce_loss(logits, targets):
"""
logits: (torch.float32) shape (N, C, H, W) (16, 2, 160, 160)
targets: (torch.float32) shape (N, H, W), value {0,1,...,C-1} (16, 160, 160)
"""
targets = torch.unsqueeze(targets, dim=1)
targets = torch.zeros_like(logits).scatter_(dim=1, index=targets.type(torch.int64), src=torch.tensor(1.0))
loss = F.binary_cross_entropy_with_logits(logits, targets)
return loss
"""
logits: (torch.float32) shape (N, C, H, W) (16, 2, 160, 160)
targets: (torch.float32) shape (N, H, W), value {0,1,...,C-1} (16, 160, 160)
"""
targets = torch.unsqueeze(targets, dim=1)
targets = torch.zeros_like(logits).scatter_(dim=1, index=targets.type(torch.int64), src=torch.tensor(1.0))
loss = F.binary_cross_entropy_with_logits(logits, targets)
return loss

#==============================lovasz loss==================================
# adapted from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytorch/lovasz_losses.py
Expand Down Expand Up @@ -205,4 +206,36 @@ def custom_icnet_loss(logits, targets, alpha=[0.4, 0.16]):
return loss1 + alpha[0]*loss2 + alpha[1]*loss3

else:
return ce_loss(logits, targets)
return ce_loss(logits, targets)


def custom_hrnet_loss(logits, targets):

ph, pw = logits.size(2), logits.size(3)
h, w = targets.size(1), targets.size(2)
if ph != h or pw != w:
logits = F.upsample(input=logits, size=(h, w), mode='bilinear')

loss = ce_loss(logits, targets)
return loss

# For HRNet
def custom_hrnet_loss_ohem(score, target, ignore_label=-1, thresh=0.7, min_kept=100000):
ph, pw = score.size(2), score.size(3)
h, w = target.size(1), target.size(2)
if ph != h or pw != w:
score = F.upsample(input=score, size=(h, w), mode='bilinear')
pred = F.softmax(score, dim=1)
pixel_losses = ce_loss(score, target).contiguous().view(-1)
mask = target.contiguous().view(-1) != ignore_label

tmp_target = target.clone()
tmp_target[tmp_target == ignore_label] = 0
pred = pred.gather(1, tmp_target.unsqueeze(1))
pred, ind = pred.contiguous().view(-1,)[mask].contiguous().sort()
min_value = pred[min(min_kept, pred.numel() - 1)]
threshold = max(min_value, thresh)

pixel_losses = pixel_losses[mask][ind]
pixel_losses = pixel_losses[pred < threshold]
return pixel_losses.mean()
18 changes: 17 additions & 1 deletion evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,20 @@ def custom_icnet_miou(logits, targets):
targets = F.interpolate(targets, size=logits[0].shape[-2:], mode='bilinear', align_corners=True)[:,0,...]
return miou(logits[0], targets)
else:
return miou(logits, targets)
return miou(logits, targets)

def custom_hrnet_miou(logits, targets, eps=1e-6):

ph, pw = logits.size(2), logits.size(3)
h, w = targets.size(1), targets.size(2)
if ph != h or pw != w:
logits = F.upsample(input=logits, size=(h, w), mode='bilinear')
outputs = torch.argmax(logits, dim=1, keepdim=True).type(torch.int64)
targets = torch.unsqueeze(targets, dim=1).type(torch.int64)
outputs = torch.zeros_like(logits).scatter_(dim=1, index=outputs, src=torch.tensor(1.0)).type(torch.int8)
targets = torch.zeros_like(logits).scatter_(dim=1, index=targets, src=torch.tensor(1.0)).type(torch.int8)

inter = (outputs & targets).type(torch.float32).sum(dim=(2,3))
union = (outputs | targets).type(torch.float32).sum(dim=(2,3))
iou = inter / (union + eps)
return iou.mean()
3 changes: 2 additions & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from models.unet import UNet
from models.deeplabv3_plus import DeepLabV3Plus
from models.deeplabv3_plus import DeepLabV3Plus
from models.hrnet import HighResolutionNet
Loading

0 comments on commit 253ab9e

Please sign in to comment.