forked from IRC-SPHERE/ADL-TM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathADL_TM.py
134 lines (107 loc) · 2.97 KB
/
ADL_TM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# coding: utf-8
# In[309]:
import pandas as pd
import scipy as sp
import numpy as np
import csv
import sys
import os
import math
import random as rd
def GibbsSampler(widf,docs,wd,T,ITER,lidf,lwd,zdin=[]):
# In[404]:
D = len(docs)
V = max(wd)
M = max(lwd)
ALPHA = 50/float(T)
BETA = 5/float(V)
MU = 5/float(M)
WBETA = V*BETA
TALPHA = T*ALPHA
MMU = M * MU
# In[406]:
wtp = np.zeros((T,V,V))
lwtp = np.zeros((T,M))
totz = np.zeros((T,1),dtype=int)
zd = np.zeros((D,1),dtype=int)
# initialise
di = 0
if (len(zdin)>0):
for d in docs:
k = zdin[di]
zd[di] = k
totz[k] +=1
wtp[k,wd[d[0]:d[1]]-1,wd[d[0]+1:d[1]+1]-1]+=1
for i in range(d[0]+1,d[1]+1):
lwtp[k,lwd[i]-1]+= 1
di+=1
else:
for d in docs:
k = rd.randint(0,T-1)
zd[di] = k
totz[k] +=1
#lwtp[k,lwd[d[0]:d[1]+1]-1]+=1
for i in range(d[0],d[1]):
lwtp[k,lwd[i]-1]+= 1
wtp[k,wd[i]-1,wd[i+1]-1]+=1
lwtp[k,lwd[d[1]]-1]+= 1
di+=1
# In[408]:
for itr in range(ITER):
print('iter: {}'.format(itr))
di = 0
for d in docs:
#print d
u = zd[di]
totz[u]-=1
#lwtp[u,lwd[d[0]:d[1]]-1]-=1
for i in range(d[0]+1,d[1]+1):
lwtp[u,lwd[i]-1]-= 1
wtp[u,wd[d[0]:d[1]]-1,wd[d[0]+1:d[1]+1]-1]-=1
prob = totz + ALPHA
#tmp1 = (wtp[:,wd[d[0]:d[1]]-1,wd[d[0]+1:d[1]+1]-1] + BETA)/(np.sum(wtp[:,wd[d[0]:d[1]]-1,:],min(2,d[1]-d[0]))+WBETA)
tmp1 = (wtp[:,wd[d[0]:d[1]]-1,wd[d[0]+1:d[1]+1]-1] + BETA)/(np.sum(wtp[:,wd[d[0]:d[1]]-1,:],2)+WBETA)
tmp2 = (lwtp[:,lwd[d[0]:d[1]+1]-1] + MU)/np.repeat(np.array(np.sum(lwtp,1) + MMU,ndmin=2).transpose(), d[1]-d[0]+1, 1)
#print('prob shape',prob.shape, tmp1.shape, np.prod(tmp2,1).shape)
#if d[1]-d[0] > 1:
prob = (prob.transpose() * (np.prod(tmp1,1) * np.prod(tmp2,1)))[0]
#else:
#prob = (prob.transpose() * (tmp1 * np.prod(tmp2,1)))[0]
#print prob
totprob = sum(prob)
#sample from topic distribution
r = rd.random() * totprob
maxprob = prob[0]
k=0
while(r > maxprob):
#print k,r,maxprob
k+=1
maxprob += prob[k]
zd[di] = k
wtp[k,wd[d[0]:d[1]]-1,wd[d[0]+1:d[1]+1]-1]+=1
#lwtp[k,lwd[d[0]:d[1]+1]-1]+=1
for i in range(d[0]+1,d[1]+1):
lwtp[k,lwd[i]-1]+= 1
totz[k]+=1
di+=1
return(wtp,lwtp, zd, totz)
# In[ ]:
def predictDoc(wtp,lwtp,wd,lwd,docs):
#print wtp
zd = np.zeros(len(docs))
di = 0
prob = np.zeros((len(docs),wtp.shape[0]))
for d in docs:
#print d
tmp1 = (wtp[:,wd[d[0]:d[1]]-1,wd[d[0]+1:d[1]+1]-1])
tmp2 = (lwtp[:,lwd[d[0]:d[1]+1]-1])
if d[1]-d[0] > 1:
prob[di,:] += (np.sum(np.log(tmp1),1) + np.sum(np.log(tmp2),1))
elif (d[1]-d[0])==1:
prob[di,:] += (np.log(tmp1) + np.sum(np.log(tmp2),1))
else:
prob[di,:] += np.log(tmp2)
zd[di]=np.argmax(prob[di,:])
#print zd[di]
di+=1
return (zd,prob)