Skip to content

Commit 866437e

Browse files
committed
first commit
0 parents  commit 866437e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+7532
-0
lines changed

README.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
## CDNSR
2+
3+
This is the source code for our paper: Classification-based Dynamic Network for Efficient Super-Resolution. A brief introduction of this work is as follows:
4+
5+
> Deep neural networks (DNNs) based approaches have achieved superior performance in single image super-resolution (SR). To obtain better visual quality, DNNs for SR are generally designed with massive computation overhead. To accelerate network inference under resource constraints, we propose a classification-based dynamic network for efficient super-resolution (CDNSR), which combines the classification and SR networks in a unified framework. Specifically, CDNSR decomposes a large image into a number of image-patches, and uses a classification network to categorize them into different classes based on the restoration difficulty. Each class of image-patches will be handled by the SR network that corresponds to the difficulty of this class. In particular, we design a new loss to trade off between the computational overhead and the reconstruction quality. Besides, we apply contrastive learning based knowledge distillation to guarantee the performance of SR networks and the quality of reconstructed images. Extensive experiments show that CDNSR significantly outperforms the other SR networks and backbones on image quality and computational overhead.
6+
7+
This paper has been accepted by ICASSP 2023. Due to the 5-page limitation of this conference, we provide a full version of technique report in this repo.
8+
9+
## Required software
10+
11+
PyTorch
12+
13+
## Pre-train & test SR-Nets
14+
`train`
15+
```python
16+
cd codes
17+
python train_SR_Net.py -opt options/train/train_CARN_branch1.yml
18+
python train_SR_Net.py -opt options/train/train_CARN_branch2.yml
19+
python train_SR_Net.py -opt options/train/train_CARN_branch3.yml
20+
```
21+
`test`
22+
```
23+
cd codes
24+
python test_SR_Net.py -opt options/test/test_CARN.yml
25+
```
26+
27+
## Train & test CDNSR
28+
`train`
29+
```
30+
cd codes
31+
python train_CDNSR.py -opt options/train/train_CDNSR_CARN.yml
32+
33+
```
34+
`distill`
35+
```
36+
cd codes
37+
python train_CDNSR.py -opt options/train/train_CDNSR_CARN_KD.yml
38+
```
39+
40+
`test`
41+
```
42+
cd codes
43+
python test_CDNSR.py -opt options/test/test_CDNSR_CARN.yml
44+
```
45+
46+
## Contact
47+
48+

codes/data/LQGT_classify_test.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import random
2+
import numpy as np
3+
import cv2
4+
import lmdb
5+
import torch
6+
import torch.utils.data as data
7+
import data.util as util
8+
9+
10+
class LQGTDataset(data.Dataset):
11+
"""
12+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
13+
If only GT images are provided, generate LQ images on-the-fly.
14+
"""
15+
16+
def __init__(self, opt):
17+
super(LQGTDataset, self).__init__()
18+
self.opt = opt
19+
self.data_type = self.opt['data_type']
20+
self.paths_LQ, self.paths_GT = None, None
21+
self.sizes_LQ, self.sizes_GT = None, None
22+
self.LQ_env, self.GT_env = None, None # environments for lmdb
23+
24+
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
25+
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
26+
assert self.paths_GT, 'Error: GT path is empty.'
27+
if self.paths_LQ and self.paths_GT:
28+
assert len(self.paths_LQ) == len(
29+
self.paths_GT
30+
), 'GT and LQ datasets have different number of images - {}, {}.'.format(
31+
len(self.paths_LQ), len(self.paths_GT))
32+
self.random_scale_list = [1]
33+
34+
def _init_lmdb(self):
35+
# https://github.com/chainer/chainermn/issues/129
36+
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
37+
meminit=False)
38+
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
39+
meminit=False)
40+
41+
def __getitem__(self, index):
42+
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
43+
self._init_lmdb()
44+
GT_path, LQ_path = None, None
45+
scale = self.opt['scale']
46+
GT_size = self.opt['GT_size']
47+
48+
# get GT image
49+
GT_path = self.paths_GT[index]
50+
resolution = [int(s) for s in self.sizes_GT[index].split('_')
51+
] if self.data_type == 'lmdb' else None
52+
img_GT = util.read_img(self.GT_env, GT_path, resolution)
53+
if self.opt['phase'] != 'train': # modcrop in the validation / test phase
54+
img_GT = util.modcrop(img_GT, scale)
55+
if self.opt['color']: # change color space if necessary
56+
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
57+
58+
# get LQ image
59+
if self.paths_LQ:
60+
LQ_path = self.paths_LQ[index]
61+
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
62+
] if self.data_type == 'lmdb' else None
63+
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
64+
else: # down-sampling on-the-fly
65+
# randomly scale during training
66+
if self.opt['phase'] == 'train':
67+
random_scale = random.choice(self.random_scale_list)
68+
H_s, W_s, _ = img_GT.shape
69+
70+
def _mod(n, random_scale, scale, thres):
71+
rlt = int(n * random_scale)
72+
rlt = (rlt // scale) * scale
73+
return thres if rlt < thres else rlt
74+
75+
H_s = _mod(H_s, random_scale, scale, GT_size)
76+
W_s = _mod(W_s, random_scale, scale, GT_size)
77+
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
78+
if img_GT.ndim == 2:
79+
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
80+
81+
H, W, _ = img_GT.shape
82+
# using matlab imresize
83+
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
84+
if img_LQ.ndim == 2:
85+
img_LQ = np.expand_dims(img_LQ, axis=2)
86+
87+
if self.opt['phase'] == 'train':
88+
# if the image size is too small
89+
H, W, _ = img_GT.shape
90+
if H < GT_size or W < GT_size:
91+
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
92+
# using matlab imresize
93+
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
94+
if img_LQ.ndim == 2:
95+
img_LQ = np.expand_dims(img_LQ, axis=2)
96+
97+
H, W, C = img_LQ.shape
98+
LQ_size = GT_size // scale
99+
100+
# randomly crop
101+
rnd_h = random.randint(0, max(0, H - LQ_size))
102+
rnd_w = random.randint(0, max(0, W - LQ_size))
103+
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
104+
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
105+
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
106+
107+
# augmentation - flip, rotate
108+
img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
109+
self.opt['use_rot'])
110+
111+
if self.opt['color']: # change color space if necessary
112+
img_LQ = util.channel_convert(C, self.opt['color'],
113+
[img_LQ])[0] # TODO during val no definition
114+
115+
# BGR to RGB, HWC to CHW, numpy to tensor
116+
if img_GT.shape[2] == 3:
117+
img_GT = img_GT[:, :, [2, 1, 0]]
118+
img_LQ = img_LQ[:, :, [2, 1, 0]]
119+
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
120+
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
121+
122+
if LQ_path is None:
123+
LQ_path = GT_path
124+
return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}
125+
126+
def __len__(self):
127+
return len(self.paths_GT)

codes/data/LQGT_dataset.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import random
2+
import numpy as np
3+
import cv2
4+
import lmdb
5+
import torch
6+
import torch.utils.data as data
7+
import data.util as util
8+
9+
10+
class LQGTDataset(data.Dataset):
11+
"""
12+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
13+
If only GT images are provided, generate LQ images on-the-fly.
14+
"""
15+
16+
def __init__(self, opt):
17+
super(LQGTDataset, self).__init__()
18+
self.opt = opt
19+
self.data_type = self.opt['data_type']
20+
self.paths_LQ, self.paths_GT = None, None
21+
self.sizes_LQ, self.sizes_GT = None, None
22+
self.LQ_env, self.GT_env = None, None # environments for lmdb
23+
24+
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
25+
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
26+
27+
# self.paths_GT=[self.paths_GT[400000],self.paths_GT[800000],self.paths_GT[1200000]]
28+
# self.paths_LQ=[self.paths_LQ[400000],self.paths_LQ[800000],self.paths_LQ[1200000]]
29+
30+
31+
assert self.paths_GT, 'Error: GT path is empty.'
32+
if self.paths_LQ and self.paths_GT:
33+
assert len(self.paths_LQ) == len(
34+
self.paths_GT
35+
), 'GT and LQ datasets have different number of images - {}, {}.'.format(
36+
len(self.paths_LQ), len(self.paths_GT))
37+
self.random_scale_list = [1]
38+
39+
def _init_lmdb(self):
40+
# https://github.com/chainer/chainermn/issues/129
41+
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
42+
meminit=False)
43+
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
44+
meminit=False)
45+
46+
def __getitem__(self, index):
47+
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
48+
self._init_lmdb()
49+
GT_path, LQ_path = None, None
50+
scale = self.opt['scale']
51+
GT_size = self.opt['GT_size']
52+
53+
# get GT image
54+
GT_path = self.paths_GT[index]
55+
resolution = [int(s) for s in self.sizes_GT[index].split('_')
56+
] if self.data_type == 'lmdb' else None
57+
img_GT = util.read_img(self.GT_env, GT_path, resolution)
58+
if self.opt['phase'] != 'train': # modcrop in the validation / test phase
59+
img_GT = util.modcrop(img_GT, scale)
60+
if self.opt['color']: # change color space if necessary
61+
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
62+
63+
# get LQ image
64+
if self.paths_LQ:
65+
LQ_path = self.paths_LQ[index]
66+
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
67+
] if self.data_type == 'lmdb' else None
68+
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
69+
else: # down-sampling on-the-fly
70+
# randomly scale during training
71+
if self.opt['phase'] == 'train':
72+
random_scale = random.choice(self.random_scale_list)
73+
H_s, W_s, _ = img_GT.shape
74+
75+
def _mod(n, random_scale, scale, thres):
76+
rlt = int(n * random_scale)
77+
rlt = (rlt // scale) * scale
78+
return thres if rlt < thres else rlt
79+
80+
H_s = _mod(H_s, random_scale, scale, GT_size)
81+
W_s = _mod(W_s, random_scale, scale, GT_size)
82+
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
83+
if img_GT.ndim == 2:
84+
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
85+
86+
H, W, _ = img_GT.shape
87+
# using matlab imresize
88+
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
89+
if img_LQ.ndim == 2:
90+
img_LQ = np.expand_dims(img_LQ, axis=2)
91+
92+
if self.opt['phase'] == 'train':
93+
# if the image size is too small
94+
H, W, _ = img_GT.shape
95+
if H < GT_size or W < GT_size:
96+
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
97+
# using matlab imresize
98+
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
99+
if img_LQ.ndim == 2:
100+
img_LQ = np.expand_dims(img_LQ, axis=2)
101+
102+
H, W, C = img_LQ.shape
103+
LQ_size = GT_size // scale
104+
105+
# randomly crop
106+
rnd_h = random.randint(0, max(0, H - LQ_size))
107+
rnd_w = random.randint(0, max(0, W - LQ_size))
108+
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
109+
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
110+
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
111+
112+
# augmentation - flip, rotate
113+
img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
114+
self.opt['use_rot'])
115+
116+
if self.opt['color']: # change color space if necessary
117+
img_LQ = util.channel_convert(C, self.opt['color'],
118+
[img_LQ])[0] # TODO during val no definition
119+
120+
# BGR to RGB, HWC to CHW, numpy to tensor
121+
if img_GT.shape[2] == 3:
122+
img_GT = img_GT[:, :, [2, 1, 0]]
123+
img_LQ = img_LQ[:, :, [2, 1, 0]]
124+
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
125+
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
126+
127+
if LQ_path is None:
128+
LQ_path = GT_path
129+
return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}
130+
131+
def __len__(self):
132+
return len(self.paths_GT)

0 commit comments

Comments
 (0)