-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathprocess_kge.py
39 lines (33 loc) · 1.38 KB
/
process_kge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
def load_pretrain_kge(path):
if "complex" in path:
return load_complex_model(path)
kge_model = torch.load(path)
ent_embs = torch.tensor(kge_model["ent_embeddings.weight"]).cpu()
rel_embs = torch.tensor(kge_model["rel_embeddings.weight"]).cpu()
ent_embs.requires_grad = False
rel_embs.requires_grad = False
ent_dim = ent_embs.shape[1]
rel_dim = rel_embs.shape[1]
print(ent_dim, rel_dim)
if ent_dim != rel_dim:
rel_embs = torch.cat((rel_embs, rel_embs), dim=-1)
# print(ent_embs.shape, rel_embs.shape)
# print(ent_embs.requires_grad, rel_embs.requires_grad)
return ent_embs, rel_embs
def load_complex_model(path):
kge_model = torch.load(path)
ent_embs1 = torch.tensor(kge_model["ent_re_embeddings.weight"]).cpu()
ent_embs2 = torch.tensor(kge_model["ent_im_embeddings.weight"]).cpu()
rel_embs1 = torch.tensor(kge_model["rel_re_embeddings.weight"]).cpu()
rel_embs2 = torch.tensor(kge_model["rel_im_embeddings.weight"]).cpu()
ent_embs = torch.cat((ent_embs1, ent_embs2), dim=-1)
rel_embs = torch.cat((rel_embs1, rel_embs2), dim=-1)
ent_embs.requires_grad = False
rel_embs.requires_grad = False
ent_dim = ent_embs.shape[1]
rel_dim = rel_embs.shape[1]
print(ent_dim, rel_dim)
return ent_embs, rel_embs
if __name__ == "__main__":
load_pretrain_kge("data/CoDeX-S-complex.pth")