-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
updated project name to FRVSR-GAN and credits
- Loading branch information
Aman Chadha
authored and
Aman Chadha
committed
Dec 2, 2019
1 parent
0aa3db3
commit 15e0fdc
Showing
14 changed files
with
51 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
""" | ||
This file contains implementation of FRVSR (FNet and SRNet) from https://arxiv.org/abs/1801.04590 | ||
[email protected] | ||
Aman Chadha | [email protected] | ||
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018 | ||
""" | ||
|
@@ -137,7 +137,7 @@ def __init__(self, batch_size, lr_height, lr_width): | |
self.height = lr_height | ||
self.batch_size = batch_size | ||
self.fnet = FNet() | ||
self.todepth = SpaceToDepth(FRVSR.SRFactor) | ||
self.todepth = SpaceToDepth(FRVSRGAN.SRFactor) | ||
self.srnet = SRNet(FRVSR.SRFactor * FRVSR.SRFactor * 3 + 3) # 3 is channel number | ||
|
||
# make sure to call this before every batch train. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
""" | ||
This file does a quick check of a trained FRVSR model on a single low resolution video source and upscales it to 4x. | ||
[email protected] | ||
This file does a quick check of a trained FRVSR-GAN model on a single low resolution video source and upscales it to 4x. | ||
Aman Chadha | [email protected] | ||
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018 | ||
""" | ||
|
@@ -12,7 +12,7 @@ | |
import torch.nn.functional as func | ||
import matplotlib.pyplot as plt | ||
import DatasetLoader | ||
import AFRVSRModels | ||
import FRVSRGAN_Models | ||
from skimage import img_as_ubyte | ||
from skimage.util import img_as_float32 | ||
|
||
|
@@ -85,10 +85,10 @@ def psnr(img1, img2): | |
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description='Test Single Video') | ||
# Use FR-SRGAN | ||
parser.add_argument('--model', default='./epochs/netG_epoch_4_7.pth', type=str, help='AFRVSR Model') | ||
parser.add_argument('--model', default='./epochs/netG_epoch_4_7.pth', type=str, help='FRVSRGAN Model') | ||
|
||
# Use FRVSR | ||
# parser.add_argument('--model', default='./models/FRVSR.4', type=str, help='AFRVSR Model') | ||
# parser.add_argument('--model', default='./models/FRVSR.4', type=str, help='FRVSRGAN Model') | ||
|
||
opt = parser.parse_args() | ||
|
||
|
@@ -97,7 +97,7 @@ def psnr(img1, img2): | |
|
||
with torch.no_grad(): | ||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
model = AFRVSRModels.FRVSR(0, 0, 0) | ||
model = FRVSRGAN_Models.FRVSR(0, 0, 0) | ||
model.to(device) | ||
|
||
# for cpu | ||
|
@@ -126,10 +126,10 @@ def psnr(img1, img2): | |
hr_video_size = (lr_width * UPSCALE_FACTOR, lr_height * UPSCALE_FACTOR) | ||
lr_video_size = (lr_width, lr_height) | ||
|
||
output_sr_name = 'AFRVSROut_' + str(UPSCALE_FACTOR) + f'_{idx}_' + 'Random_Sample.mp4' | ||
output_gt_name = 'AFRVSROut_' + 'GroundTruth' + f'_{idx}_' + 'Random_Sample.mp4' | ||
output_lr_name = 'AFRVSROut_' + 'LowRes' + '_' + 'Random_Sample.mp4' | ||
output_aw_name = 'AFRVSROut_' + 'IntermediateWarp' + '_' + 'Random_Sample.mp4' | ||
output_sr_name = 'FRVSRGAN_Out_' + str(UPSCALE_FACTOR) + f'_{idx}_' + 'Random_Sample.mp4' | ||
output_gt_name = 'FRVSRGAN_Out_' + 'GroundTruth' + f'_{idx}_' + 'Random_Sample.mp4' | ||
output_lr_name = 'FRVSRGAN_Out_' + 'LowRes' + '_' + 'Random_Sample.mp4' | ||
output_aw_name = 'FRVSRGAN_Out_' + 'IntermediateWarp' + '_' + 'Random_Sample.mp4' | ||
|
||
fourcc = cv2.VideoWriter_fourcc(*'MP4V') | ||
hr_video_writer = cv2.VideoWriter(output_sr_name, fourcc, fps, hr_video_size) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,10 @@ | ||
""" | ||
This file trains a FRVSR-GAN model on based on an upscaling factor of 4x. | ||
Aman Chadha | [email protected] | ||
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018 | ||
""" | ||
|
||
import argparse | ||
from math import log10 | ||
import gc | ||
|
@@ -7,8 +14,8 @@ | |
from tqdm import tqdm | ||
import DatasetLoader | ||
import logger | ||
from AFRVSRModels import FRVSR | ||
from AFRVSRModels import GeneratorLoss | ||
from FRVSRGAN_Models import FRVSR | ||
from FRVSRGAN_Models import GeneratorLoss | ||
from SRGAN.model import Discriminator | ||
import SRGAN.pytorch_ssim as pts | ||
|
||
|
@@ -216,7 +223,7 @@ def saveModelParams(epoch, runningResults, validationResults={}): | |
data_frame = pd.DataFrame(data={'DLoss': results['DLoss'], 'GLoss': results['GLoss'], 'DScore': results['DScore'], | ||
'GScore': results['GScore']},#, 'PSNR': results['PSNR'], 'SSIM': results['SSIM']}, | ||
index=range(1, epoch + 1)) | ||
data_frame.to_csv(out_path + 'AFRVSR_' + str(UPSCALE_FACTOR) + '_Train_Results.csv', index_label='Epoch') | ||
data_frame.to_csv(out_path + 'FRVSRGAN__' + str(UPSCALE_FACTOR) + '_Train_Results.csv', index_label='Epoch') | ||
|
||
def main(): | ||
""" Lets begin the training process! """ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
""" | ||
This file contains implementation of dataset classes. | ||
[email protected] | ||
Aman Chadha | [email protected] | ||
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018 | ||
""" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
""" | ||
This file tests FRVSR on a single low resolution video source and upscales it to 4x. | ||
[email protected] | ||
Aman Chadha | [email protected] | ||
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018 | ||
""" | ||
|
@@ -13,7 +13,7 @@ | |
from torch.autograd import Variable | ||
from torchvision.transforms import ToTensor | ||
from tqdm import tqdm | ||
import AFRVSRModels | ||
import FRVSRGANModels | ||
import checkTrain | ||
|
||
if __name__ == "__main__": | ||
|
@@ -28,7 +28,7 @@ | |
MODEL_NAME = opt.model | ||
|
||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
model = AFRVSRModels.FRVSR(0, 0, 0) | ||
model = FRVSRGANModels.FRVSR(0, 0, 0) | ||
|
||
model.to(device) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
""" | ||
This file tests the SRNet model within FRVSR on a single low resolution video source and upscales it to 4x. | ||
[email protected] | ||
Aman Chadha | [email protected] | ||
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018 | ||
""" | ||
|
@@ -14,7 +14,7 @@ | |
from tqdm import tqdm | ||
import Dataset | ||
import checkTrain | ||
import AFRVSRModels | ||
import FRVSRGAN_Models | ||
|
||
if __name__ == "__main__": | ||
with torch.no_grad(): | ||
|
@@ -30,7 +30,7 @@ | |
print(MODEL_NAME) | ||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
#model = FRVSR.FRVSR(0, 0, 0) | ||
model = AFRVSRModels.SRNet(3) # testing the SRNet only | ||
model = FRVSRGAN_Models.SRNet(3) # testing the SRNet only | ||
|
||
model.to(device) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
""" | ||
This file trains FRVSR on a single low resolution video source and upscales it to 4x. | ||
[email protected] | ||
Aman Chadha | [email protected] | ||
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018 | ||
""" | ||
|
@@ -16,12 +16,12 @@ | |
from SRGAN import pytorch_ssim | ||
|
||
torch.backends.cudnn.benchmark = True | ||
import AFRVSRModels | ||
import Dataset_OnlyHR | ||
import FRVSRGAN_Models | ||
import DatasetLoader | ||
|
||
|
||
def load_model(model_name, batch_size, width, height): | ||
model = AFRVSRModels.FRVSR(batch_size=batch_size, lr_height=height, lr_width=width) | ||
model = FRVSRGANModels.FRVSR(batch_size=batch_size, lr_height=height, lr_width=width) | ||
if model_name != '': | ||
model_path = f'./models/{model_name}' | ||
checkpoint = torch.load(model_path, map_location='cpu') | ||
|
@@ -50,7 +50,7 @@ def run(): | |
num_val_batches = len(val_loader) | ||
|
||
flow_criterion = nn.MSELoss().to(device) | ||
content_criterion = AFRVSRModels.Loss().to(device) | ||
content_criterion = FRVSRGANModels.Loss().to(device) | ||
|
||
ssim_loss = pytorch_ssim.SSIM(window_size=11).to(device) | ||
optimizer = optim.Adam(model.parameters(), lr=1e-5) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
""" | ||
This file does a quick check of the SRNet model within FRVSR on a single low resolution video source and upscales it to 4x. | ||
[email protected] | ||
Aman Chadha | [email protected] | ||
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018 | ||
""" | ||
|
@@ -13,7 +13,7 @@ | |
import torch.nn.functional as func | ||
import matplotlib.pyplot as plt | ||
import Dataset | ||
import AFRVSRModels | ||
import FRVSRGANModels | ||
from skimage import img_as_ubyte | ||
from skimage.util import img_as_float32 | ||
|
||
|
@@ -92,7 +92,7 @@ def psnr(img1, img2): | |
MODEL_NAME = opt.model | ||
|
||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
model = AFRVSRModels.SRNet(3) | ||
model = FRVSRGANModels.SRNet(3) | ||
model.to(device) | ||
|
||
# for cpu | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
# https://macpaw.com/how-to/remove-ds-store-files-on-mac | ||
# find . -name '.DS_Store' -type f -delete | ||
[email protected] | ||
Aman Chadha | [email protected] | ||
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018 | ||
""" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
""" | ||
This file creates a 4x1 upscaled video. | ||
[email protected] | ||
Aman Chadha | [email protected] | ||
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018 | ||
""" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
""" | ||
This file contains implementation of dataset classes. | ||
[email protected] | ||
Aman Chadha | [email protected] | ||
Adapted from FR-SRGAN, MIT 6.819 Advances in Computer Vision, Nov 2018 | ||
""" | ||
|