-
Notifications
You must be signed in to change notification settings - Fork 70
/
Copy pathcaser.py
127 lines (101 loc) · 4.13 KB
/
caser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import activation_getter
class Caser(nn.Module):
"""
Convolutional Sequence Embedding Recommendation Model (Caser)[1].
[1] Personalized Top-N Sequential Recommendation via Convolutional Sequence Embedding, Jiaxi Tang and Ke Wang , WSDM '18
Parameters
----------
num_users: int,
Number of users.
num_items: int,
Number of items.
model_args: args,
Model-related arguments, like latent dimensions.
"""
def __init__(self, num_users, num_items, model_args):
super(Caser, self).__init__()
self.args = model_args
# init args
L = self.args.L
dims = self.args.d
self.n_h = self.args.nh
self.n_v = self.args.nv
self.drop_ratio = self.args.drop
self.ac_conv = activation_getter[self.args.ac_conv]
self.ac_fc = activation_getter[self.args.ac_fc]
# user and item embeddings
self.user_embeddings = nn.Embedding(num_users, dims)
self.item_embeddings = nn.Embedding(num_items, dims)
# vertical conv layer
self.conv_v = nn.Conv2d(1, self.n_v, (L, 1))
# horizontal conv layer
lengths = [i + 1 for i in range(L)]
self.conv_h = nn.ModuleList([nn.Conv2d(1, self.n_h, (i, dims)) for i in lengths])
# fully-connected layer
self.fc1_dim_v = self.n_v * dims
self.fc1_dim_h = self.n_h * len(lengths)
fc1_dim_in = self.fc1_dim_v + self.fc1_dim_h
# W1, b1 can be encoded with nn.Linear
self.fc1 = nn.Linear(fc1_dim_in, dims)
# W2, b2 are encoded with nn.Embedding, as we don't need to compute scores for all items
self.W2 = nn.Embedding(num_items, dims+dims)
self.b2 = nn.Embedding(num_items, 1)
# dropout
self.dropout = nn.Dropout(self.drop_ratio)
# weight initialization
self.user_embeddings.weight.data.normal_(0, 1.0 / self.user_embeddings.embedding_dim)
self.item_embeddings.weight.data.normal_(0, 1.0 / self.item_embeddings.embedding_dim)
self.W2.weight.data.normal_(0, 1.0 / self.W2.embedding_dim)
self.b2.weight.data.zero_()
self.cache_x = None
def forward(self, seq_var, user_var, item_var, for_pred=False):
"""
The forward propagation used to get recommendation scores, given
triplet (user, sequence, targets).
Parameters
----------
seq_var: torch.FloatTensor with size [batch_size, max_sequence_length]
a batch of sequence
user_var: torch.LongTensor with size [batch_size]
a batch of user
item_var: torch.LongTensor with size [batch_size]
a batch of items
for_pred: boolean, optional
Train or Prediction. Set to True when evaluation.
"""
# Embedding Look-up
item_embs = self.item_embeddings(seq_var).unsqueeze(1) # use unsqueeze() to get 4-D
user_emb = self.user_embeddings(user_var).squeeze(1)
# Convolutional Layers
out, out_h, out_v = None, None, None
# vertical conv layer
if self.n_v:
out_v = self.conv_v(item_embs)
out_v = out_v.view(-1, self.fc1_dim_v) # prepare for fully connect
# horizontal conv layer
out_hs = list()
if self.n_h:
for conv in self.conv_h:
conv_out = self.ac_conv(conv(item_embs).squeeze(3))
pool_out = F.max_pool1d(conv_out, conv_out.size(2)).squeeze(2)
out_hs.append(pool_out)
out_h = torch.cat(out_hs, 1) # prepare for fully connect
# Fully-connected Layers
out = torch.cat([out_v, out_h], 1)
# apply dropout
out = self.dropout(out)
# fully-connected layer
z = self.ac_fc(self.fc1(out))
x = torch.cat([z, user_emb], 1)
w2 = self.W2(item_var)
b2 = self.b2(item_var)
if for_pred:
w2 = w2.squeeze()
b2 = b2.squeeze()
res = (x * w2).sum(1) + b2
else:
res = torch.baddbmm(b2, w2, x.unsqueeze(2)).squeeze()
return res