Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
cavalleria committed May 21, 2020
1 parent 1853e4d commit 822cf00
Show file tree
Hide file tree
Showing 8 changed files with 495 additions and 43 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -396,4 +396,4 @@ Code should pass the [Flake8](http://flake8.pycqa.org/en/latest/) check before c
This project is licensed under the MIT License. See LICENSE for more details
## Acknowledgements
This project is inspired by the project [Human-Segmentation-PyTorch](https://github.com/thuyngch/Human-Segmentation-PyTorch) and [pytorch_segmentation](https://github.com/yassouali/pytorch_segmentation)
This project is inspired by the project [pytorch-template](https://github.com/victoresque/pytorch-template),[Human-Segmentation-PyTorch](https://github.com/thuyngch/Human-Segmentation-PyTorch) and [pytorch_segmentation](https://github.com/yassouali/pytorch_segmentation)
9 changes: 0 additions & 9 deletions base/base_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
#------------------------------------------------------------------------------
# Libraries
#------------------------------------------------------------------------------
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import SubsetRandomSampler


#------------------------------------------------------------------------------
# BaseDataLoader
#------------------------------------------------------------------------------
class BaseDataLoader(DataLoader):
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
self.validation_split = validation_split
Expand All @@ -29,7 +22,6 @@ def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers,
}
super(BaseDataLoader, self).__init__(sampler=self.sampler, **self.init_kwargs)


def _split_sampler(self, split):
if split == 0.0:
return None, None
Expand All @@ -53,7 +45,6 @@ def _split_sampler(self, split):

return train_sampler, valid_sampler


def split_validation(self):
if self.valid_sampler is None:
return None
Expand Down
15 changes: 0 additions & 15 deletions base/base_model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
#------------------------------------------------------------------------------
# Libraries
#------------------------------------------------------------------------------
import torch
import torch.nn as nn

import torchsummary
import os, warnings, sys
from utils import add_flops_counting_methods, flops_to_string


#------------------------------------------------------------------------------
# BaseModel
#------------------------------------------------------------------------------
class BaseModel(nn.Module):
def __init__(self):
super(BaseModel, self).__init__()
Expand Down Expand Up @@ -64,10 +57,6 @@ def load_pretrained_model(self, pretrained):
state_dict.update(model_dict)
self.load_state_dict(state_dict)


#------------------------------------------------------------------------------
# BaseBackbone
#------------------------------------------------------------------------------
class BaseBackbone(BaseModel):
def __init__(self):
super(BaseBackbone, self).__init__()
Expand Down Expand Up @@ -99,10 +88,6 @@ def load_pretrained_model_extended(self, pretrained):
state_dict.update(model_dict)
self.load_state_dict(state_dict)


#------------------------------------------------------------------------------
# BaseBackboneWrapper
#------------------------------------------------------------------------------
class BaseBackboneWrapper(BaseBackbone):
def __init__(self):
super(BaseBackboneWrapper, self).__init__()
Expand Down
17 changes: 2 additions & 15 deletions base/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
#------------------------------------------------------------------------------
# Libraries
#------------------------------------------------------------------------------
from time import time
import os, math, json, logging, datetime, torch
from utils.visualization import WriterTensorboardX


#------------------------------------------------------------------------------
# Class of BaseTrainer
#------------------------------------------------------------------------------
class BaseTrainer:
"""
Base class for all trainers
"""

def __init__(self, model, loss, metrics, optimizer, resume, config, train_logger=None):
self.config = config

Expand Down Expand Up @@ -69,7 +60,6 @@ def __init__(self, model, loss, metrics, optimizer, resume, config, train_logger
if resume:
self._resume_checkpoint(resume)


def _prepare_device(self, n_gpu_use):
"""
setup GPU device if available, move model into configured device
Expand All @@ -86,7 +76,6 @@ def _prepare_device(self, n_gpu_use):
list_ids = list(range(n_gpu_use))
return device, list_ids


def train(self):
for epoch in range(self.start_epoch, self.epochs + 1):
self.logger.info("\n----------------------------------------------------------------")
Expand Down Expand Up @@ -131,7 +120,6 @@ def train(self):
# Save checkpoint
self._save_checkpoint(epoch, save_best=best)


def _train_epoch(self, epoch):
"""
Training logic for an epoch
Expand Down Expand Up @@ -176,7 +164,6 @@ def _save_checkpoint(self, epoch, save_best=False):
else:
self.logger.info("Monitor is not improved from %f" % (self.monitor_best))


def _resume_checkpoint(self, resume_path):
"""
Resume from saved checkpoints
Expand All @@ -200,6 +187,6 @@ def _resume_checkpoint(self, resume_path):
# 'Optimizer parameters not being resumed.')
# else:
# self.optimizer.load_state_dict(checkpoint['optimizer'])

self.train_logger = checkpoint['logger']
self.logger.info("Checkpoint '{}' (epoch {}) loaded".format(resume_path, self.start_epoch-1))
3 changes: 0 additions & 3 deletions data_loader/data_loaders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import warnings
warnings.filterwarnings('ignore')

import os
import cv2
import numpy as np
Expand Down
109 changes: 109 additions & 0 deletions tools/video_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import cv2, torch, argparse
from time import time
import numpy as np
from torch.nn import functional as F

from models import UNet
from data_loader import transforms
from utils import utils

parser = argparse.ArgumentParser(description="Arguments for the script")
parser.add_argument('--use_cuda', action='store_true', default=False, help='Use GPU acceleration')
parser.add_argument('--bg', type=str, default=None, help='Path to the background image file')
parser.add_argument('--watch', action='store_true', default=False, help='Indicate show result live')
parser.add_argument('--input_sz', type=int, default=320, help='Input size')
parser.add_argument('--checkpoint', type=str, default="", help='Path to the trained model file')
parser.add_argument('--video', type=str, default="", help='Path to the input video')
parser.add_argument('--output', type=str, default="", help='Path to the output video')

args = parser.parse_args()

# Video input
cap = cv2.VideoCapture(args.video)
_, frame = cap.read()
H, W = frame.shape[:2]

# Video output
fourcc = cv2.VideoWriter_fourcc(*'DIVX')
out = cv2.VideoWriter(args.output, fourcc, 30, (W,H))
font = cv2.FONT_HERSHEY_SIMPLEX

# Background
if args.bg is not None:
BACKGROUND = cv2.imread(args.bg)[...,::-1]
BACKGROUND = cv2.resize(BACKGROUND, (W,H), interpolation=cv2.INTER_LINEAR)
KERNEL_SZ = 25
SIGMA = 0

# Alpha transperency
else:
COLOR1 = [255, 0, 0]
COLOR2 = [0, 0, 255]

model = UNet(
backbone="mobilenetv2",
num_classes=2,
pretrained_backbone=None
)
if args.use_cuda:
model = model.cuda()
trained_dict = torch.load(args.checkpoint, map_location="cpu")['state_dict']
model.load_state_dict(trained_dict, strict=False)
model.eval()

i = 0
while(cap.isOpened()):
# Read frame from camera
start_time = time()
_, frame = cap.read()
# image = cv2.transpose(frame[...,::-1])
image = frame[...,::-1]
h, w = image.shape[:2]
read_cam_time = time()

# Predict mask
X, pad_up, pad_left, h_new, w_new = utils.preprocessing(image, expected_size=args.input_sz, pad_value=0)
preproc_time = time()
with torch.no_grad():
if args.use_cuda:
mask = model(X.cuda())
mask = mask[..., pad_up: pad_up+h_new, pad_left: pad_left+w_new]
mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=True)
mask = F.softmax(mask, dim=1)
mask = mask[0,1,...].cpu().numpy()
else:
mask = model(X)
mask = mask[..., pad_up: pad_up+h_new, pad_left: pad_left+w_new]
mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=True)
mask = F.softmax(mask, dim=1)
mask = mask[0,1,...].numpy()
predict_time = time()

# Draw result
if args.bg is None:
image_alpha = utils.draw_matting(image, mask)
# image_alpha = utils.draw_transperency(image, mask, COLOR1, COLOR2)
else:
image_alpha = utils.draw_fore_to_back(image, mask, BACKGROUND, kernel_sz=KERNEL_SZ, sigma=SIGMA)
draw_time = time()

# Print runtime
read = read_cam_time-start_time
preproc = preproc_time-read_cam_time
pred = predict_time-preproc_time
draw = draw_time-predict_time
total = read + preproc + pred + draw
fps = 1 / total
print("read: %.3f [s]; preproc: %.3f [s]; pred: %.3f [s]; draw: %.3f [s]; total: %.3f [s]; fps: %.2f [Hz]" %
(read, preproc, pred, draw, total, fps))

# Wait for interupt
cv2.putText(image_alpha, "%.2f [fps]" % (fps), (10, 50), font, 1.5, (0, 255, 0), 2, cv2.LINE_AA)
out.write(image_alpha[..., ::-1])
if args.watch:
cv2.imshow('webcam', image_alpha[..., ::-1])
if cv2.waitKey(1) & 0xFF == ord('q'):
break

cap.release()
cv2.destroyAllWindows()
Loading

0 comments on commit 822cf00

Please sign in to comment.