@@ -18,17 +18,16 @@ def parse_args():
18
18
desc = "MAD-VAE for adversarial defense"
19
19
parser = argparse .ArgumentParser (description = desc )
20
20
parser .add_argument ('--batch_size' , type = int , default = 512 , help = 'Training batch size' )
21
- parser .add_argument ('--epochs' , type = int , default = 10 , help = 'Training epoch numbers' )
21
+ parser .add_argument ('--epochs' , type = int , default = 5 , help = 'Training epoch numbers' )
22
22
parser .add_argument ('--h_dim' , type = int , default = 4096 , help = 'Hidden dimensions' )
23
23
parser .add_argument ('--z_dim' , type = int , default = 128 , help = 'Latent dimensions for images' )
24
24
parser .add_argument ('--image_channels' , type = int , default = 1 , help = 'Image channels' )
25
25
parser .add_argument ('--image_size' , type = int , default = 28 , help = 'Image size (default to be squared images)' )
26
26
parser .add_argument ('--num_classes' , type = int , default = 10 , help = 'Number of image classes' )
27
27
parser .add_argument ('--log_dir' , type = str , default = 'pd_logs' , help = 'Logs directory' )
28
28
parser .add_argument ('--lr' , type = float , default = 0.001 , help = 'Learning rate for the Adam optimizer' )
29
- parser .add_argument ('--closs_weight' , type = float , default = 0.1 , help = 'Weight for classification loss functions' )
30
29
parser .add_argument ('--ploss_weight' , type = float , default = 0.01 , help = 'Weight for proximity loss functions' )
31
- parser .add_argument ('--dloss_weight' , type = float , default = 0.00001 , help = 'Weight for distance loss functions' )
30
+ parser .add_argument ('--dloss_weight' , type = float , default = 0.0001 , help = 'Weight for distance loss functions' )
32
31
parser .add_argument ('--data_root' , type = str , default = 'data' , help = 'Data directory' )
33
32
parser .add_argument ('--model_dir' , type = str , default = 'pretrained_model' , help = 'Pretrained model directory' )
34
33
parser .add_argument ('--use_gpu' , type = bool , default = True , help = 'If use GPU for training' )
@@ -78,18 +77,18 @@ def main():
78
77
writer1 .add_image ("reconstruct data" , outputs [i ][0 ], step )
79
78
80
79
# print out loss
81
- print ("batch {}'s img_recon loss: {:.5f}, recon loss: {:.5f}, kl loss: {:.5f}" \
80
+ print ("batch {}'s img_recon loss: {:.5f}, recon loss: {:.5f}, kl loss: {:.5f}, pd_loss: {:.5f} " \
82
81
.format (step , np .sum (img_losses )/ len (img_losses ), np .sum (recon_losses )/ len (recon_losses ),\
83
- np .sum (kl_losses )/ len (kl_losses )))
82
+ np .sum (kl_losses )/ len (kl_losses ), np . sum ( pd_losses ) / len ( pd_losses ) ))
84
83
85
84
# step scheduler
86
85
scheduler .step ()
87
86
scheduler1 .step ()
88
87
scheduler2 .step ()
89
88
90
89
# save model parameters
91
- if epoch % 5 == 0 :
92
- torch .save (model .state_dict (), '{}/proxi_dist/params_{}.pt' .format (args .model_dir , epoch ))
90
+ # if epoch % 5 == 0:
91
+ # torch.save(model.state_dict(), '{}/proxi_dist/params_{}.pt'.format(args.model_dir, epoch))
93
92
94
93
torch .save (model .state_dict (), '{}/proxi_dist/params.pt' .format (args .model_dir ))
95
94
@@ -123,10 +122,12 @@ def train(args, dataloader, model, classifier, proximity, distance, optimizer, o
123
122
distribution = Normal (dsm , dss )
124
123
125
124
# calculate losses
126
- r_loss , img_recon , kld = recon_loss_function (output , data , distribution , step , epoch / 100 )
125
+ r_loss , img_recon , kld = recon_loss_function (output , data , distribution , step , 0.1 )
127
126
p_loss = proximity (z , label )
128
127
d_loss = distance (z , label )
129
- loss = r_loss + + args .ploss_weight * p_loss - args .dloss_weight * d_loss
128
+
129
+ pd_loss = args .ploss_weight * p_loss - args .dloss_weight * d_loss
130
+ loss = r_loss + pd_loss
130
131
loss .backward ()
131
132
132
133
# clip for gradient
@@ -136,18 +137,14 @@ def train(args, dataloader, model, classifier, proximity, distance, optimizer, o
136
137
137
138
# step optimizer
138
139
optimizer .step ()
139
- for param in proximity .parameters ():
140
- param .grad .data *= (1. / args .ploss_weight )
141
140
optimizer1 .step ()
142
- for param in distance .parameters ():
143
- param .grad .data *= (1. / args .dloss_weight )
144
141
optimizer2 .step ()
145
142
146
143
# record results
147
144
recon_losses .append (loss .cpu ().item ())
148
145
img_losses .append (img_recon .cpu ().item ())
149
146
kl_losses .append (kld .cpu ().item ())
150
- pd_losses .append (p_loss . cpu (). item () - d_loss . cpu (). item () )
147
+ pd_losses .append (pd_loss )
151
148
outputs .append (output .cpu ())
152
149
datas .append (data .cpu ())
153
150
adv_datas .append (adv_data .cpu ())
@@ -191,10 +188,10 @@ def init_models(args):
191
188
# construct optimizer
192
189
optimizer = optim .Adam (model .parameters (), lr = args .lr )
193
190
scheduler = MinExponentialLR (optimizer , gamma = 0.998 , minimum = 1e-5 )
194
- optimizer1 = optim .Adam (proximity .parameters (), lr = args .lr * 50 )
195
- scheduler1 = MinExponentialLR (optimizer1 , gamma = 0.998 , minimum = 1e-5 )
196
- optimizer2 = optim .Adam (distance .parameters (), lr = args .lr / 100 )
197
- scheduler2 = MinExponentialLR (optimizer2 , gamma = 0.998 , minimum = 1e-5 )
191
+ optimizer1 = optim .SGD (proximity .parameters (), lr = args .lr * 500 )
192
+ scheduler1 = MinExponentialLR (optimizer1 , gamma = 0.1 , minimum = 1e-5 )
193
+ optimizer2 = optim .SGD (distance .parameters (), lr = args .lr / 10 )
194
+ scheduler2 = MinExponentialLR (optimizer2 , gamma = 0.1 , minimum = 1e-5 )
198
195
199
196
return model , proximity , distance , classifier , optimizer , scheduler ,\
200
197
optimizer1 , scheduler1 , optimizer2 , scheduler2
0 commit comments