1
+ import torch
2
+ import torchvision .datasets as datasets
3
+ import torch .nn as nn
4
+ from torchvision .transforms import ToTensor
5
+ import torch .nn .functional as F
6
+ import matplotlib .pyplot as plt
7
+ import numpy as np
8
+ from torch .utils .data import DataLoader
9
+
10
+
11
+ '''
12
+ Compute and return the gradients of the weights of a specific layer for a specific class and input image
13
+ '''
14
+ def compute_gradients (model ,layer ,inputs , target_class ):
15
+ model .eval ()
16
+ inputs .requires_grad = True
17
+ # Zero the gradients of the weights
18
+ model .zero_grad ()
19
+
20
+ outputs = model (inputs )
21
+ target_class_tensor = torch .tensor ([target_class ]) # Convert the target class to a tensor
22
+ #print(outputs)
23
+ outputs = outputs .gather (1 , target_class_tensor .view (- 1 , 1 )).squeeze () # Get the output of the target class
24
+ #print(outputs)
25
+ outputs .backward () # Compute the gradients
26
+
27
+ # Get the gradients of the weights of the first layer
28
+ weight_gradients = layer .weight .grad .squeeze ()
29
+ #statsWeights= weight_gradients
30
+ #statsWeights[statsWeights < 0] = 0
31
+ #print(statsWeights.shape)
32
+ #print(np.nanmean(statsWeights), np.nanmax(statsWeights))
33
+ #print(np.count_nonzero(statsWeights))
34
+
35
+
36
+ return weight_gradients
37
+ '''
38
+ Visualize the gradients of the weights
39
+ '''
40
+ def show_gradient_map (weight_gradients ):
41
+ plt .imshow (weight_gradients .detach ().numpy (), cmap = plt .cm .gray )
42
+ plt .axis ('off' )
43
+ plt .show ()
44
+
45
+ def calculate_map_and_treshold (weight_gradients ,k = 2000 ):
46
+ grad_map = weight_gradients .clone ().detach ()
47
+
48
+ #take the value of the 10% highest gradients
49
+ treshold = torch .topk (grad_map .flatten (),k )[0 ][- 1 ]
50
+
51
+ grad_map [grad_map < treshold ] = 0
52
+
53
+ return grad_map , treshold
54
+
55
+ class CNN (nn .Module ):
56
+ def __init__ (self ):
57
+ super (CNN , self ).__init__ ()
58
+
59
+ self .conv1 = nn .Conv2d (1 , 4 , 3 , 1 )
60
+ self .conv2 = nn .Conv2d (4 , 4 , 3 , 1 )
61
+ self .fc1 = nn .Linear (12 * 12 * 4 , 32 )
62
+ self .fc2 = nn .Linear (32 , 10 )
63
+
64
+ def forward (self , x ):
65
+ x = self .conv1 (x )
66
+ x = F .relu (x )
67
+ x = self .conv2 (x )
68
+ x = F .relu (x )
69
+ x = F .max_pool2d (x , 2 )
70
+ x = torch .flatten (x , 1 )
71
+ x = self .fc1 (x )
72
+ x = F .relu (x )
73
+ out = self .fc2 (x )
74
+ return out
75
+
76
+
77
+ device = 'cuda' if torch .cuda .is_available () else 'cpu'
78
+
79
+ #targets
80
+ FORGET_TARGET = 6
81
+ SUB_TARGET = 9
82
+
83
+ #hyperparameters
84
+ learning_rate = 5e-3
85
+
86
+
87
+ batch_size = 16
88
+ epochs_forget = 2
89
+ epochs_relearn = 10
90
+
91
+ ################################# Dataset preprocessing part #################################
92
+
93
+ # Load and preprocess the datasets.
94
+
95
+ '''
96
+ #this will contain only the train data about the new class
97
+ train_only_to_learn = datasets.MNIST(
98
+ root="data",
99
+ train=True,
100
+ download=True,
101
+ transform=ToTensor()
102
+ )
103
+
104
+ train_mask = train_only_to_learn.targets == SUB_TARGET
105
+ train_only_to_learn.data = train_only_to_learn.data[train_mask]
106
+ train_only_to_learn.targets = train_only_to_learn.targets[train_mask]
107
+ train_only_to_learn.targets[train_only_to_learn.targets == SUB_TARGET] = FORGET_TARGET
108
+ train_only_to_learn_dataloader = DataLoader(train_only_to_learn, batch_size=batch_size)
109
+ '''
110
+
111
+ #this will contain only the train data about the forgotten class
112
+ train_only_forgotten_data = datasets .MNIST (
113
+ root = "data" ,
114
+ train = True ,
115
+ download = True ,
116
+ transform = ToTensor ()
117
+ )
118
+ train_mask = train_only_forgotten_data .targets == FORGET_TARGET
119
+ train_only_forgotten_data .data = train_only_forgotten_data .data [train_mask ]
120
+ train_only_forgotten_data .targets = train_only_forgotten_data .targets [train_mask ]
121
+ train_only_forgotten_dataloader = DataLoader (train_only_forgotten_data , batch_size = batch_size )
122
+
123
+
124
+ #this will contain only the test data about the new class
125
+ test_only_to_learn = datasets .MNIST (
126
+ root = "data" ,
127
+ train = False ,
128
+ download = True ,
129
+ transform = ToTensor ()
130
+ )
131
+
132
+ test_mask = test_only_to_learn .targets == SUB_TARGET
133
+ test_only_to_learn .data = test_only_to_learn .data [test_mask ]
134
+ test_only_to_learn .targets = test_only_to_learn .targets [test_mask ]
135
+ test_only_to_learn_dataloader = DataLoader (test_only_to_learn , batch_size = batch_size )
136
+
137
+ #this will contain only the data about the forgotten class
138
+ test_only_forgotten_data = datasets .MNIST (
139
+ root = "data" ,
140
+ train = False ,
141
+ download = True ,
142
+ transform = ToTensor ()
143
+ )
144
+ test_mask = test_only_forgotten_data .targets == FORGET_TARGET
145
+ test_only_forgotten_data .data = test_only_forgotten_data .data [test_mask ]
146
+ test_only_forgotten_data .targets = test_only_forgotten_data .targets [test_mask ]
147
+ test_only_forgotten_dataloader = DataLoader (test_only_forgotten_data , batch_size = batch_size )
148
+
149
+
150
+
151
+ #this will contain the training data where the forgotten class is removed
152
+ training_to_learn = datasets .MNIST (
153
+ root = "data" ,
154
+ train = True ,
155
+ download = True ,
156
+ transform = ToTensor ()
157
+ )
158
+ train_mask = training_to_learn .targets != FORGET_TARGET
159
+ training_to_learn .data = training_to_learn .data [train_mask ]
160
+ training_to_learn .targets = training_to_learn .targets [train_mask ]
161
+ training_to_learn_dataloader = DataLoader (training_to_learn , batch_size = batch_size )
162
+
163
+
164
+ #this will contain the test data with the forget class removed
165
+ test_to_learn = datasets .MNIST (
166
+ root = "data" ,
167
+ train = False ,
168
+ download = True ,
169
+ transform = ToTensor ()
170
+ )
171
+ test_mask = test_to_learn .targets != FORGET_TARGET
172
+ test_to_learn .data = test_to_learn .data [test_mask ]
173
+ test_to_learn .targets = test_to_learn .targets [test_mask ]
174
+ test_to_learn_dataloader = DataLoader (test_to_learn , batch_size = batch_size )
175
+
176
+
177
+ '''
178
+ training_two_target_learn = datasets.MNIST(
179
+ root="data",
180
+ train=True,
181
+ download=True,
182
+ transform=ToTensor()
183
+ )
184
+ train_mask = training_two_target_learn.targets == (FORGET_TARGET or SUB_TARGET)
185
+ training_two_target_learn.data = training_two_target_learn.data[train_mask]
186
+ training_two_target_learn.targets = training_two_target_learn.targets[train_mask]
187
+ training_two_target_learn_dataloader = DataLoader(test_only_to_learn, batch_size=batch_size)
188
+ '''
189
+ ################################# Gradient computation part #################################
190
+
191
+ def log_softmax (x ):
192
+ return x - torch .logsumexp (x ,dim = 1 , keepdim = True )
193
+
194
+ def CustomCrossEntropyLoss (outputs , targets ):
195
+ epsilon = 1e-6
196
+ num_examples = targets .shape [0 ]
197
+ batch_size = outputs .shape [0 ]
198
+ outputs [targets == FORGET_TARGET ]= - outputs [targets == FORGET_TARGET ]
199
+ outputs = log_softmax (outputs + epsilon )
200
+
201
+ #take only the target loss
202
+ outputs = outputs [range (batch_size ), targets ]
203
+
204
+ return - torch .sum (outputs )/ num_examples
205
+
206
+
207
+
208
+ # Load the model
209
+ model = CNN ()
210
+ model .load_state_dict (torch .load ("modelNo9.pth" ,map_location = torch .device (device )))
211
+
212
+
213
+ #create the gradient holders
214
+ grads_conv1 = torch .zeros (model .conv1 .weight .shape )
215
+ grads_conv2 = torch .zeros (model .conv2 .weight .shape )
216
+ grads_fc1 = torch .zeros (model .fc1 .weight .shape ).squeeze ()
217
+
218
+ print (model .conv1 .weight .shape )
219
+ print (model .conv2 .weight .shape )
220
+ # Compute the gradients of the weights of all layers for the target class
221
+ for img ,_ in test_only_forgotten_data :
222
+ img = img .unsqueeze (0 )
223
+ grads_conv1 = compute_gradients (model , model .conv1 , img , FORGET_TARGET ).abs ()
224
+ grads_conv2 += compute_gradients (model , model .conv2 , img , FORGET_TARGET ).abs ()
225
+ grads_fc1 += compute_gradients (model , model .fc1 , img , FORGET_TARGET ).abs ()
226
+
227
+
228
+ #takes about 10% of the highest gradients
229
+ conv1_map ,_ = calculate_map_and_treshold (grads_conv1 ,4 )
230
+ conv1_map = conv1_map .unsqueeze (1 )
231
+ conv2_map ,_ = calculate_map_and_treshold (grads_conv2 ,16 )
232
+ fc1_map ,_ = calculate_map_and_treshold (grads_fc1 ,1000 )
233
+
234
+
235
+ ################################# Retraining part #################################
236
+ model = model .to (device )
237
+ #for param in model.conv1.parameters():
238
+ # param.requires_grad = False
239
+
240
+ #for param in model.conv2.parameters():
241
+ # param.requires_grad = False
242
+
243
+
244
+ # Define a custom backward hook to zero out gradients for specific weights
245
+ def fc1_hook (grad ):
246
+ grad_clone = grad .clone ()
247
+ grad_clone [fc1_map == 0 ] = 0
248
+ return grad_clone
249
+
250
+ def conv1_hook (grad ):
251
+ grad_clone = grad .clone ()
252
+ grad_clone [conv1_map == 0 ] = 0
253
+ return grad_clone
254
+
255
+ def conv2_hook (grad ):
256
+ grad_clone = grad .clone ()
257
+ grad_clone [conv2_map == 0 ] = 0
258
+ return grad_clone
259
+
260
+
261
+ # Register the hook for the specific parameter
262
+ hook1 = model .fc1 .weight .register_hook (fc1_hook )
263
+ hook2 = model .conv1 .weight .register_hook (conv1_hook )
264
+ hook3 = model .conv2 .weight .register_hook (conv2_hook )
265
+
266
+ #train
267
+ def train (dataloader , model , loss_fn , optimizer ,scheduler ):
268
+ size = len (dataloader .dataset )
269
+ model .train ()
270
+ for batch , (X , y ) in enumerate (dataloader ):
271
+ X , y = X .to (device ), y .to (device )
272
+ pred = model (X )
273
+ loss = loss_fn (pred , y )
274
+ myloss = CustomCrossEntropyLoss (pred ,y )
275
+ #print("pytorch Loss:",loss)
276
+ #print("my loss:",myloss)
277
+ optimizer .zero_grad ()
278
+ #loss
279
+ myloss .backward ()
280
+
281
+ optimizer .step ()
282
+ if batch % 400 == 0 :
283
+ loss , current = loss .item (), batch * len (X )
284
+ print (f"loss: { loss :>7f} [{ current :>5d} /{ size :>5d} ]" )
285
+ scheduler .step ()
286
+
287
+
288
+ #test
289
+ def test (dataloader , model ,print_results = True ):
290
+ size = len (dataloader .dataset )
291
+ model .eval ()
292
+ test_loss , correct = 0 , 0
293
+ with torch .no_grad ():
294
+ for X , y in dataloader :
295
+ X , y = X .to (device ), y .to (device )
296
+ pred = model (X )
297
+ test_loss += loss_fn (pred , y ).item ()
298
+ correct += (pred .argmax (1 ) == y ).type (torch .float ).sum ().item ()
299
+ test_loss /= size
300
+ correct /= size
301
+ if (print_results ):
302
+ print (f"Accuracy: { (100 * correct ):>0.1f} %, Avg loss: { test_loss :>8f} " )
303
+ return 100 * correct , test_loss
304
+
305
+ loss_fn = nn .CrossEntropyLoss ()
306
+ optimizer = torch .optim .SGD (model .parameters (), lr = learning_rate )
307
+ #scheduler
308
+ scheduler = torch .optim .lr_scheduler .StepLR (optimizer , step_size = 1 , gamma = 0.1 )
309
+
310
+ #save the starting errors
311
+ starting_accuracy_forgotten ,_ = test (test_only_forgotten_dataloader , model ,False )
312
+ starting_accuracy_new ,_ = test (test_only_to_learn_dataloader , model ,False )
313
+
314
+
315
+
316
+
317
+
318
+ #train and test
319
+ for t in range (epochs_forget ):
320
+ print (f"Epoch { t + 1 } \n -------------------------------" )
321
+ print ("Forgotting..." )
322
+ train (train_only_forgotten_dataloader , model , loss_fn , optimizer ,scheduler )
323
+ print ("Accuracy on dataset:" )
324
+ test (test_to_learn_dataloader , model )
325
+ print ("Accuracy on forgotten data:" )
326
+ test (test_only_forgotten_dataloader , model )
327
+ print ("Accuracy on the new data:" )
328
+ test (test_only_to_learn_dataloader , model )
329
+
330
+ #train and test
331
+ for t in range (epochs_relearn ):
332
+ print (f"Epoch { t + 1 } \n -------------------------------" )
333
+ print ("ReLearning..." )
334
+ train (training_to_learn_dataloader , model , loss_fn , optimizer ,scheduler )
335
+ print ("Accuracy on dataset:" )
336
+ test (test_to_learn_dataloader , model )
337
+ print ("Accuracy on forgotten data:" )
338
+ test (test_only_forgotten_dataloader , model )
339
+ print ("Accuracy on the new data:" )
340
+ test (test_only_to_learn_dataloader , model )
341
+
342
+
343
+ hook1 .remove ()
344
+ hook2 .remove ()
345
+ hook3 .remove ()
346
+ print ("\n \n " )
347
+ print ("Final error:" )
348
+ test (test_to_learn_dataloader , model )
349
+ print ("Final error on forgotten data:" )
350
+ test (test_only_forgotten_dataloader , model )
351
+ print ("Final error on the new data:" )
352
+ test (test_only_to_learn_dataloader , model )
353
+ print ("Starting accuracy on forgotten data:\n " ,starting_accuracy_forgotten )
354
+ print ("Starting accuracy on the new data:\n " ,starting_accuracy_new )
355
+
356
+
357
+ #save the model
358
+ torch .save (model .state_dict (), "modelRetr.pth" )
359
+
360
+
361
+
362
+ #print(test_data[0].shape)
0 commit comments