Skip to content

Commit 9ab555b

Browse files
Fix training and experiments
1 parent 8eeb5b2 commit 9ab555b

File tree

10 files changed

+37
-43
lines changed

10 files changed

+37
-43
lines changed

experiments/test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def parse_args():
3333

3434

3535
if __name__ == "__main__":
36-
models = ['vanilla', 'classification', 'proxi_dist', 'combined']
36+
# models = ['classification', 'proxi_dist', 'vanilla', 'combined']
37+
models = ['classification', 'vanilla']
3738
for i in range(len(models)):
3839
args = parse_args()
3940
model = MADVAE(args)

experiments/test_black.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
from torch.utils import data
66
from torchvision import datasets, transforms
77
import matplotlib.pyplot as plt
8+
from test.test_models import *
89
sys.path.insert(0, os.path.abspath('..'))
910
from MAD_VAE import *
10-
from test.test_models import *
11-
from test.plotting import *
1211
from utils.dataset import *
1312
from utils.adversarial import *
1413
from utils.classifier import *
@@ -35,7 +34,7 @@ def parse_args():
3534

3635
if __name__ == "__main__":
3736
models = ['vanilla', 'classification', 'proxi_dist', 'combined']
38-
for i in range(4):
37+
for i in range(len(models)):
3938

4039
args = parse_args()
4140
model = MADVAE(args)
@@ -84,23 +83,23 @@ def parse_args():
8483
image = image.cuda()
8584
label = label.cuda()
8685

87-
output, adv_out = add_adv(classifier, image, label, 'fgsm')
86+
output, adv_out = add_adv(classifier, image, label, 'fgsm', default=True)
8887
output_class = classifier(output)
8988
adv_output_class = classifier(adv_out)
9089
def_out, _, _, _ = model(adv_out)
91-
adv_out_class = classifier(def_out)
90+
cleaned_class = classifier(def_out)
9291

9392
true_class = torch.argmax(output_class, 1)
94-
output_class = torch.argmax(adv_output_class, 1)
95-
adversarial_class = torch.argmax(adv_out_class, 1)
93+
adv_class = torch.argmax(adv_output_class, 1)
94+
adv_clean_class = torch.argmax(cleaned_class, 1)
9695

9796
print(f'attack method fgsm')
9897
print(f'actual class {true_class}')
99-
print(f'actual advclass {output_class}')
100-
print(f'adversarial class {adversarial_class}')
98+
print(f'actual advclass {adv_class}')
99+
print(f'adversarial class {adv_clean_class}')
101100

102-
true += torch.sum(torch.eq(true_class, adversarial_class))
103-
true_adv += torch.sum(torch.eq(true_class, output_class))
101+
true += torch.sum(torch.eq(true_class, adv_clean_class))
102+
true_adv += torch.sum(torch.eq(true_class, adv_class))
104103

105104
print(int(true) / total)
106105
print(int(true_adv) / total)
578 Bytes
Binary file not shown.

pretrained_model/combined/params.pt

-840 Bytes
Binary file not shown.

pretrained_model/proxi_dist/params.pt

473 Bytes
Binary file not shown.

pretrained_model/vanilla/params.pt

578 Bytes
Binary file not shown.

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def train(args, dataloader, model, optimizer, step, epoch):
111111
distribution = Normal(dsm, dss)
112112

113113
# calculate losses
114-
r_loss, img_recon, kld = recon_loss_function(output, data, distribution, step, epoch/100)
114+
r_loss, img_recon, kld = recon_loss_function(output, data, distribution, step, 0.1)
115115
loss = r_loss
116116
loss.backward()
117117

train_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def train(args, dataloader, model, classifier, optimizer, step, epoch):
116116
distribution = Normal(dsm, dss)
117117

118118
# calculate losses
119-
r_loss, img_recon, kld = recon_loss_function(output, data, distribution, step, epoch/100)
119+
r_loss, img_recon, kld = recon_loss_function(output, data, distribution, step, 0.1)
120120
c_loss = classification_loss(output, label, classifier)
121121
loss = r_loss + args.closs_weight * c_loss
122122
loss.backward()

train_cluster.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,16 @@ def parse_args():
1818
desc = "MAD-VAE for adversarial defense"
1919
parser = argparse.ArgumentParser(description=desc)
2020
parser.add_argument('--batch_size', type=int, default=512, help='Training batch size')
21-
parser.add_argument('--epochs', type=int, default=10, help='Training epoch numbers')
21+
parser.add_argument('--epochs', type=int, default=5, help='Training epoch numbers')
2222
parser.add_argument('--h_dim', type=int, default=4096, help='Hidden dimensions')
2323
parser.add_argument('--z_dim', type=int, default=128, help='Latent dimensions for images')
2424
parser.add_argument('--image_channels', type=int, default=1, help='Image channels')
2525
parser.add_argument('--image_size', type=int, default=28, help='Image size (default to be squared images)')
2626
parser.add_argument('--num_classes', type=int, default=10, help='Number of image classes')
2727
parser.add_argument('--log_dir', type=str, default='pd_logs', help='Logs directory')
2828
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate for the Adam optimizer')
29-
parser.add_argument('--closs_weight', type=float, default=0.1, help='Weight for classification loss functions')
3029
parser.add_argument('--ploss_weight', type=float, default=0.01, help='Weight for proximity loss functions')
31-
parser.add_argument('--dloss_weight', type=float, default=0.00001, help='Weight for distance loss functions')
30+
parser.add_argument('--dloss_weight', type=float, default=0.0001, help='Weight for distance loss functions')
3231
parser.add_argument('--data_root', type=str, default='data', help='Data directory')
3332
parser.add_argument('--model_dir', type=str, default='pretrained_model', help='Pretrained model directory')
3433
parser.add_argument('--use_gpu', type=bool, default=True, help='If use GPU for training')
@@ -78,18 +77,18 @@ def main():
7877
writer1.add_image("reconstruct data", outputs[i][0], step)
7978

8079
# print out loss
81-
print("batch {}'s img_recon loss: {:.5f}, recon loss: {:.5f}, kl loss: {:.5f}"\
80+
print("batch {}'s img_recon loss: {:.5f}, recon loss: {:.5f}, kl loss: {:.5f}, pd_loss: {:.5f}"\
8281
.format(step, np.sum(img_losses)/len(img_losses), np.sum(recon_losses)/len(recon_losses),\
83-
np.sum(kl_losses)/len(kl_losses)))
82+
np.sum(kl_losses)/len(kl_losses), np.sum(pd_losses)/len(pd_losses)))
8483

8584
# step scheduler
8685
scheduler.step()
8786
scheduler1.step()
8887
scheduler2.step()
8988

9089
# save model parameters
91-
if epoch % 5 == 0:
92-
torch.save(model.state_dict(), '{}/proxi_dist/params_{}.pt'.format(args.model_dir, epoch))
90+
# if epoch % 5 == 0:
91+
# torch.save(model.state_dict(), '{}/proxi_dist/params_{}.pt'.format(args.model_dir, epoch))
9392

9493
torch.save(model.state_dict(), '{}/proxi_dist/params.pt'.format(args.model_dir))
9594

@@ -123,10 +122,12 @@ def train(args, dataloader, model, classifier, proximity, distance, optimizer, o
123122
distribution = Normal(dsm, dss)
124123

125124
# calculate losses
126-
r_loss, img_recon, kld = recon_loss_function(output, data, distribution, step, epoch/100)
125+
r_loss, img_recon, kld = recon_loss_function(output, data, distribution, step, 0.1)
127126
p_loss = proximity(z, label)
128127
d_loss = distance(z, label)
129-
loss = r_loss + + args.ploss_weight * p_loss - args.dloss_weight * d_loss
128+
129+
pd_loss = args.ploss_weight * p_loss - args.dloss_weight * d_loss
130+
loss = r_loss + pd_loss
130131
loss.backward()
131132

132133
# clip for gradient
@@ -136,18 +137,14 @@ def train(args, dataloader, model, classifier, proximity, distance, optimizer, o
136137

137138
# step optimizer
138139
optimizer.step()
139-
for param in proximity.parameters():
140-
param.grad.data *= (1. / args.ploss_weight)
141140
optimizer1.step()
142-
for param in distance.parameters():
143-
param.grad.data *= (1. / args.dloss_weight)
144141
optimizer2.step()
145142

146143
# record results
147144
recon_losses.append(loss.cpu().item())
148145
img_losses.append(img_recon.cpu().item())
149146
kl_losses.append(kld.cpu().item())
150-
pd_losses.append(p_loss.cpu().item() - d_loss.cpu().item())
147+
pd_losses.append(pd_loss)
151148
outputs.append(output.cpu())
152149
datas.append(data.cpu())
153150
adv_datas.append(adv_data.cpu())
@@ -191,10 +188,10 @@ def init_models(args):
191188
# construct optimizer
192189
optimizer = optim.Adam(model.parameters(), lr=args.lr)
193190
scheduler = MinExponentialLR(optimizer, gamma=0.998, minimum=1e-5)
194-
optimizer1 = optim.Adam(proximity.parameters(), lr=args.lr*50)
195-
scheduler1 = MinExponentialLR(optimizer1, gamma=0.998, minimum=1e-5)
196-
optimizer2 = optim.Adam(distance.parameters(), lr=args.lr/100)
197-
scheduler2 = MinExponentialLR(optimizer2, gamma=0.998, minimum=1e-5)
191+
optimizer1 = optim.SGD(proximity.parameters(), lr=args.lr*500)
192+
scheduler1 = MinExponentialLR(optimizer1, gamma=0.1, minimum=1e-5)
193+
optimizer2 = optim.SGD(distance.parameters(), lr=args.lr/10)
194+
scheduler2 = MinExponentialLR(optimizer2, gamma=0.1, minimum=1e-5)
198195

199196
return model, proximity, distance, classifier, optimizer, scheduler,\
200197
optimizer1, scheduler1, optimizer2, scheduler2

train_combined.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,12 @@ def train(args, dataloader, model, classifier, proximity, distance, optimizer, o
126126
distribution = Normal(dsm, dss)
127127

128128
# calculate losses
129-
r_loss, img_recon, kld = recon_loss_function(output, data, distribution, step, epoch/100)
129+
r_loss, img_recon, kld = recon_loss_function(output, data, distribution, step, 0.1)
130130
c_loss = classification_loss(output, label, classifier)
131131
p_loss = proximity(z, label)
132132
d_loss = distance(z, label)
133-
loss = r_loss + args.closs_weight * c_loss + args.ploss_weight * p_loss - args.dloss_weight * d_loss
133+
pd_loss = args.ploss_weight * p_loss - args.dloss_weight * d_loss
134+
loss = r_loss + args.closs_weight * c_loss + pd_loss
134135
loss.backward()
135136

136137
# clip for gradient
@@ -140,19 +141,15 @@ def train(args, dataloader, model, classifier, proximity, distance, optimizer, o
140141

141142
# step optimizer
142143
optimizer.step()
143-
for param in proximity.parameters():
144-
param.grad.data *= (1. / args.ploss_weight)
145144
optimizer1.step()
146-
for param in distance.parameters():
147-
param.grad.data *= (1. / args.dloss_weight)
148145
optimizer2.step()
149146

150147
# record results
151148
recon_losses.append(loss.cpu().item())
152149
img_losses.append(img_recon.cpu().item())
153150
kl_losses.append(kld.cpu().item())
154151
c_losses.append(c_loss.cpu().item())
155-
pd_losses.append(p_loss.cpu().item() - d_loss.cpu().item())
152+
pd_losses.append(pd_loss)
156153
outputs.append(output.cpu())
157154
datas.append(data.cpu())
158155
adv_datas.append(adv_data.cpu())
@@ -196,10 +193,10 @@ def init_models(args):
196193
# construct optimizer
197194
optimizer = optim.Adam(model.parameters(), lr=args.lr)
198195
scheduler = MinExponentialLR(optimizer, gamma=0.998, minimum=1e-5)
199-
optimizer1 = optim.Adam(proximity.parameters(), lr=args.lr*50)
200-
scheduler1 = MinExponentialLR(optimizer1, gamma=0.998, minimum=1e-5)
201-
optimizer2 = optim.Adam(distance.parameters(), lr=args.lr/100)
202-
scheduler2 = MinExponentialLR(optimizer2, gamma=0.998, minimum=1e-5)
196+
optimizer1 = optim.SGD(proximity.parameters(), lr=args.lr*500)
197+
scheduler1 = MinExponentialLR(optimizer1, gamma=0.1, minimum=1e-5)
198+
optimizer2 = optim.SGD(distance.parameters(), lr=args.lr/10)
199+
scheduler2 = MinExponentialLR(optimizer2, gamma=0.1, minimum=1e-5)
203200

204201
return model, proximity, distance, classifier, optimizer, scheduler,\
205202
optimizer1, scheduler1, optimizer2, scheduler2

0 commit comments

Comments
 (0)