|
| 1 | +import random |
| 2 | +import numpy as np |
| 3 | +import cv2 |
| 4 | +import lmdb |
| 5 | +import torch |
| 6 | +import torch.utils.data as data |
| 7 | +import data.util as util |
| 8 | + |
| 9 | + |
| 10 | +class LQGTDataset(data.Dataset): |
| 11 | + """ |
| 12 | + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs. |
| 13 | + If only GT images are provided, generate LQ images on-the-fly. |
| 14 | + """ |
| 15 | + |
| 16 | + def __init__(self, opt): |
| 17 | + super(LQGTDataset, self).__init__() |
| 18 | + self.opt = opt |
| 19 | + self.data_type = self.opt['data_type'] |
| 20 | + self.paths_LQ, self.paths_GT = None, None |
| 21 | + self.sizes_LQ, self.sizes_GT = None, None |
| 22 | + self.LQ_env, self.GT_env = None, None # environments for lmdb |
| 23 | + |
| 24 | + self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) |
| 25 | + self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) |
| 26 | + |
| 27 | + # self.paths_GT=[self.paths_GT[400000],self.paths_GT[800000],self.paths_GT[1200000]] |
| 28 | + # self.paths_LQ=[self.paths_LQ[400000],self.paths_LQ[800000],self.paths_LQ[1200000]] |
| 29 | + |
| 30 | + |
| 31 | + assert self.paths_GT, 'Error: GT path is empty.' |
| 32 | + if self.paths_LQ and self.paths_GT: |
| 33 | + assert len(self.paths_LQ) == len( |
| 34 | + self.paths_GT |
| 35 | + ), 'GT and LQ datasets have different number of images - {}, {}.'.format( |
| 36 | + len(self.paths_LQ), len(self.paths_GT)) |
| 37 | + self.random_scale_list = [1] |
| 38 | + |
| 39 | + def _init_lmdb(self): |
| 40 | + # https://github.com/chainer/chainermn/issues/129 |
| 41 | + self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, |
| 42 | + meminit=False) |
| 43 | + self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, |
| 44 | + meminit=False) |
| 45 | + |
| 46 | + def __getitem__(self, index): |
| 47 | + if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): |
| 48 | + self._init_lmdb() |
| 49 | + GT_path, LQ_path = None, None |
| 50 | + scale = self.opt['scale'] |
| 51 | + GT_size = self.opt['GT_size'] |
| 52 | + |
| 53 | + # get GT image |
| 54 | + GT_path = self.paths_GT[index] |
| 55 | + resolution = [int(s) for s in self.sizes_GT[index].split('_') |
| 56 | + ] if self.data_type == 'lmdb' else None |
| 57 | + img_GT = util.read_img(self.GT_env, GT_path, resolution) |
| 58 | + if self.opt['phase'] != 'train': # modcrop in the validation / test phase |
| 59 | + img_GT = util.modcrop(img_GT, scale) |
| 60 | + if self.opt['color']: # change color space if necessary |
| 61 | + img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] |
| 62 | + |
| 63 | + # get LQ image |
| 64 | + if self.paths_LQ: |
| 65 | + LQ_path = self.paths_LQ[index] |
| 66 | + resolution = [int(s) for s in self.sizes_LQ[index].split('_') |
| 67 | + ] if self.data_type == 'lmdb' else None |
| 68 | + img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) |
| 69 | + else: # down-sampling on-the-fly |
| 70 | + # randomly scale during training |
| 71 | + if self.opt['phase'] == 'train': |
| 72 | + random_scale = random.choice(self.random_scale_list) |
| 73 | + H_s, W_s, _ = img_GT.shape |
| 74 | + |
| 75 | + def _mod(n, random_scale, scale, thres): |
| 76 | + rlt = int(n * random_scale) |
| 77 | + rlt = (rlt // scale) * scale |
| 78 | + return thres if rlt < thres else rlt |
| 79 | + |
| 80 | + H_s = _mod(H_s, random_scale, scale, GT_size) |
| 81 | + W_s = _mod(W_s, random_scale, scale, GT_size) |
| 82 | + img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR) |
| 83 | + if img_GT.ndim == 2: |
| 84 | + img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) |
| 85 | + |
| 86 | + H, W, _ = img_GT.shape |
| 87 | + # using matlab imresize |
| 88 | + img_LQ = util.imresize_np(img_GT, 1 / scale, True) |
| 89 | + if img_LQ.ndim == 2: |
| 90 | + img_LQ = np.expand_dims(img_LQ, axis=2) |
| 91 | + |
| 92 | + if self.opt['phase'] == 'train': |
| 93 | + # if the image size is too small |
| 94 | + H, W, _ = img_GT.shape |
| 95 | + if H < GT_size or W < GT_size: |
| 96 | + img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) |
| 97 | + # using matlab imresize |
| 98 | + img_LQ = util.imresize_np(img_GT, 1 / scale, True) |
| 99 | + if img_LQ.ndim == 2: |
| 100 | + img_LQ = np.expand_dims(img_LQ, axis=2) |
| 101 | + |
| 102 | + H, W, C = img_LQ.shape |
| 103 | + LQ_size = GT_size // scale |
| 104 | + |
| 105 | + # randomly crop |
| 106 | + rnd_h = random.randint(0, max(0, H - LQ_size)) |
| 107 | + rnd_w = random.randint(0, max(0, W - LQ_size)) |
| 108 | + img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] |
| 109 | + rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) |
| 110 | + img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] |
| 111 | + |
| 112 | + # augmentation - flip, rotate |
| 113 | + img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'], |
| 114 | + self.opt['use_rot']) |
| 115 | + |
| 116 | + if self.opt['color']: # change color space if necessary |
| 117 | + img_LQ = util.channel_convert(C, self.opt['color'], |
| 118 | + [img_LQ])[0] # TODO during val no definition |
| 119 | + |
| 120 | + # BGR to RGB, HWC to CHW, numpy to tensor |
| 121 | + if img_GT.shape[2] == 3: |
| 122 | + img_GT = img_GT[:, :, [2, 1, 0]] |
| 123 | + img_LQ = img_LQ[:, :, [2, 1, 0]] |
| 124 | + img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() |
| 125 | + img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() |
| 126 | + |
| 127 | + if LQ_path is None: |
| 128 | + LQ_path = GT_path |
| 129 | + return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path} |
| 130 | + |
| 131 | + def __len__(self): |
| 132 | + return len(self.paths_GT) |
0 commit comments