Skip to content

Commit

Permalink
code refactoring across train, test interfaces; included testbenches;…
Browse files Browse the repository at this point in the history
… added trained models
  • Loading branch information
Aman Chadha authored and Aman Chadha committed Dec 1, 2019
1 parent 67dbb06 commit 7144b28
Show file tree
Hide file tree
Showing 26 changed files with 184 additions and 47 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ Data/*
outputframes/*

# Files
Train_iSeeBetter_nf.py
run.sh
DL_AMI.txt
epochs/FRVSRwithDisc.zip

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
17 changes: 4 additions & 13 deletions FRVSR_Models.py → AFRVSRModels.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,32 +156,24 @@ def init_hidden(self, device):

# x is a 4-d tensor of shape N×C×H×W
def forward(self, input):
def trunc(tensor):
# tensor = tensor.clone()
tensor[tensor < 0] = 0
tensor[tensor > 1] = 1
return tensor

# Apply FNet
# print(f'input.shape is {input.shape}, lastImg shape is {self.lastLrImg.shape}')
preflow = torch.cat((input, self.lastLrImg), dim=1)
flow = self.fnet(preflow)
# flow += self.lr_identity
relative_place = flow + self.lr_identity
# debug info goes here
self.EstLrImg = func.grid_sample(self.lastLrImg, relative_place.permute(0, 2, 3, 1))
# self.EstLrImg = trunc(self.EstLrImg)
# print(self.EstLrImg)
relative_placeNCHW = func.interpolate(relative_place, scale_factor=4, mode="bilinear")
# relative_placeNCHW = torch.unsqueeze(self.hr_identity, dim=0)
relative_placeNWHC = relative_placeNCHW.permute(0, 2, 3, 1) # shift c to last, as grid_sample function needs it.
relative_placeNWHC = relative_placeNCHW.permute(0, 2, 3, 1) # switch to channels-last notation for grid_sample()
afterWarp = func.grid_sample(self.EstHrImg, relative_placeNWHC)
self.afterWarp = afterWarp # for debugging, should be removed later.
depthImg = self.todepth(afterWarp)

# Apply SRNet
srInput = torch.cat((input, depthImg), dim=1)
estImg = self.srnet(srInput)
self.lastLrImg = input
self.EstHrImg = estImg
#self.EstHrImg = trunc(self.EstHrImg)
self.EstHrImg.retain_grad()
return self.EstHrImg, self.EstLrImg

Expand All @@ -194,7 +186,6 @@ def set_param(self, **kwargs):
if key == 'width':
self.width = val


class Loss(nn.Module):
def __init__(self):
super(Loss, self).__init__()
Expand Down
9 changes: 4 additions & 5 deletions checkTrain.py → AFRVSRTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.nn.functional as func
import matplotlib.pyplot as plt
import Dataset_OnlyHR
import FRVSR_Models
import AFRVSRModels
from skimage import img_as_ubyte
from skimage.util import img_as_float32

Expand Down Expand Up @@ -84,16 +84,15 @@ def psnr(img1, img2):

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Test Single Video')
#parser.add_argument('--model', default='./models/FRVSR.4', type=str, help='generator model epoch name')
parser.add_argument('--model', default='./epochs/netG_epoch_4_4.pth', type=str, help='generator model epoch name')
parser.add_argument('--model', default='./epochs/netG_epoch_4_7.pth', type=str, help='AFRVSR Model')
opt = parser.parse_args()

UPSCALE_FACTOR = 4
MODEL_NAME = opt.model

with torch.no_grad():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = FRVSR_Models.FRVSR(0, 0, 0)
model = AFRVSRModels.FRVSR(0, 0, 0)
model.to(device)

# for cpu
Expand All @@ -102,7 +101,7 @@ def psnr(img1, img2):
model.load_state_dict(checkpoint)
model.eval()

train_loader, val_loader = Dataset_OnlyHR.get_data_loaders(1, dataset_size=0, validation_split=1, shuffle_dataset=True)
train_loader, val_loader = Dataset_OnlyHR.get_data_loaders(batch=1, fixedIndices=0, dataset_size=0, validation_split=1, shuffle_dataset=True)

tot_psnr = 0
for idx, (lr_example, hr_example) in enumerate(val_loader, 1):
Expand Down
6 changes: 3 additions & 3 deletions FRSRGAN_Train.py → AFRVSRTrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from tqdm import tqdm
import Dataset_OnlyHR
import logger
from FRVSR_Models import FRVSR
from FRVSR_Models import GeneratorLoss
from AFRVSRModels import FRVSR
from AFRVSRModels import GeneratorLoss
from SRGAN.model import Discriminator
import SRGAN.pytorch_ssim as pts

Expand Down Expand Up @@ -216,7 +216,7 @@ def saveModelParams(epoch, runningResults, validationResults={}):
data_frame = pd.DataFrame(data={'DLoss': results['DLoss'], 'GLoss': results['GLoss'], 'DScore': results['DScore'],
'GScore': results['GScore']},#, 'PSNR': results['PSNR'], 'SSIM': results['SSIM']},
index=range(1, epoch + 1))
data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')
data_frame.to_csv(out_path + 'AFRVSR_' + str(UPSCALE_FACTOR) + '_Train_Results.csv', index_label='Epoch')

def main():
""" Lets begin the training process! """
Expand Down
Binary file renamed models/FRVSRTest → epochs/netG_epoch_4_1.pth
Binary file not shown.
Binary file added epochs/netG_epoch_4_2.pth
Binary file not shown.
Binary file added epochs/netG_epoch_4_3.pth
Binary file not shown.
Binary file added epochs/netG_epoch_4_4.pth
Binary file not shown.
Binary file added epochs/netG_epoch_4_5.pth
Binary file not shown.
Binary file added epochs/netG_epoch_4_6.pth
Binary file not shown.
Binary file added epochs/netG_epoch_4_7.pth
Binary file not shown.
Binary file modified outputframes/idx_checktrain_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified outputframes/idx_checktrain_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified outputframes/idx_checktrain_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified outputframes/idx_checktrain_4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified outputframes/idx_checktrain_5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified outputframes/idx_checktrain_6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified outputframes/idx_checktrain_7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pytorch
numpy
scikit-image
cv2
opencv-python
tqdm
7 changes: 5 additions & 2 deletions run.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# iSeeBetter: A Novel Approach to Video Super-Resolution using Adaptive Frame Recurrence and Generative Adversarial Networks
# [email protected]

# generate a low res random sample and apply FRVSR
python3 checkTrain.py

python3 Test_iSeeBetter.py --video out_srf_original_random_sample.mp4
# test
python3 Test_iSeeBetter.py --video FRSRVOut_LowRes_Random_Sample.mp4

# apply SRGAN
cd SRGAN
python3 test_video.py --video ../out_srf_4_out_srf_original_random_sample.mp4
python3 test_video.py --video ../FRSRVOut_LowRes_Random_Sample.mp4
5 changes: 5 additions & 0 deletions statistics/AFRVSR_4_Train_Results.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Epoch,DLoss,GLoss,DScore,GScore
1,0.9984568357467651,0.12309867143630981,0.5031180381774902,0.5015749335289001
2,0.9857131838798523,0.12294773012399673,0.5101743340492249,0.49588751792907715
3,0.9748905897140503,0.1227954626083374,0.5166448354721069,0.49153539538383484
4,0.9652048349380493,0.1226423978805542,0.5226783156394958,0.48788315057754517
7 changes: 3 additions & 4 deletions Test_iSeeBetter.py → testbenches/Test_iSeeBetter_FRVSR.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.autograd import Variable
from torchvision.transforms import ToTensor
from tqdm import tqdm
import FRVSR_Models
import AFRVSRModels
import checkTrain

if __name__ == "__main__":
Expand All @@ -28,7 +28,7 @@
MODEL_NAME = opt.model

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = FRVSR_Models.FRVSR(0, 0, 0)
model = AFRVSRModels.FRVSR(0, 0, 0)

model.to(device)

Expand All @@ -44,13 +44,12 @@
lr_width = int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH))
lr_height = int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT))
model.set_param(batch_size=1, width=lr_width, height=lr_height)
#import pdb; pdb.set_trace()
model.init_hidden(device)

sr_video_size = (int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR),
int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)) * UPSCALE_FACTOR)

output_sr_name = 'out_srf_' + str(UPSCALE_FACTOR) + '_' + VIDEO_NAME.split('.')[0] + '.mp4'
output_sr_name = 'test_' + str(UPSCALE_FACTOR) + '_' + VIDEO_NAME.split('.')[0] + '.mp4'
sr_video_writer = cv2.VideoWriter(output_sr_name, cv2.VideoWriter_fourcc('M', 'P', '4', 'V'), fps,
sr_video_size)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,21 @@
"""

import argparse

import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage
from torchvision.transforms import ToTensor
from tqdm import tqdm
import Dataset
import checkTrain
import FRVSR_Models
import AFRVSRModels

if __name__ == "__main__":
with torch.no_grad():
parser = argparse.ArgumentParser(description='Test Single Video')
parser.add_argument('--video', type=str, help='test low resolution video name')
parser.add_argument('--model', default='./models/FRVSR.X', type=str, help='generator model epoch name')
parser.add_argument('--model', default='./models/LR-5_SRN.25', type=str, help='generator model epoch name')
opt = parser.parse_args()

UPSCALE_FACTOR = 4
Expand All @@ -33,7 +30,7 @@
print(MODEL_NAME)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#model = FRVSR.FRVSR(0, 0, 0)
model = FRVSR_Models.SRNet(3) # testing the SRNet only
model = AFRVSRModels.SRNet(3) # testing the SRNet only

model.to(device)

Expand Down
9 changes: 3 additions & 6 deletions Train_iSeeBetter.py → testbenches/Train_iSeeBetter_FRVSR.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import gc
import sys
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.optim as optim
Expand All @@ -17,14 +16,12 @@
from SRGAN import pytorch_ssim

torch.backends.cudnn.benchmark = True
import matplotlib.pyplot as plt
import numpy as np
import FRVSR_Models
import AFRVSRModels
import Dataset_OnlyHR


def load_model(model_name, batch_size, width, height):
model = FRVSR_Models.FRVSR(batch_size=batch_size, lr_height=height, lr_width=width)
model = AFRVSRModels.FRVSR(batch_size=batch_size, lr_height=height, lr_width=width)
if model_name != '':
model_path = f'./models/{model_name}'
checkpoint = torch.load(model_path, map_location='cpu')
Expand Down Expand Up @@ -53,7 +50,7 @@ def run():
num_val_batches = len(val_loader)

flow_criterion = nn.MSELoss().to(device)
content_criterion = FRVSR_Models.Loss().to(device)
content_criterion = AFRVSRModels.Loss().to(device)

ssim_loss = pytorch_ssim.SSIM(window_size=11).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-5)
Expand Down
144 changes: 144 additions & 0 deletions testbenches/Train_iSeeBetter_SRNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import gc
import sys
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler

torch.backends.cudnn.benchmark = True
import matplotlib.pyplot as plt
import numpy as np
import AFRVSRModels
import Dataset
import pytorch_ssim
from skimage.measure import compare_ssim as ssim


def load_model(model_name, batch_size, width, height):
model = AFRVSRModels.SRNet(in_dim=3)
if model_name != '':
model_path = f'./models/{model_name}'
print("successfully loaded the model")
checkpoint = torch.load(model_path, map_location='cpu')
model.load_state_dict(checkpoint)
return model

def run():
# Parameters
num_epochs = 100
output_period = 10
batch_size = 8
width, height = 112, 64

# setup the device for running
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = load_model('', batch_size, width, height)
model = model.to(device)

torch.save(model.state_dict(), "models/AFRVSRTest")

train_loader, val_loader = Dataset.get_data_loaders(batch_size, dataset_size=7000, validation_split=0)
num_train_batches = len(train_loader)
num_val_batches = len(val_loader)

#flow_criterion = nn.MSELoss().to(device)
content_criterion = nn.MSELoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-5)

epoch = 1
while epoch <= num_epochs:
running_loss = 0.0
for param_group in optimizer.param_groups:
print('Current learning rate: ' + str(param_group['lr']))
model.train()

for batch_num, (lr_imgs, hr_imgs) in enumerate(train_loader, 1):
lr_imgs = lr_imgs.to(device)
hr_imgs = hr_imgs.to(device)
# print(f'hrimgs.shape is {hr_imgs.shape}')
# print(f'lrimgs.shape is {lr_imgs.shape}')
optimizer.zero_grad()
#model.init_hidden(device)
batch_content_loss = 0
#batch_flow_loss = 0

# lr_imgs = 7 * 4 * 3 * H * W
for lr_img, hr_img in zip(lr_imgs, hr_imgs):
# print(lr_img.shape)
hr_est = model(lr_img)

content_loss = torch.mean((hr_est - hr_img) ** 2)
#ssim-content_loss
#ssim_loss = pytorch_ssim.SSIM(window_size = 11)
#content_loss = ssim_loss(hr_est, hr_img)
# ssim_loss = pytorch_ssim.ssim(hr_est, hr_img).data[0]
# ssim_loss.to(device)
# content_loss = ssim_loss

#flow_loss = flow_criterion(lr_est, lr_img)

#print(f'content_loss is {content_loss}, flow_loss is {flow_loss}')
batch_content_loss += content_loss
#batch_flow_loss += flow_loss


#print(f'loss is {loss}')
loss = batch_content_loss
loss.backward()

print(f'content_loss {batch_content_loss}')

# print("success")
optimizer.step()
running_loss += loss.item()

if batch_num % output_period == 0:
print('[%d:%.2f] loss: %.3f' % (
epoch, batch_num * 1.0 / num_train_batches,
running_loss / output_period
))
running_loss = 0.0
gc.collect()

gc.collect()
# save after every epoch
torch.save(model.state_dict(), "models/LR-5_SRN.%d" % epoch)

# model.eval()

# a helper function to calc topk error
# def calcTopKError(loader, k, name):
# epoch_topk_err = 0.0
#
# for batch_num, (inputs, labels) in enumerate(loader, 1):
# inputs = inputs.to(device)
# labels = labels.to(device)
# outputs = model(inputs)
#
# _,cls = torch.topk(outputs,dim=1,k=k)
# batch_topk_err = (1 - (cls.numel()-torch.nonzero(cls-labels.view(-1,1)).shape[0])/labels.numel())
# epoch_topk_err = epoch_topk_err * ((batch_num-1) / batch_num) \
# + batch_topk_err / batch_num
#
# if batch_num % output_period == 0:
# # print('[%d:%.2f] %s_Topk_error: %.3f' % (
# # epoch,
# # batch_num*1.0/num_val_batches,
# # name,
# # epoch_topk_err/batch_num
# # ))
# gc.collect()
#
#
# return epoch_topk_err

gc.collect()
epoch += 1


if __name__ == "__main__":
print('Starting training')
run()
print('Training terminated')

Loading

0 comments on commit 7144b28

Please sign in to comment.