Skip to content

Commit

Permalink
updated project name to FRVSR-GAN and credits
Browse files Browse the repository at this point in the history
  • Loading branch information
Aman Chadha authored and Aman Chadha committed Dec 2, 2019
1 parent 0aa3db3 commit 15e0fdc
Show file tree
Hide file tree
Showing 14 changed files with 51 additions and 44 deletions.
4 changes: 2 additions & 2 deletions AFRVSRModels.py → FRVSRGAN_Models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This file contains implementation of FRVSR (FNet and SRNet) from https://arxiv.org/abs/1801.04590
[email protected]
Aman Chadha | [email protected]
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018
"""
Expand Down Expand Up @@ -137,7 +137,7 @@ def __init__(self, batch_size, lr_height, lr_width):
self.height = lr_height
self.batch_size = batch_size
self.fnet = FNet()
self.todepth = SpaceToDepth(FRVSR.SRFactor)
self.todepth = SpaceToDepth(FRVSRGAN.SRFactor)
self.srnet = SRNet(FRVSR.SRFactor * FRVSR.SRFactor * 3 + 3) # 3 is channel number

# make sure to call this before every batch train.
Expand Down
20 changes: 10 additions & 10 deletions AFRVSRTest.py → FRVSRGAN_Test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This file does a quick check of a trained FRVSR model on a single low resolution video source and upscales it to 4x.
[email protected]
This file does a quick check of a trained FRVSR-GAN model on a single low resolution video source and upscales it to 4x.
Aman Chadha | [email protected]
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018
"""
Expand All @@ -12,7 +12,7 @@
import torch.nn.functional as func
import matplotlib.pyplot as plt
import DatasetLoader
import AFRVSRModels
import FRVSRGAN_Models
from skimage import img_as_ubyte
from skimage.util import img_as_float32

Expand Down Expand Up @@ -85,10 +85,10 @@ def psnr(img1, img2):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Test Single Video')
# Use FR-SRGAN
parser.add_argument('--model', default='./epochs/netG_epoch_4_7.pth', type=str, help='AFRVSR Model')
parser.add_argument('--model', default='./epochs/netG_epoch_4_7.pth', type=str, help='FRVSRGAN Model')

# Use FRVSR
# parser.add_argument('--model', default='./models/FRVSR.4', type=str, help='AFRVSR Model')
# parser.add_argument('--model', default='./models/FRVSR.4', type=str, help='FRVSRGAN Model')

opt = parser.parse_args()

Expand All @@ -97,7 +97,7 @@ def psnr(img1, img2):

with torch.no_grad():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = AFRVSRModels.FRVSR(0, 0, 0)
model = FRVSRGAN_Models.FRVSR(0, 0, 0)
model.to(device)

# for cpu
Expand Down Expand Up @@ -126,10 +126,10 @@ def psnr(img1, img2):
hr_video_size = (lr_width * UPSCALE_FACTOR, lr_height * UPSCALE_FACTOR)
lr_video_size = (lr_width, lr_height)

output_sr_name = 'AFRVSROut_' + str(UPSCALE_FACTOR) + f'_{idx}_' + 'Random_Sample.mp4'
output_gt_name = 'AFRVSROut_' + 'GroundTruth' + f'_{idx}_' + 'Random_Sample.mp4'
output_lr_name = 'AFRVSROut_' + 'LowRes' + '_' + 'Random_Sample.mp4'
output_aw_name = 'AFRVSROut_' + 'IntermediateWarp' + '_' + 'Random_Sample.mp4'
output_sr_name = 'FRVSRGAN_Out_' + str(UPSCALE_FACTOR) + f'_{idx}_' + 'Random_Sample.mp4'
output_gt_name = 'FRVSRGAN_Out_' + 'GroundTruth' + f'_{idx}_' + 'Random_Sample.mp4'
output_lr_name = 'FRVSRGAN_Out_' + 'LowRes' + '_' + 'Random_Sample.mp4'
output_aw_name = 'FRVSRGAN_Out_' + 'IntermediateWarp' + '_' + 'Random_Sample.mp4'

fourcc = cv2.VideoWriter_fourcc(*'MP4V')
hr_video_writer = cv2.VideoWriter(output_sr_name, fourcc, fps, hr_video_size)
Expand Down
13 changes: 10 additions & 3 deletions AFRVSRTrain.py → FRVSRGAN_Train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
This file trains a FRVSR-GAN model on based on an upscaling factor of 4x.
Aman Chadha | [email protected]
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018
"""

import argparse
from math import log10
import gc
Expand All @@ -7,8 +14,8 @@
from tqdm import tqdm
import DatasetLoader
import logger
from AFRVSRModels import FRVSR
from AFRVSRModels import GeneratorLoss
from FRVSRGAN_Models import FRVSR
from FRVSRGAN_Models import GeneratorLoss
from SRGAN.model import Discriminator
import SRGAN.pytorch_ssim as pts

Expand Down Expand Up @@ -216,7 +223,7 @@ def saveModelParams(epoch, runningResults, validationResults={}):
data_frame = pd.DataFrame(data={'DLoss': results['DLoss'], 'GLoss': results['GLoss'], 'DScore': results['DScore'],
'GScore': results['GScore']},#, 'PSNR': results['PSNR'], 'SSIM': results['SSIM']},
index=range(1, epoch + 1))
data_frame.to_csv(out_path + 'AFRVSR_' + str(UPSCALE_FACTOR) + '_Train_Results.csv', index_label='Epoch')
data_frame.to_csv(out_path + 'FRVSRGAN__' + str(UPSCALE_FACTOR) + '_Train_Results.csv', index_label='Epoch')

def main():
""" Lets begin the training process! """
Expand Down
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# FRVSR-GAN: A Novel Approach to Video Super-Resolution using Frame Recurrence and Generative Adversarial Networks
# FRVSRGAN: A Novel Approach to Video Super-Resolution using Frame Recurrence and Generative Adversarial Networks

Project for Stanford CS230: Deep Learning

Expand All @@ -20,12 +20,12 @@ To load,

## Overview

Recently, learning-based models have enhanced the performance of Single-Image Super-Resolution (SISR). However, applying SISR successively to each video frame leads to lack of temporal consistency. On the other hand, VSR models based on convolutional neural networks outperform traditional approaches in terms of image quality metrics such as Peak Signal to Noise Ratio (PSNR) and Structural SIMilarity (SSIM). While optimizing mean squared reconstruction error during training improves PSNR and SSIM, these metrics may not capture fine details in the image leading to misrepresentation of perceptual quality. We propose an Adaptive Frame Recurrent Video Super Resolution (AFRVSR) scheme that seeks to improve temporal consistency by utilizing information multiple similar adjacent frames (both future LR frames and previous SR estimates), in addition to the current frame. Further, to improve the “naturality” associated with the reconstructed image while eliminating artifacts seen with traditional algorithms, we combine the output of the AFRVSR algorithm with a Super-Resolution Generative Adversarial Network (SRGAN). The proposed idea thus not only considers spatial information in the current frame but also temporal information in the adjacent frames thereby offering superior reconstruction fidelity. Once our implementation is complete, we plan to show results on publicly available datasets that demonstrate that the proposed algorithms surpass current state-of-the-art performance in both accuracy and efficiency.
Recently, learning-based models have enhanced the performance of Single-Image Super-Resolution (SISR). However, applying SISR successively to each video frame leads to lack of temporal consistency. On the other hand, VSR models based on convolutional neural networks outperform traditional approaches in terms of image quality metrics such as Peak Signal to Noise Ratio (PSNR) and Structural SIMilarity (SSIM). While optimizing mean squared reconstruction error during training improves PSNR and SSIM, these metrics may not capture fine details in the image leading to misrepresentation of perceptual quality. We propose an Adaptive Frame Recurrent Video Super Resolution (FRVSRGAN-GAN) scheme that seeks to improve temporal consistency by utilizing information multiple similar adjacent frames (both future LR frames and previous SR estimates), in addition to the current frame. Further, to improve the “naturality” associated with the reconstructed image while eliminating artifacts seen with traditional algorithms, we combine the output of the FRVSRGAN-GAN algorithm with a Super-Resolution Generative Adversarial Network (SRGAN). The proposed idea thus not only considers spatial information in the current frame but also temporal information in the adjacent frames thereby offering superior reconstruction fidelity. Once our implementation is complete, we plan to show results on publicly available datasets that demonstrate that the proposed algorithms surpass current state-of-the-art performance in both accuracy and efficiency.

![adjacent frame similarity](https://github.com/amanchadha/FRVSR-GAN/blob/master/images/iSeeBetter_AFS.jpg)
![adjacent frame similarity](https://github.com/amanchadha/FRVSRGAN/blob/master/images/iSeeBetter_AFS.jpg)
Figure 1: Adjacent frame similarity

![network arch](https://github.com/amanchadha/FRVSR-GAN/blob/master/images/iSeeBetter_NNArch.jpg)
![network arch](https://github.com/amanchadha/FRVSRGAN/blob/master/images/iSeeBetter_NNArch.jpg)
Figure 2: Network architecture

## Dataset
Expand All @@ -34,7 +34,7 @@ To train and evaluate our proposed model, we used the [Vimeo90K](http://data.csa

## Results

![results](https://github.com/amanchadha/FRVSR-GAN/blob/master/images/iSeeBetter_Results.jpg)
![results](https://github.com/amanchadha/FRVSRGAN/blob/master/images/iSeeBetter_Results.jpg)

## Pretrained Model
Model trained for 7 epochs included under ```epochs/```
Expand All @@ -45,13 +45,13 @@ Model trained for 7 epochs included under ```epochs/```

Train the model using (takes roughly an hour per epoch on an NVIDIA Tesla V100):

```python AFRVSRTrain.py```
```python FRVSRGAN_Train.py```

### Testing

To use the pre-trained model and test on a random video from within the dataset:

```python AFRVSRTest.py```
```python FRVSRGAN_Test.py```

## Acknowledgements

Expand Down
2 changes: 1 addition & 1 deletion SRGAN/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_data_loaders(batch, shuffle_dataset=True, dataset_size=0, validation_spl
print(f'lr_img shape is {lr_img.shape}, hr_img shape is {hr_img.shape}')
break

# class TestFRVSR(unittest.TestCase):
# class TestFRVSRGAN(unittest.TestCase):
# def TestGetDataLoader(self):
#

Expand Down
2 changes: 1 addition & 1 deletion testbenches/Dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This file contains implementation of dataset classes.
[email protected]
Aman Chadha | [email protected]
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018
"""
Expand Down
6 changes: 3 additions & 3 deletions testbenches/Test_iSeeBetter_FRVSR.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This file tests FRVSR on a single low resolution video source and upscales it to 4x.
[email protected]
Aman Chadha | [email protected]
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018
"""
Expand All @@ -13,7 +13,7 @@
from torch.autograd import Variable
from torchvision.transforms import ToTensor
from tqdm import tqdm
import AFRVSRModels
import FRVSRGANModels
import checkTrain

if __name__ == "__main__":
Expand All @@ -28,7 +28,7 @@
MODEL_NAME = opt.model

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = AFRVSRModels.FRVSR(0, 0, 0)
model = FRVSRGANModels.FRVSR(0, 0, 0)

model.to(device)

Expand Down
6 changes: 3 additions & 3 deletions testbenches/Test_iSeeBetter_SRNet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This file tests the SRNet model within FRVSR on a single low resolution video source and upscales it to 4x.
[email protected]
Aman Chadha | [email protected]
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018
"""
Expand All @@ -14,7 +14,7 @@
from tqdm import tqdm
import Dataset
import checkTrain
import AFRVSRModels
import FRVSRGAN_Models

if __name__ == "__main__":
with torch.no_grad():
Expand All @@ -30,7 +30,7 @@
print(MODEL_NAME)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#model = FRVSR.FRVSR(0, 0, 0)
model = AFRVSRModels.SRNet(3) # testing the SRNet only
model = FRVSRGAN_Models.SRNet(3) # testing the SRNet only

model.to(device)

Expand Down
10 changes: 5 additions & 5 deletions testbenches/Train_iSeeBetter_FRVSR.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This file trains FRVSR on a single low resolution video source and upscales it to 4x.
[email protected]
Aman Chadha | [email protected]
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018
"""
Expand All @@ -16,12 +16,12 @@
from SRGAN import pytorch_ssim

torch.backends.cudnn.benchmark = True
import AFRVSRModels
import Dataset_OnlyHR
import FRVSRGAN_Models
import DatasetLoader


def load_model(model_name, batch_size, width, height):
model = AFRVSRModels.FRVSR(batch_size=batch_size, lr_height=height, lr_width=width)
model = FRVSRGANModels.FRVSR(batch_size=batch_size, lr_height=height, lr_width=width)
if model_name != '':
model_path = f'./models/{model_name}'
checkpoint = torch.load(model_path, map_location='cpu')
Expand Down Expand Up @@ -50,7 +50,7 @@ def run():
num_val_batches = len(val_loader)

flow_criterion = nn.MSELoss().to(device)
content_criterion = AFRVSRModels.Loss().to(device)
content_criterion = FRVSRGANModels.Loss().to(device)

ssim_loss = pytorch_ssim.SSIM(window_size=11).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-5)
Expand Down
6 changes: 3 additions & 3 deletions testbenches/Train_iSeeBetter_SRNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
torch.backends.cudnn.benchmark = True
import matplotlib.pyplot as plt
import numpy as np
import AFRVSRModels
import FRVSRGAN_Models
import Dataset
import pytorch_ssim
from skimage.measure import compare_ssim as ssim


def load_model(model_name, batch_size, width, height):
model = AFRVSRModels.SRNet(in_dim=3)
model = FRVSRGAN_Models.SRNet(in_dim=3)
if model_name != '':
model_path = f'./models/{model_name}'
print("successfully loaded the model")
Expand All @@ -36,7 +36,7 @@ def run():
model = load_model('', batch_size, width, height)
model = model.to(device)

torch.save(model.state_dict(), "models/AFRVSRTest")
torch.save(model.state_dict(), "models/FRVSRGAN_Test")

train_loader, val_loader = Dataset.get_data_loaders(batch_size, dataset_size=7000, validation_split=0)
num_train_batches = len(train_loader)
Expand Down
6 changes: 3 additions & 3 deletions testbenches/checkTrain_SRNet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This file does a quick check of the SRNet model within FRVSR on a single low resolution video source and upscales it to 4x.
[email protected]
Aman Chadha | [email protected]
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018
"""
Expand All @@ -13,7 +13,7 @@
import torch.nn.functional as func
import matplotlib.pyplot as plt
import Dataset
import AFRVSRModels
import FRVSRGANModels
from skimage import img_as_ubyte
from skimage.util import img_as_float32

Expand Down Expand Up @@ -92,7 +92,7 @@ def psnr(img1, img2):
MODEL_NAME = opt.model

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = AFRVSRModels.SRNet(3)
model = FRVSRGANModels.SRNet(3)
model.to(device)

# for cpu
Expand Down
2 changes: 1 addition & 1 deletion utils/ReadyDataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# https://macpaw.com/how-to/remove-ds-store-files-on-mac
# find . -name '.DS_Store' -type f -delete
[email protected]
Aman Chadha | [email protected]
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018
"""
Expand Down
2 changes: 1 addition & 1 deletion utils/Upscale4x1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This file creates a 4x1 upscaled video.
[email protected]
Aman Chadha | [email protected]
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018
"""
Expand Down
2 changes: 1 addition & 1 deletion utils/Vid4_Video.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This file contains implementation of dataset classes.
[email protected]
Aman Chadha | [email protected]
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018
"""
Expand Down

0 comments on commit 15e0fdc

Please sign in to comment.