Skip to content

Commit

Permalink
init commit
Browse files Browse the repository at this point in the history
  • Loading branch information
cavalleria committed May 20, 2020
1 parent e90344c commit 9eaf9e3
Show file tree
Hide file tree
Showing 22 changed files with 1,391 additions and 499 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,4 @@ datasets/
.vscode/
.idea/
__MACOSX/
.history/
13 changes: 6 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# PyTorch Template Project
PyTorch deep learning project made easy.
# Human Segmentation in Pytorch

<!-- @import "[TOC]" {cmd="toc" depthFrom=1 depthTo=6 orderedList=false} -->

<!-- code_chunk_output -->

* [PyTorch Template Project](#pytorch-template-project)
* [Human Segmentation in Pytorch](#pytorch-template-project)
* [Requirements](#requirements)
* [Features](#features)
* [Folder Structure](#folder-structure)
Expand Down Expand Up @@ -39,9 +38,9 @@ PyTorch deep learning project made easy.
* tensorboard >= 1.14 (see [Tensorboard Visualization](#tensorboard-visualization))

## Features
* Clear folder structure which is suitable for many deep learning projects.
* `.json` config file support for convenient parameter tuning.
* Customizable command line options for more convenient parameter tuning.
* A clear and easy to navigate structure.
* A `.json` config file with a lot of possibilities for parameter tuning.
* Supports various models, losses, Lr schedulers, data augmentations.
* Checkpoint saving and resuming.
* Abstract base classes for faster development:
* `BaseTrainer` handles checkpoint saving/resuming, training process logging, and more.
Expand All @@ -50,7 +49,7 @@ PyTorch deep learning project made easy.

## Folder Structure
```
pytorch-template/
humanseg.pytorch/
├── train.py - main script to start training
├── test.py - evaluation of trained model
Expand Down
9 changes: 9 additions & 0 deletions base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
from .base_data_loader import *
from .base_model import *
from .base_trainer import *


#------------------------------------------------------------------------------
# Bag of Inferences
#------------------------------------------------------------------------------
from base.base_inference import (
BaseInference,
VideoInference
)
30 changes: 15 additions & 15 deletions base/base_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
#------------------------------------------------------------------------------
# Libraries
#------------------------------------------------------------------------------
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import SubsetRandomSampler


#------------------------------------------------------------------------------
# BaseDataLoader
#------------------------------------------------------------------------------
class BaseDataLoader(DataLoader):
"""
Base class for all data loaders
"""
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
self.validation_split = validation_split
self.shuffle = shuffle

self.batch_idx = 0
self.n_samples = len(dataset)

Expand All @@ -23,37 +26,34 @@ def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers,
'shuffle': self.shuffle,
'collate_fn': collate_fn,
'num_workers': num_workers
}
super().__init__(sampler=self.sampler, **self.init_kwargs)
}
super(BaseDataLoader, self).__init__(sampler=self.sampler, **self.init_kwargs)


def _split_sampler(self, split):
if split == 0.0:
return None, None

idx_full = np.arange(self.n_samples)

np.random.seed(0)
np.random.seed(0)
np.random.shuffle(idx_full)

if isinstance(split, int):
assert split > 0
assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
len_valid = split
else:
len_valid = int(self.n_samples * split)
len_valid = int(self.n_samples * split)

valid_idx = idx_full[0:len_valid]
train_idx = np.delete(idx_full, np.arange(0, len_valid))

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

# turn off shuffle option which is mutually exclusive with sampler
self.shuffle = False
self.n_samples = len(train_idx)

return train_sampler, valid_sampler


def split_validation(self):
if self.valid_sampler is None:
return None
Expand Down
173 changes: 173 additions & 0 deletions base/base_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
#------------------------------------------------------------------------------
# Libraries
#------------------------------------------------------------------------------
import cv2, torch
import numpy as np
from time import time
from torch.nn import functional as F


#------------------------------------------------------------------------------
# BaseInference
#------------------------------------------------------------------------------
class BaseInference(object):
def __init__(self, model, color_f=[255,0,0], color_b=[0,0,255], kernel_sz=25, sigma=0, background_path=None):
self.model = model
self.color_f = color_f
self.color_b = color_b
self.kernel_sz = kernel_sz
self.sigma = sigma
self.background_path = background_path
if background_path is not None:
self.background = cv2.imread(background_path)[...,::-1]
self.background = self.background.astype(np.float32)


def load_image(self):
raise NotImplementedError


def preprocess(self, image, *args):
raise NotImplementedError


def predict(self, X):
raise NotImplementedError


def draw_matting(self, image, mask):
"""
image (np.uint8) shape (H,W,3)
mask (np.float32) range from 0 to 1, shape (H,W)
"""
mask = 255*(1.0-mask)
mask = np.expand_dims(mask, axis=2)
mask = np.tile(mask, (1,1,3))
mask = mask.astype(np.uint8)
image_alpha = cv2.add(image, mask)
return image_alpha


def draw_transperency(self, image, mask):
"""
image (np.uint8) shape (H,W,3)
mask (np.float32) range from 0 to 1, shape (H,W)
"""
mask = mask.round()
alpha = np.zeros_like(image, dtype=np.uint8)
alpha[mask==1, :] = self.color_f
alpha[mask==0, :] = self.color_b
image_alpha = cv2.add(image, alpha)
return image_alpha


def draw_background(self, image, mask):
"""
image (np.uint8) shape (H,W,3)
mask (np.float32) range from 0 to 1, shape (H,W)
"""
image = image.astype(np.float32)
mask_filtered = cv2.GaussianBlur(mask, (self.kernel_sz, self.kernel_sz), self.sigma)
mask_filtered = np.expand_dims(mask_filtered, axis=2)
mask_filtered = np.tile(mask_filtered, (1,1,3))

image_alpha = image*mask_filtered + self.background*(1-mask_filtered)
return image_alpha.astype(np.uint8)


#------------------------------------------------------------------------------
# VideoInference
#------------------------------------------------------------------------------
class VideoInference(BaseInference):
def __init__(self, model, video_path, input_size, use_cuda=True, draw_mode='matting',
color_f=[255,0,0], color_b=[0,0,255], kernel_sz=25, sigma=0, background_path=None):

# Initialize
super(VideoInference, self).__init__(model, color_f, color_b, kernel_sz, sigma, background_path)
self.input_size = input_size
self.use_cuda = use_cuda
self.draw_mode = draw_mode
if draw_mode=='matting':
self.draw_func = self.draw_matting
elif draw_mode=='transperency':
self.draw_func = self.draw_transperency
elif draw_mode=='background':
self.draw_func = self.draw_background
else:
raise NotImplementedError

# Preprocess
self.mean = np.array([0.485,0.456,0.406])[None,None,:]
self.std = np.array([0.229,0.224,0.225])[None,None,:]

# Read video
self.video_path = video_path
self.cap = cv2.VideoCapture(video_path)
_, frame = self.cap.read()
self.H, self.W = frame.shape[:2]


def load_image(self):
_, frame = self.cap.read()
image = frame[...,::-1]
return image


def preprocess(self, image):
image = cv2.resize(image, (self.input_size,self.input_size), interpolation=cv2.INTER_LINEAR)
image = image.astype(np.float32) / 255.0
image = (image - self.mean) / self.std
X = np.transpose(image, axes=(2, 0, 1))
X = np.expand_dims(X, axis=0)
X = torch.tensor(X, dtype=torch.float32)
return X


def predict(self, X):
with torch.no_grad():
if self.use_cuda:
mask = self.model(X.cuda())
mask = F.interpolate(mask, size=(self.H, self.W), mode='bilinear', align_corners=True)
mask = F.softmax(mask, dim=1)
mask = mask[0,1,...].cpu().numpy()
else:
mask = self.model(X)
mask = F.interpolate(mask, size=(self.H, self.W), mode='bilinear', align_corners=True)
mask = F.softmax(mask, dim=1)
mask = mask[0,1,...].numpy()
return mask


def run(self):
while(True):
# Read frame from camera
start_time = time()
image = self.load_image()
read_cam_time = time()

# Preprocess
X = self.preprocess(image)
preproc_time = time()

# Predict
mask = self.predict(X)
predict_time = time()

# Draw result
image_alpha = self.draw_func(image, mask)
draw_time = time()

# Wait for interupt
cv2.imshow('webcam', image_alpha[..., ::-1])
if cv2.waitKey(1) & 0xFF == ord('q'):
break

# Print runtime
read = read_cam_time-start_time
preproc = preproc_time-read_cam_time
pred = predict_time-preproc_time
draw = draw_time-predict_time
total = read + preproc + pred + draw
fps = 1 / total
print("read: %.3f [s]; preproc: %.3f [s]; pred: %.3f [s]; draw: %.3f [s]; total: %.3f [s]; fps: %.2f [Hz]" %
(read, preproc, pred, draw, total, fps))
Loading

0 comments on commit 9eaf9e3

Please sign in to comment.