Skip to content

Commit 6c74332

Browse files
authored
Add files via upload
1 parent 8228a46 commit 6c74332

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+5232
-0
lines changed

RK2-s/data/__init__.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from importlib import import_module
2+
from dataloader import MSDataLoader
3+
from torch.utils.data import ConcatDataset
4+
5+
# This is a simple wrapper function for ConcatDataset
6+
class MyConcatDataset(ConcatDataset):
7+
def __init__(self, datasets):
8+
super(MyConcatDataset, self).__init__(datasets)
9+
self.train = datasets[0].train
10+
11+
def set_scale(self, idx_scale):
12+
for d in self.datasets:
13+
if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
14+
15+
class Data:
16+
def __init__(self, args):
17+
self.loader_train = None
18+
if not args.test_only:
19+
datasets = []
20+
for d in args.data_train:
21+
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
22+
m = import_module('data.' + module_name.lower())
23+
datasets.append(getattr(m, module_name)(args, name=d))
24+
25+
self.loader_train = MSDataLoader(
26+
args,
27+
MyConcatDataset(datasets),
28+
batch_size=args.batch_size,
29+
shuffle=True,
30+
pin_memory=not args.cpu
31+
)
32+
33+
self.loader_test = []
34+
for d in args.data_test:
35+
if d in ['Set5', 'Set14', 'B100', 'Urban100']:
36+
m = import_module('data.benchmark')
37+
testset = getattr(m, 'Benchmark')(args, train=False, name=d)
38+
else:
39+
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
40+
m = import_module('data.' + module_name.lower())
41+
testset = getattr(m, module_name)(args, train=False, name=d)
42+
43+
self.loader_test.append(MSDataLoader(
44+
args,
45+
testset,
46+
batch_size=1,
47+
shuffle=False,
48+
pin_memory=not args.cpu
49+
))
50+

RK2-s/data/benchmark.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import os
2+
3+
from data import common
4+
from data import srdata
5+
6+
import numpy as np
7+
8+
import torch
9+
import torch.utils.data as data
10+
11+
class Benchmark(srdata.SRData):
12+
def __init__(self, args, name='', train=True, benchmark=True):
13+
super(Benchmark, self).__init__(
14+
args, name=name, train=train, benchmark=True
15+
)
16+
17+
def _set_filesystem(self, dir_data):
18+
self.apath = os.path.join(dir_data, 'benchmark', self.name)
19+
self.dir_hr = os.path.join(self.apath, 'HR')
20+
if self.input_large:
21+
self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
22+
else:
23+
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
24+
self.ext = ('', '.png')
25+

RK2-s/data/common.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import random
2+
3+
import numpy as np
4+
import skimage.color as sc
5+
6+
import torch
7+
8+
def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
9+
ih, iw = args[0].shape[:2]
10+
11+
if not input_large:
12+
p = scale if multi else 1
13+
tp = p * patch_size
14+
ip = tp // scale
15+
else:
16+
tp = patch_size
17+
ip = patch_size
18+
19+
ix = random.randrange(0, iw - ip + 1)
20+
iy = random.randrange(0, ih - ip + 1)
21+
22+
if not input_large:
23+
tx, ty = scale * ix, scale * iy
24+
else:
25+
tx, ty = ix, iy
26+
27+
ret = [
28+
args[0][iy:iy + ip, ix:ix + ip, :],
29+
*[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
30+
]
31+
32+
return ret
33+
34+
def set_channel(*args, n_channels=3):
35+
def _set_channel(img):
36+
if img.ndim == 2:
37+
img = np.expand_dims(img, axis=2)
38+
39+
c = img.shape[2]
40+
if n_channels == 1 and c == 3:
41+
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
42+
elif n_channels == 3 and c == 1:
43+
img = np.concatenate([img] * n_channels, 2)
44+
45+
return img
46+
47+
return [_set_channel(a) for a in args]
48+
49+
def np2Tensor(*args, rgb_range=255):
50+
def _np2Tensor(img):
51+
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
52+
tensor = torch.from_numpy(np_transpose).float()
53+
tensor.mul_(rgb_range / 255)
54+
55+
return tensor
56+
57+
return [_np2Tensor(a) for a in args]
58+
59+
def augment(*args, hflip=True, rot=True):
60+
hflip = hflip and random.random() < 0.5
61+
vflip = rot and random.random() < 0.5
62+
rot90 = rot and random.random() < 0.5
63+
64+
def _augment(img):
65+
if hflip: img = img[:, ::-1, :]
66+
if vflip: img = img[::-1, :, :]
67+
if rot90: img = img.transpose(1, 0, 2)
68+
69+
return img
70+
71+
return [_augment(a) for a in args]
72+

RK2-s/data/demo.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
3+
from data import common
4+
5+
import numpy as np
6+
import imageio
7+
8+
import torch
9+
import torch.utils.data as data
10+
11+
class Demo(data.Dataset):
12+
def __init__(self, args, name='Demo', train=False, benchmark=False):
13+
self.args = args
14+
self.name = name
15+
self.scale = args.scale
16+
self.idx_scale = 0
17+
self.train = False
18+
self.benchmark = benchmark
19+
20+
self.filelist = []
21+
for f in os.listdir(args.dir_demo):
22+
if f.find('.png') >= 0 or f.find('.jp') >= 0:
23+
self.filelist.append(os.path.join(args.dir_demo, f))
24+
self.filelist.sort()
25+
26+
def __getitem__(self, idx):
27+
filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]
28+
lr = imageio.imread(self.filelist[idx])
29+
lr, = common.set_channel(lr, n_channels=self.args.n_colors)
30+
lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
31+
32+
return lr_t, -1, filename
33+
34+
def __len__(self):
35+
return len(self.filelist)
36+
37+
def set_scale(self, idx_scale):
38+
self.idx_scale = idx_scale
39+

RK2-s/data/div2k.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import os
2+
from data import srdata
3+
4+
class DIV2K(srdata.SRData):
5+
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
6+
data_range = [r.split('-') for r in args.data_range.split('/')]
7+
if train:
8+
data_range = data_range[0]
9+
else:
10+
if args.test_only and len(data_range) == 1:
11+
data_range = data_range[0]
12+
else:
13+
data_range = data_range[1]
14+
15+
self.begin, self.end = list(map(lambda x: int(x), data_range))
16+
super(DIV2K, self).__init__(
17+
args, name=name, train=train, benchmark=benchmark
18+
)
19+
20+
def _scan(self):
21+
names_hr, names_lr = super(DIV2K, self)._scan()
22+
names_hr = names_hr[self.begin - 1:self.end]
23+
names_lr = [n[self.begin - 1:self.end] for n in names_lr]
24+
25+
return names_hr, names_lr
26+
27+
def _set_filesystem(self, dir_data):
28+
super(DIV2K, self)._set_filesystem(dir_data)
29+
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
30+
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
31+
if self.input_large: self.dir_lr += 'L'
32+

RK2-s/data/div2kjpeg.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import os
2+
from data import srdata
3+
from data import div2k
4+
5+
class DIV2KJPEG(div2k.DIV2K):
6+
def __init__(self, args, name='', train=True, benchmark=False):
7+
self.q_factor = int(name.replace('DIV2K-Q', ''))
8+
super(DIV2KJPEG, self).__init__(
9+
args, name=name, train=train, benchmark=benchmark
10+
)
11+
12+
def _set_filesystem(self, dir_data):
13+
self.apath = os.path.join(dir_data, 'DIV2K')
14+
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
15+
self.dir_lr = os.path.join(
16+
self.apath, 'DIV2K_Q{}'.format(self.q_factor)
17+
)
18+
if self.input_large: self.dir_lr += 'L'
19+
self.ext = ('.png', '.jpg')
20+

RK2-s/data/sr291.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from data import srdata
2+
3+
class SR291(srdata.SRData):
4+
def __init__(self, args, name='SR291', train=True, benchmark=False):
5+
super(SR291, self).__init__(args, name=name)
6+

0 commit comments

Comments
 (0)