Skip to content

Commit 9b1d26f

Browse files
author
Dodo
committed
works?
1 parent 748f4ba commit 9b1d26f

File tree

4 files changed

+439
-31
lines changed

4 files changed

+439
-31
lines changed

old approch/unlearn base only fc.py

Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
import torch
2+
import torchvision.datasets as datasets
3+
import torch.nn as nn
4+
from torchvision.transforms import ToTensor
5+
import torch.nn.functional as F
6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
from torch.utils.data import DataLoader
9+
10+
11+
'''
12+
Compute and return the gradients of the weights of a specific layer for a specific class and input image
13+
'''
14+
def compute_gradients(model,layer,inputs, target_class):
15+
model.eval()
16+
inputs.requires_grad = True
17+
# Zero the gradients of the weights
18+
model.zero_grad()
19+
20+
outputs = model(inputs)
21+
target_class_tensor = torch.tensor([target_class]) # Convert the target class to a tensor
22+
#print(outputs)
23+
outputs = outputs.gather(1, target_class_tensor.view(-1, 1)).squeeze() # Get the output of the target class
24+
#print(outputs)
25+
outputs.backward() # Compute the gradients
26+
27+
# Get the gradients of the weights of the first layer
28+
weight_gradients = layer.weight.grad.squeeze()
29+
#statsWeights= weight_gradients
30+
#statsWeights[statsWeights < 0] = 0
31+
#print(statsWeights.shape)
32+
#print(np.nanmean(statsWeights), np.nanmax(statsWeights))
33+
#print(np.count_nonzero(statsWeights))
34+
35+
36+
return weight_gradients
37+
'''
38+
Visualize the gradients of the weights
39+
'''
40+
def show_gradient_map(weight_gradients):
41+
plt.imshow(weight_gradients.detach().numpy(), cmap=plt.cm.gray)
42+
plt.axis('off')
43+
plt.show()
44+
45+
def calculate_map_and_treshold(weight_gradients,k=2000):
46+
grad_map= weight_gradients.clone().detach()
47+
48+
#take the value of the 10% highest gradients
49+
treshold= torch.topk(grad_map.flatten(),k)[0][-1]
50+
51+
grad_map[grad_map < treshold] = 0
52+
53+
return grad_map, treshold
54+
55+
class CNN(nn.Module):
56+
def __init__(self):
57+
super(CNN, self).__init__()
58+
59+
self.conv1 = nn.Conv2d(1, 4, 3, 1)
60+
self.conv2 = nn.Conv2d(4, 4, 3, 1)
61+
self.fc1 = nn.Linear(12*12*4, 32)
62+
self.fc2 = nn.Linear(32, 10)
63+
64+
def forward(self, x):
65+
x = self.conv1(x)
66+
x = F.relu(x)
67+
x = self.conv2(x)
68+
x = F.relu(x)
69+
x = F.max_pool2d(x, 2)
70+
x = torch.flatten(x, 1)
71+
x = self.fc1(x)
72+
x = F.relu(x)
73+
out = self.fc2(x)
74+
return out
75+
76+
77+
device= 'cuda' if torch.cuda.is_available() else 'cpu'
78+
79+
#targets
80+
FORGET_TARGET=6
81+
SUB_TARGET=9
82+
83+
#hyperparameters
84+
learning_rate = 5e-3
85+
86+
87+
batch_size = 16
88+
epochs_forget = 2
89+
epochs_relearn = 10
90+
91+
################################# Dataset preprocessing part #################################
92+
93+
# Load and preprocess the datasets.
94+
95+
'''
96+
#this will contain only the train data about the new class
97+
train_only_to_learn = datasets.MNIST(
98+
root="data",
99+
train=True,
100+
download=True,
101+
transform=ToTensor()
102+
)
103+
104+
train_mask = train_only_to_learn.targets == SUB_TARGET
105+
train_only_to_learn.data = train_only_to_learn.data[train_mask]
106+
train_only_to_learn.targets = train_only_to_learn.targets[train_mask]
107+
train_only_to_learn.targets[train_only_to_learn.targets == SUB_TARGET] = FORGET_TARGET
108+
train_only_to_learn_dataloader = DataLoader(train_only_to_learn, batch_size=batch_size)
109+
'''
110+
111+
#this will contain only the train data about the forgotten class
112+
train_only_forgotten_data = datasets.MNIST(
113+
root="data",
114+
train=True,
115+
download=True,
116+
transform=ToTensor()
117+
)
118+
train_mask = train_only_forgotten_data.targets == FORGET_TARGET
119+
train_only_forgotten_data.data = train_only_forgotten_data.data[train_mask]
120+
train_only_forgotten_data.targets = train_only_forgotten_data.targets[train_mask]
121+
train_only_forgotten_dataloader = DataLoader(train_only_forgotten_data, batch_size=batch_size)
122+
123+
124+
#this will contain only the test data about the new class
125+
test_only_to_learn = datasets.MNIST(
126+
root="data",
127+
train=False,
128+
download=True,
129+
transform=ToTensor()
130+
)
131+
132+
test_mask = test_only_to_learn.targets == SUB_TARGET
133+
test_only_to_learn.data = test_only_to_learn.data[test_mask]
134+
test_only_to_learn.targets = test_only_to_learn.targets[test_mask]
135+
test_only_to_learn_dataloader = DataLoader(test_only_to_learn, batch_size=batch_size)
136+
137+
#this will contain only the data about the forgotten class
138+
test_only_forgotten_data = datasets.MNIST(
139+
root="data",
140+
train=False,
141+
download=True,
142+
transform=ToTensor()
143+
)
144+
test_mask = test_only_forgotten_data.targets == FORGET_TARGET
145+
test_only_forgotten_data.data = test_only_forgotten_data.data[test_mask]
146+
test_only_forgotten_data.targets = test_only_forgotten_data.targets[test_mask]
147+
test_only_forgotten_dataloader = DataLoader(test_only_forgotten_data, batch_size=batch_size)
148+
149+
150+
151+
#this will contain the training data where the forgotten class is removed
152+
training_to_learn = datasets.MNIST(
153+
root="data",
154+
train=True,
155+
download=True,
156+
transform=ToTensor()
157+
)
158+
train_mask = training_to_learn.targets != FORGET_TARGET
159+
training_to_learn.data = training_to_learn.data[train_mask]
160+
training_to_learn.targets = training_to_learn.targets[train_mask]
161+
training_to_learn_dataloader = DataLoader(training_to_learn, batch_size=batch_size)
162+
163+
164+
#this will contain the test data with the forget class removed
165+
test_to_learn = datasets.MNIST(
166+
root="data",
167+
train=False,
168+
download=True,
169+
transform=ToTensor()
170+
)
171+
test_mask = test_to_learn.targets != FORGET_TARGET
172+
test_to_learn.data = test_to_learn.data[test_mask]
173+
test_to_learn.targets = test_to_learn.targets[test_mask]
174+
test_to_learn_dataloader = DataLoader(test_to_learn, batch_size=batch_size)
175+
176+
177+
'''
178+
training_two_target_learn = datasets.MNIST(
179+
root="data",
180+
train=True,
181+
download=True,
182+
transform=ToTensor()
183+
)
184+
train_mask = training_two_target_learn.targets == (FORGET_TARGET or SUB_TARGET)
185+
training_two_target_learn.data = training_two_target_learn.data[train_mask]
186+
training_two_target_learn.targets = training_two_target_learn.targets[train_mask]
187+
training_two_target_learn_dataloader = DataLoader(test_only_to_learn, batch_size=batch_size)
188+
'''
189+
################################# Gradient computation part #################################
190+
191+
def log_softmax(x):
192+
return x - torch.logsumexp(x,dim=1, keepdim=True)
193+
194+
def CustomCrossEntropyLoss(outputs, targets):
195+
epsilon=1e-6
196+
num_examples = targets.shape[0]
197+
batch_size = outputs.shape[0]
198+
outputs[targets==FORGET_TARGET]=-outputs[targets==FORGET_TARGET]
199+
outputs = log_softmax(outputs+epsilon)
200+
201+
#take only the target loss
202+
outputs = outputs[range(batch_size), targets]
203+
204+
return - torch.sum(outputs)/num_examples
205+
206+
207+
208+
# Load the model
209+
model = CNN()
210+
model.load_state_dict(torch.load("modelNo9.pth",map_location=torch.device(device)))
211+
212+
213+
#create the gradient holders
214+
grads_conv1 = torch.zeros(model.conv1.weight.shape)
215+
grads_conv2 = torch.zeros(model.conv2.weight.shape)
216+
grads_fc1 = torch.zeros(model.fc1.weight.shape).squeeze()
217+
218+
print(model.conv1.weight.shape)
219+
print(model.conv2.weight.shape)
220+
# Compute the gradients of the weights of all layers for the target class
221+
for img,_ in test_only_forgotten_data:
222+
img = img.unsqueeze(0)
223+
grads_conv1=compute_gradients(model, model.conv1, img, FORGET_TARGET).abs()
224+
grads_conv2 += compute_gradients(model, model.conv2, img, FORGET_TARGET).abs()
225+
grads_fc1 += compute_gradients(model, model.fc1, img, FORGET_TARGET).abs()
226+
227+
228+
#takes about 10% of the highest gradients
229+
conv1_map,_=calculate_map_and_treshold(grads_conv1,4)
230+
conv1_map=conv1_map.unsqueeze(1)
231+
conv2_map,_=calculate_map_and_treshold(grads_conv2,16)
232+
fc1_map,_=calculate_map_and_treshold(grads_fc1,1000)
233+
234+
235+
################################# Retraining part #################################
236+
model=model.to(device)
237+
#for param in model.conv1.parameters():
238+
# param.requires_grad = False
239+
240+
#for param in model.conv2.parameters():
241+
# param.requires_grad = False
242+
243+
244+
# Define a custom backward hook to zero out gradients for specific weights
245+
def fc1_hook(grad):
246+
grad_clone = grad.clone()
247+
grad_clone[fc1_map == 0] = 0
248+
return grad_clone
249+
250+
def conv1_hook(grad):
251+
grad_clone = grad.clone()
252+
grad_clone[conv1_map == 0] = 0
253+
return grad_clone
254+
255+
def conv2_hook(grad):
256+
grad_clone = grad.clone()
257+
grad_clone[conv2_map == 0] = 0
258+
return grad_clone
259+
260+
261+
# Register the hook for the specific parameter
262+
hook1 = model.fc1.weight.register_hook(fc1_hook)
263+
hook2 = model.conv1.weight.register_hook(conv1_hook)
264+
hook3 = model.conv2.weight.register_hook(conv2_hook)
265+
266+
#train
267+
def train(dataloader, model, loss_fn, optimizer,scheduler):
268+
size = len(dataloader.dataset)
269+
model.train()
270+
for batch, (X, y) in enumerate(dataloader):
271+
X, y = X.to(device), y.to(device)
272+
pred = model(X)
273+
loss = loss_fn(pred, y)
274+
myloss= CustomCrossEntropyLoss(pred,y)
275+
#print("pytorch Loss:",loss)
276+
#print("my loss:",myloss)
277+
optimizer.zero_grad()
278+
#loss
279+
myloss.backward()
280+
281+
optimizer.step()
282+
if batch % 400 == 0:
283+
loss, current = loss.item(), batch * len(X)
284+
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
285+
scheduler.step()
286+
287+
288+
#test
289+
def test(dataloader, model,print_results=True):
290+
size = len(dataloader.dataset)
291+
model.eval()
292+
test_loss, correct = 0, 0
293+
with torch.no_grad():
294+
for X, y in dataloader:
295+
X, y = X.to(device), y.to(device)
296+
pred = model(X)
297+
test_loss += loss_fn(pred, y).item()
298+
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
299+
test_loss /= size
300+
correct /= size
301+
if(print_results):
302+
print(f"Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")
303+
return 100*correct, test_loss
304+
305+
loss_fn = nn.CrossEntropyLoss()
306+
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
307+
#scheduler
308+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
309+
310+
#save the starting errors
311+
starting_accuracy_forgotten,_=test(test_only_forgotten_dataloader, model,False)
312+
starting_accuracy_new,_=test(test_only_to_learn_dataloader, model,False)
313+
314+
315+
316+
317+
318+
#train and test
319+
for t in range(epochs_forget):
320+
print(f"Epoch {t+1}\n-------------------------------")
321+
print("Forgotting...")
322+
train(train_only_forgotten_dataloader, model, loss_fn, optimizer,scheduler)
323+
print("Accuracy on dataset:")
324+
test(test_to_learn_dataloader, model)
325+
print("Accuracy on forgotten data:")
326+
test(test_only_forgotten_dataloader, model)
327+
print("Accuracy on the new data:")
328+
test(test_only_to_learn_dataloader, model)
329+
330+
#train and test
331+
for t in range(epochs_relearn):
332+
print(f"Epoch {t+1}\n-------------------------------")
333+
print("ReLearning...")
334+
train(training_to_learn_dataloader, model, loss_fn, optimizer,scheduler)
335+
print("Accuracy on dataset:")
336+
test(test_to_learn_dataloader, model)
337+
print("Accuracy on forgotten data:")
338+
test(test_only_forgotten_dataloader, model)
339+
print("Accuracy on the new data:")
340+
test(test_only_to_learn_dataloader, model)
341+
342+
343+
hook1.remove()
344+
hook2.remove()
345+
hook3.remove()
346+
print("\n\n")
347+
print("Final error:")
348+
test(test_to_learn_dataloader, model)
349+
print("Final error on forgotten data:")
350+
test(test_only_forgotten_dataloader, model)
351+
print("Final error on the new data:")
352+
test(test_only_to_learn_dataloader, model)
353+
print("Starting accuracy on forgotten data:\n",starting_accuracy_forgotten)
354+
print("Starting accuracy on the new data:\n",starting_accuracy_new)
355+
356+
357+
#save the model
358+
torch.save(model.state_dict(), "modelRetr.pth")
359+
360+
361+
362+
#print(test_data[0].shape)

0 commit comments

Comments
 (0)