Skip to content

Commit

Permalink
run unet/mobilenetv2 ok
Browse files Browse the repository at this point in the history
  • Loading branch information
cavalleria committed May 21, 2020
1 parent 822cf00 commit 7cd607f
Show file tree
Hide file tree
Showing 12 changed files with 134 additions and 32 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
* [Trainer](#trainer)
* [Model](#model)
* [Loss](#loss)
* [metrics](#metrics)
* [Metrics](#metrics)
* [Additional logging](#additional-logging)
* [Validation data](#validation-data)
* [Checkpoints](#checkpoints)
Expand Down
8 changes: 8 additions & 0 deletions base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .base_data_loader import *
from .base_model import *
from .base_trainer import *

from base.base_inference import (
BaseInference,
VideoInference
)
10 changes: 6 additions & 4 deletions config/config_unet.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
"train_loader": {
"type": "SegmentationDataLoader",
"args":{
"pairs_file": "dataset/train_mask.txt",
"prefix": "/workspace/data",
"pairs_file": "../data/train_mask.txt",
"color_channel": "RGB",
"resize": 320,
"padding_value": 0,
Expand All @@ -36,7 +37,8 @@
"valid_loader": {
"type": "SegmentationDataLoader",
"args":{
"pairs_file": "dataset/valid_mask.txt",
"prefix": "/workspace/data",
"pairs_file": "../data/valid_mask.txt",
"color_channel": "RGB",
"resize": 320,
"padding_value": 0,
Expand Down Expand Up @@ -74,7 +76,7 @@

"trainer": {
"epochs": 80,
"save_dir": "/workspace/checkpoints/",
"save_dir": "/workspace/models/",
"save_freq": null,
"verbosity": 2,
"monitor": "valid_loss",
Expand All @@ -83,6 +85,6 @@

"visualization":{
"tensorboardX": true,
"log_dir": "/workspace/checkpoints/"
"log_dir": "/workspace/models/"
}
}
10 changes: 7 additions & 3 deletions data_loader/data_loaders.py → data_loader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
import torch
from torch.utils.data import Dataset, DataLoader

from dataloaders import transforms
from data_loader import transforms
# import transforms
class SegmentationDataLoader(object):
def __init__(self, pairs_file, color_channel="RGB", resize=224, padding_value=0,
def __init__(self, prefix, pairs_file, color_channel="RGB", resize=224, padding_value=0,
crop_range=[0.75, 1.0], flip_hor=0.5, rotate=0.3, angle=10, noise_std=5,
normalize=True, one_hot=False, is_training=True,
shuffle=True, batch_size=1, n_workers=1, pin_memory=True):

# Storage parameters
super(SegmentationDataLoader, self).__init__()
self.prefix = prefix
self.pairs_file = pairs_file
self.color_channel = color_channel
self.resize = resize
Expand All @@ -35,6 +36,7 @@ def __init__(self, pairs_file, color_channel="RGB", resize=224, padding_value=0,

# Dataset
self.dataset = SegmentationDataset(
prefix = self.prefix,
pairs_file=self.pairs_file,
color_channel=self.color_channel,
resize=self.resize,
Expand Down Expand Up @@ -64,7 +66,7 @@ class SegmentationDataset(Dataset):
The dataset requires label is a grayscale image with value {0,1,...,C-1},
where C is the number of classes.
"""
def __init__(self, pairs_file, color_channel="RGB", resize=512, padding_value=0,
def __init__(self, prefix, pairs_file, color_channel="RGB", resize=512, padding_value=0,
is_training=True, noise_std=5, crop_range=[0.75, 1.0], flip_hor=0.5, rotate=0.3, angle=10,
one_hot=False, normalize=True, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):

Expand All @@ -79,6 +81,8 @@ def __init__(self, pairs_file, color_channel="RGB", resize=512, padding_value=0,
error_flg = False
for line in lines:
image_file, label_file = line
image_file = os.path.join(prefix, image_file.strip())
label_file = os.path.join(prefix, label_file.strip())
if not os.path.exists(image_file):
print("%s does not exist!" % (image_file))
error_flg = True
Expand Down
1 change: 1 addition & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from models.unet import UNet
35 changes: 17 additions & 18 deletions models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from functools import reduce

from base.base_model import BaseModel
from nets.mobilenetv2 import MobileNetV2
from nets.resnet import ResNet

from nets import mobilenetv2
from nets import resnet

class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, block_unit):
Expand All @@ -26,23 +25,23 @@ def __init__(self, backbone="mobilenetv2", num_classes=2, pretrained_backbone=No
if backbone=='mobilenetv2':
alpha = 1.0
expansion = 6
self.backbone = MobileNetV2.MobileNetV2(alpha=alpha, expansion=expansion, num_classes=None)
self.backbone = mobilenetv2.MobileNetV2(alpha=alpha, expansion=expansion, num_classes=None)
self._run_backbone = self._run_backbone_mobilenetv2
# Stage 1
channel1 = MobileNetV2._make_divisible(int(96*alpha), 8)
block_unit = MobileNetV2.InvertedResidual(2*channel1, channel1, 1, expansion)
channel1 = mobilenetv2._make_divisible(int(96*alpha), 8)
block_unit = mobilenetv2.InvertedResidual(2*channel1, channel1, 1, expansion)
self.decoder1 = DecoderBlock(self.backbone.last_channel, channel1, block_unit)
# Stage 2
channel2 = MobileNetV2._make_divisible(int(32*alpha), 8)
block_unit = MobileNetV2.InvertedResidual(2*channel2, channel2, 1, expansion)
channel2 = mobilenetv2._make_divisible(int(32*alpha), 8)
block_unit = mobilenetv2.InvertedResidual(2*channel2, channel2, 1, expansion)
self.decoder2 = DecoderBlock(channel1, channel2, block_unit)
# Stage 3
channel3 = MobileNetV2._make_divisible(int(24*alpha), 8)
block_unit = MobileNetV2.InvertedResidual(2*channel3, channel3, 1, expansion)
channel3 = mobilenetv2._make_divisible(int(24*alpha), 8)
block_unit = mobilenetv2.InvertedResidual(2*channel3, channel3, 1, expansion)
self.decoder3 = DecoderBlock(channel2, channel3, block_unit)
# Stage 4
channel4 = MobileNetV2._make_divisible(int(16*alpha), 8)
block_unit = MobileNetV2.InvertedResidual(2*channel4, channel4, 1, expansion)
channel4 = mobilenetv2._make_divisible(int(16*alpha), 8)
block_unit = mobilenetv2.InvertedResidual(2*channel4, channel4, 1, expansion)
self.decoder4 = DecoderBlock(channel3, channel4, block_unit)

elif 'resnet' in backbone:
Expand All @@ -57,28 +56,28 @@ def __init__(self, backbone="mobilenetv2", num_classes=2, pretrained_backbone=No
else:
raise NotImplementedError
filters = 64
self.backbone = ResNet.get_resnet(n_layers, num_classes=None)
self.backbone = resnet.get_resnet(n_layers, num_classes=None)
self._run_backbone = self._run_backbone_resnet
block = ResNet.BasicBlock if (n_layers==18 or n_layers==34) else ResNet.Bottleneck
block = resnet.BasicBlock if (n_layers==18 or n_layers==34) else ResNet.Bottleneck
# Stage 1
last_channel = 8*filters if (n_layers==18 or n_layers==34) else 32*filters
channel1 = 4*filters if (n_layers==18 or n_layers==34) else 16*filters
downsample = nn.Sequential(ResNet.conv1x1(2*channel1, channel1), nn.BatchNorm2d(channel1))
downsample = nn.Sequential(resnet.conv1x1(2*channel1, channel1), nn.BatchNorm2d(channel1))
block_unit = block(2*channel1, int(channel1/block.expansion), 1, downsample)
self.decoder1 = DecoderBlock(last_channel, channel1, block_unit)
# Stage 2
channel2 = 2*filters if (n_layers==18 or n_layers==34) else 8*filters
downsample = nn.Sequential(ResNet.conv1x1(2*channel2, channel2), nn.BatchNorm2d(channel2))
downsample = nn.Sequential(resnet.conv1x1(2*channel2, channel2), nn.BatchNorm2d(channel2))
block_unit = block(2*channel2, int(channel2/block.expansion), 1, downsample)
self.decoder2 = DecoderBlock(channel1, channel2, block_unit)
# Stage 3
channel3 = filters if (n_layers==18 or n_layers==34) else 4*filters
downsample = nn.Sequential(ResNet.conv1x1(2*channel3, channel3), nn.BatchNorm2d(channel3))
downsample = nn.Sequential(resnet.conv1x1(2*channel3, channel3), nn.BatchNorm2d(channel3))
block_unit = block(2*channel3, int(channel3/block.expansion), 1, downsample)
self.decoder3 = DecoderBlock(channel2, channel3, block_unit)
# Stage 4
channel4 = filters
downsample = nn.Sequential(ResNet.conv1x1(2*channel4, channel4), nn.BatchNorm2d(channel4))
downsample = nn.Sequential(resnet.conv1x1(2*channel4, channel4), nn.BatchNorm2d(channel4))
block_unit = block(2*channel4, int(channel4/block.expansion), 1, downsample)
self.decoder4 = DecoderBlock(channel3, channel4, block_unit)

Expand Down
1 change: 1 addition & 0 deletions nets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from nets.mobilenetv2 import MobileNetV2
4 changes: 0 additions & 4 deletions nets/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,12 @@ def __init__(self, alpha=1.0, expansion=6, num_classes=1000):
def forward(self, x, feature_names=None):
# Stage1
x = reduce(lambda x, n: self.features[n](x), list(range(0,2)), x)

# Stage2
x = reduce(lambda x, n: self.features[n](x), list(range(2,4)), x)

# Stage3
x = reduce(lambda x, n: self.features[n](x), list(range(4,7)), x)

# Stage4
x = reduce(lambda x, n: self.features[n](x), list(range(7,14)), x)

# Stage5
x = reduce(lambda x, n: self.features[n](x), list(range(14,19)), x)

Expand Down
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os, json, argparse,
import os, json, argparse
import torch
import models as module_arch
import evaluation.losses as module_loss
import evaluation.metrics as module_metric
import dataloaders.dataloader as module_data
import data_loader.dataloader as module_data

from utils.logger import Logger
from trainer.trainer import Trainer

def get_instance(module, name, config, *args):
print(module, config[name]['type'])
return getattr(module, config[name]['type'])(*args, **config[name]['args'])

def main(config, resume):
Expand Down
2 changes: 2 additions & 0 deletions train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export PYTHONDONTWRITEBYTECODE=False
python train.py --config config/config_unet.json --device 7
1 change: 1 addition & 0 deletions utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .flops_counter import add_flops_counting_methods, flops_to_string

87 changes: 87 additions & 0 deletions utils/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import importlib
import warnings
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import matplotlib.pyplot as plt

class WriterTensorboardX():
def __init__(self, writer_dir, logger, enable):
self.writer = None
if enable:
log_path = writer_dir
try:
self.writer = importlib.import_module('tensorboardX').SummaryWriter(log_path)
except ModuleNotFoundError:
message = """TensorboardX visualization is configured to use, but currently not installed on this machine. Please install the package by 'pip install tensorboardx' command or turn off the option in the 'config.json' file."""
warnings.warn(message, UserWarning)
logger.warn(message)
self.step = 0
self.tensorboard_writer_ftns = ['add_scalar', 'add_scalars', 'add_image', 'add_audio', 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding']

def set_step(self, step):
self.step = step

def __getattr__(self, name):
"""
If visualization is configured to use:
return add_data() methods of tensorboard with additional information (step, tag) added.
Otherwise:
return blank function handle that does nothing
"""
if name in self.tensorboard_writer_ftns:
add_data = getattr(self.writer, name, None)
def wrapper(tag, data, *args, **kwargs):
if add_data is not None:
add_data('{}'.format(tag), data, self.step, *args, **kwargs)
return wrapper
else:
# default action for returning methods defined in this class, set_step() for instance.
try:
attr = object.__getattr__(name)
except AttributeError:
raise AttributeError("type object 'WriterTensorboardX' has no attribute '{}'".format(name))
return attr


def plot_tensorboard(train_file, valid_file, scalar_names, set_grid=False):
# Read Tensorboard files
train_event_acc = EventAccumulator(train_file)
valid_event_acc = EventAccumulator(valid_file)
train_event_acc.Reload()
valid_event_acc.Reload()

# Get scalar values
train_scalars, valid_scalars = {}, {}
for scalar_name in scalar_names:
train_scalars[scalar_name] = train_event_acc.Scalars(scalar_name)
valid_scalars[scalar_name] = valid_event_acc.Scalars(scalar_name)

# Convert to list
n_epochs = len(train_scalars["loss"])
epochs = [train_scalars["loss"][i][1] for i in range(n_epochs)]

train_lists, valid_lists = {}, {}
for scalar_name in scalar_names:
train_lists[scalar_name] = [train_scalars[scalar_name][i][2] for i in range(n_epochs)]
valid_lists[scalar_name] = [valid_scalars[scalar_name][i][2] for i in range(n_epochs)]

# Plot
for scalar_name in scalar_names:
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
if set_grid:
ax.set_xticks(epochs)

ax.plot(epochs, train_lists[scalar_name], label='train')
ax.plot(epochs, valid_lists[scalar_name], label='valid')

plt.xlabel("epochs")
plt.ylabel(scalar_name)
plt.legend(frameon=True)
plt.grid(True)
plt.show()


if __name__ == '__main__':
train_file = "checkpoints/runs/Mnist_LeNet/1125_110943/train/events.out.tfevents.1543118983.antiaegis"
valid_file = "checkpoints/runs/Mnist_LeNet/1125_110943/valid/events.out.tfevents.1543118983.antiaegis"
plot_tensorboard(train_file, valid_file, ["loss", "my_metric", "my_metric2"])

0 comments on commit 7cd607f

Please sign in to comment.