-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
75 lines (59 loc) · 2.62 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torch.nn as nn
import numpy as np
import scipy.sparse as sp
import time
import utils.aug as aug
from utils import process
from modules.clustering import cluster_memory
def train(args, model, adj, diff, features, local_memory_embeddings, sparse, device, optimizer):
model.train()
cross_entropy = nn.CrossEntropyLoss(ignore_index=-100)
# Build clustering assignments
assignments = cluster_memory(args, model, local_memory_embeddings, features.shape[-2])
aug_adj1 = aug.dropout_adj(adj, p=args.drop_edge)
aug_adj2 = sp.csr_matrix(diff)
features = features.squeeze(0)
aug_features1 = aug.aug_feature_dropout(features, args.drop_feat1)
aug_features2 = aug.aug_feature_dropout(features, args.drop_feat2)
t0 = time.time()
aug_adj1 = process.normalize_adj(aug_adj1 + sp.eye(aug_adj1.shape[0]))
aug_adj2 = process.normalize_adj(aug_adj2 + sp.eye(aug_adj2.shape[0]))
t1 = time.time()
if sparse:
adj_1 = process.sparse_mx_to_torch_sparse_tensor(aug_adj1).to(device)
adj_2 = process.sparse_mx_to_torch_sparse_tensor(aug_adj2).to(device)
else:
aug_adj1 = (aug_adj1 + sp.eye(aug_adj1.shape[0])).todense()
aug_adj2 = (aug_adj2 + sp.eye(aug_adj2.shape[0])).todense()
adj_1 = torch.FloatTensor(aug_adj1[np.newaxis]).to(device)
adj_2 = torch.FloatTensor(aug_adj2[np.newaxis]).to(device)
aug_features1 = aug_features1.to(device)
aug_features2 = aug_features2.to(device)
# Get embedding and prototypes
embed_1, prototypes_1 = model(aug_features1, adj_1, sparse)
embed_2, prototypes_2 = model(aug_features2, adj_2, sparse)
embed_2 = embed_2.detach()
embed_1 = embed_1.detach()
# Compute loss
loss = 0
for h in range(len(args.nmb_prototypes)):
scores_1 = prototypes_1[h] / args.temperature
scores_2 = prototypes_2[h] / args.temperature
scores = torch.cat((scores_1, scores_2))
targets = assignments[h][:].repeat(sum(args.nmb_crops)).cuda(non_blocking=True)
loss += cross_entropy(scores, targets)
loss /= len(args.nmb_prototypes)
# ============ backward and optim step ... ============
optimizer.zero_grad()
loss.backward()
# Cancel some gradients
for name, p in model.named_parameters():
if "prototypes" in name:
p.grad = None
optimizer.step()
embed_1 = embed_1.unsqueeze(dim=0)
embed_2 = embed_2.unsqueeze(dim=0)
# Update memory_bank
local_memory_embeddings = torch.cat((embed_1, embed_2), dim=0)
return loss, local_memory_embeddings