Skip to content

Commit 729d103

Browse files
committed
merge two codebook main code and test training script
1 parent f3dfdf8 commit 729d103

14 files changed

+61
-668
lines changed

DAEFR/models/association_stage.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def __init__(self,
1616
lossconfig,
1717
ckpt_path_HQ=None,
1818
ckpt_path_LQ=None,
19-
encoder_codebook_type=None,
2019
ignore_keys=[],
2120
image_key="lq",
2221
colorize_nlabels=None,
@@ -58,8 +57,6 @@ def __init__(self,
5857
self.comp_params_lr_scale = comp_params_lr_scale
5958
self.schedule_step = schedule_step
6059

61-
self.encoder_codebook_type = encoder_codebook_type
62-
6360

6461

6562
def init_from_ckpt_two(self, path_HQ, path_LQ, ignore_keys=list()):

DAEFR/models/daefr.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ def __init__(self,
110110
lossconfig,
111111
ckpt_path_HQ=None,
112112
ckpt_path_LQ=None,
113-
encoder_codebook_type=None,
114113
ignore_keys=[],
115114
image_key="lq",
116115
colorize_nlabels=None,
@@ -145,7 +144,6 @@ def __init__(self,
145144
self.comp_params_lr_scale = comp_params_lr_scale
146145
self.schedule_step = schedule_step
147146

148-
self.encoder_codebook_type = encoder_codebook_type
149147

150148
self.cross_attention = MultiHeadAttnBlock(in_channels=256,head_size=8)
151149

@@ -523,7 +521,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):
523521
def validation_step(self, batch, batch_idx):
524522
x = batch[self.image_key]
525523
gt = batch['gt']
526-
524+
527525
xrec, BCE_loss, L2_loss, info, hs,_,_,_ = self(x, gt)
528526

529527
qloss = BCE_loss + L2_loss

configs/Association_stage.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
model:
22
base_learning_rate: 4.5e-6
3+
max_epochs: 50
34
target: DAEFR.models.association_stage.DAEFRModel
45
params:
56
image_key: 'lq'
67
# HQ codebook path
78
ckpt_path_HQ: '/ssd1/yuju/DAEFR/experiments/HQ_codebook/HQ_codebook.ckpt'
89
# LQ codebook path
910
ckpt_path_LQ: '/ssd1/yuju/DAEFR/experiments/LQ_codebook/LQ_codebook.ckpt'
10-
encoder_codebook_type: 'LQHQ'
1111
schedule_step: [4000000, 8000000]
1212
ddconfig:
1313
target: DAEFR.modules.vqvae.vqvae_arch.VQVAEGAN
@@ -39,14 +39,14 @@ model:
3939
use_actnorm: False
4040

4141
data:
42-
target: main.DataModuleFromConfig
42+
target: main_for_association.DataModuleFromConfig
4343
params:
4444
batch_size: 2
4545
num_workers: 8
4646
train:
4747
target: DAEFR.data.ffhq_degradation_dataset.FFHQDegradationDataset
4848
params:
49-
dataroot_gt: /ssd1/yuju/dataset/FFHQ/images512x512
49+
dataroot_gt: /ssd2/yuju/RestoreFormer/data/FFHQ/images512x512
5050
io_backend:
5151
type: disk
5252
use_hflip: True
@@ -76,7 +76,7 @@ data:
7676
validation:
7777
target: DAEFR.data.ffhq_degradation_dataset.FFHQDegradationDataset
7878
params:
79-
dataroot_gt: /ssd1/yuju/dataset/FFHQ/images512x512_validation
79+
dataroot_gt: /ssd2/yuju/RestoreFormer/data/FFHQ/images512x512_validation
8080
io_backend:
8181
type: disk
8282
use_hflip: False

configs/DAEFR.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
model:
22
base_learning_rate: 4.5e-6
3+
max_epochs: 100
34
target: DAEFR.models.daefr.DAEFRModel
45
params:
56
image_key: 'lq'
6-
ckpt_path_HQ: '/ssd1/yuju/DAEFR/experiments/HQ_codebook_300/epoch=000129-Rec_loss=0.3460099399089813-Codebook_loss=0.012400745414197445.ckpt'
7-
ckpt_path_LQ: '/ssd1/yuju/DAEFR/experiments/2023-01-19_Dual_codebook_dis_start/checkpoints/last.ckpt'
8-
encoder_codebook_type: 'LQHQ'
7+
ckpt_path_HQ: '/ssd1/yuju/DAEFR/experiments/HQ_codebook/HQ_codebook.ckpt'
8+
ckpt_path_LQ: '/ssd1/yuju/DAEFR/experiments/Association_stage/Association_stage.ckpt'
99
special_params_lr_scale: 1
1010
comp_params_lr_scale: 10
1111
schedule_step: [4000000, 8000000]
@@ -43,14 +43,14 @@ model:
4343

4444

4545
data:
46-
target: main.DataModuleFromConfig
46+
target: main_DAEFR.DataModuleFromConfig
4747
params:
4848
batch_size: 2
4949
num_workers: 8
5050
train:
5151
target: DAEFR.data.ffhq_degradation_dataset.FFHQDegradationDataset
5252
params:
53-
dataroot_gt: /ssd1/yuju/dataset/FFHQ/images512x512
53+
dataroot_gt: /ssd2/yuju/RestoreFormer/data/FFHQ/images512x512
5454
io_backend:
5555
type: disk
5656
use_hflip: True
@@ -80,7 +80,7 @@ data:
8080
validation:
8181
target: DAEFR.data.ffhq_degradation_dataset.FFHQDegradationDataset
8282
params:
83-
dataroot_gt: /ssd1/yuju/dataset/FFHQ/images512x512_validation
83+
dataroot_gt: /ssd2/yuju/RestoreFormer/data/FFHQ/images512x512_validation
8484
io_backend:
8585
type: disk
8686
use_hflip: False

configs/HQ_codebook.yaml

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
model:
22
base_learning_rate: 4.5e-6
3+
max_epochs: 330
34
target: DAEFR.models.vqgan_origin.DAEFRModel
45
params:
56
image_key: 'gt'
@@ -35,27 +36,46 @@ model:
3536
use_actnorm: False
3637

3738
data:
38-
target: main.DataModuleFromConfig
39+
target: main_for_codebook.DataModuleFromConfig
3940
params:
4041
batch_size: 3
4142
num_workers: 8
4243
train:
4344
target: basicsr.data.ffhq_dataset.FFHQDataset
4445
params:
45-
dataroot_gt: /ssd1/yuju/dataset/FFHQ/images512x512
46+
dataroot_gt: /ssd2/yuju/RestoreFormer/data/FFHQ/images512x512
4647
io_backend:
4748
type: disk
4849
use_hflip: True
4950
mean: [0.5, 0.5, 0.5]
5051
std: [0.5, 0.5, 0.5]
5152
out_size: 512
5253
validation:
53-
target: basicsr.data.ffhq_dataset.FFHQDataset
54+
target: DAEFR.data.ffhq_degradation_dataset.FFHQDegradationDataset
5455
params:
55-
dataroot_gt: /ssd1/yuju/dataset/FFHQ/images512x512_validation
56+
dataroot_gt: /ssd2/yuju/RestoreFormer/data/FFHQ/images512x512_validation
5657
io_backend:
5758
type: disk
5859
use_hflip: False
5960
mean: [0.5, 0.5, 0.5]
6061
std: [0.5, 0.5, 0.5]
6162
out_size: 512
63+
64+
blur_kernel_size: [19,20]
65+
kernel_list: ['iso', 'aniso']
66+
kernel_prob: [0.5, 0.5]
67+
blur_sigma: [0.1, 10]
68+
downsample_range: [0.8, 8]
69+
noise_range: [0, 20]
70+
jpeg_range: [60, 100]
71+
72+
# color jitter and gray
73+
color_jitter_prob: ~
74+
color_jitter_shift: 20
75+
color_jitter_pt_prob: ~
76+
gray_prob: ~
77+
gt_gray: True
78+
79+
crop_components: False
80+
component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
81+
eye_enlarge_ratio: 1.4

configs/LQ_codebook.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
model:
22
base_learning_rate: 4.5e-6
3+
max_epochs: 250
34
target: DAEFR.models.vqgan_origin.DAEFRModel
45
params:
56
image_key: 'gt'
@@ -35,14 +36,14 @@ model:
3536
use_actnorm: False
3637

3738
data:
38-
target: main.DataModuleFromConfig
39+
target: main_for_codebook.DataModuleFromConfig
3940
params:
4041
batch_size: 3
4142
num_workers: 8
4243
train:
4344
target: DAEFR.data.ffhq_degradation_dataset_LQ.FFHQDegradationDataset
4445
params:
45-
dataroot_gt: /ssd1/yuju/dataset/FFHQ/images512x512
46+
dataroot_gt: /ssd2/yuju/RestoreFormer/data/FFHQ/images512x512
4647
io_backend:
4748
type: disk
4849
use_hflip: True
@@ -72,7 +73,7 @@ data:
7273
validation:
7374
target: DAEFR.data.ffhq_degradation_dataset_LQ.FFHQDegradationDataset
7475
params:
75-
dataroot_gt: /ssd1/yuju/dataset/FFHQ/images512x512_validation
76+
dataroot_gt: /ssd2/yuju/RestoreFormer/data/FFHQ/images512x512_validation
7677
io_backend:
7778
type: disk
7879
use_hflip: False

main_DAEFR.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
457457
gpuinfo = trainer_config["gpus"]
458458
print(f"Running on GPUs {gpuinfo}")
459459
cpu = False
460-
trainer_config["max_epochs"] = 100
460+
trainer_config["max_epochs"] = config.model.max_epochs
461461
trainer_opt = argparse.Namespace(**trainer_config)
462462
lightning_config.trainer = trainer_config
463463
if opt.resume:
@@ -527,7 +527,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
527527
# add callback which sets up log directory
528528
default_callbacks_cfg = {
529529
"setup_callback": {
530-
"target": "main.SetupCallback",
530+
"target": "main_DAEFR.SetupCallback",
531531
"params": {
532532
"resume": opt.resume,
533533
"now": now,
@@ -539,15 +539,15 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
539539
}
540540
},
541541
"image_logger": {
542-
"target": "main.ImageLogger",
542+
"target": "main_DAEFR.ImageLogger",
543543
"params": {
544544
"batch_frequency": 750,
545545
"max_images": 4,
546546
"clamp": True
547547
}
548548
},
549549
"learning_rate_logger": {
550-
"target": "main.LearningRateMonitor",
550+
"target": "main_DAEFR.LearningRateMonitor",
551551
"params": {
552552
"logging_interval": "step",
553553
#"log_momentum": True

0 commit comments

Comments
 (0)