-
Notifications
You must be signed in to change notification settings - Fork 23
/
batched_inv_mp.py
145 lines (106 loc) · 4.75 KB
/
batched_inv_mp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
like batched_inv, but this implementation runs the sparse matrix stuff in a set of separate processes to speed things up.
"""
import numpy as np
import wmf
import batched_inv
import multiprocessing as mp
import Queue
def buffered_gen_mp(source_gen, buffer_size=2, sleep_time=1):
"""
Generator that runs a slow source generator in a separate process.
buffer_size: the maximal number of items to pre-generate (length of the buffer)
"""
buffer = mp.Queue(maxsize=buffer_size)
def _buffered_generation_process(source_gen, buffer):
while True:
# we block here when the buffer is full. There's no point in generating more data
# when the buffer is full, it only causes extra memory usage and effectively
# increases the buffer size by one.
while buffer.full():
# print "DEBUG: buffer is full, waiting to generate more data."
time.sleep(sleep_time)
try:
data = source_gen.next()
except StopIteration:
# print "DEBUG: OUT OF DATA, CLOSING BUFFER"
buffer.close() # signal that we're done putting data in the buffer
break
buffer.put(data)
process = mp.Process(target=_buffered_generation_process, args=(source_gen, buffer))
process.start()
while True:
try:
# yield buffer.get()
# just blocking on buffer.get() here creates a problem: when get() is called and the buffer
# is empty, this blocks. Subsequently closing the buffer does NOT stop this block.
# so the only solution is to periodically time out and try again. That way we'll pick up
# on the 'close' signal.
try:
yield buffer.get(True, timeout=sleep_time)
except Queue.Empty:
if not process.is_alive():
break # no more data is going to come. This is a workaround because the buffer.close() signal does not seem to be reliable.
# print "DEBUG: queue is empty, waiting..."
pass # ignore this, just try again.
except IOError: # if the buffer has been closed, calling get() on it will raise IOError.
# this means that we're done iterating.
# print "DEBUG: buffer closed, stopping."
break
class CallableObject(object):
"""
Hack for multiprocessing stuff. This creates a callable wrapper object
with a single argument, that calls the original function with this argument
plus any other arguments passed at creation time.
"""
def __init__(self, func, *args, **kwargs):
self.func = func
self.args = args
self.kwargs = kwargs
def __call__(self, arg):
return self.func(arg, *self.args, **self.kwargs)
def get_row(S, i):
lo, hi = S.indptr[i], S.indptr[i + 1]
return S.data[lo:hi], S.indices[lo:hi]
def build_batch(b, S, Y_e, b_y, byY, YTYpR, batch_size, m, f, dtype):
lo = b * batch_size
hi = min((b + 1) * batch_size, m)
current_batch_size = hi - lo
A_stack = np.empty((current_batch_size, f + 1), dtype=dtype)
B_stack = np.empty((current_batch_size, f + 1, f + 1), dtype=dtype)
for ib, k in enumerate(xrange(lo, hi)):
s_u, i_u = get_row(S, k)
Y_u = Y_e[i_u] # exploit sparsity
b_y_u = b_y[i_u]
A = (s_u + 1).dot(Y_u)
A -= np.dot(b_y_u, (Y_u * s_u[:, None]))
A -= byY
YTSY = np.dot(Y_u.T, (Y_u * s_u[:, None]))
B = YTSY + YTYpR
A_stack[ib] = A
B_stack[ib] = B
return A_stack, B_stack
def recompute_factors_bias_batched_mp(Y, S, lambda_reg, dtype='float32', batch_size=1, solve=batched_inv.solve_sequential, num_batch_build_processes=4):
m = S.shape[0] # m = number of users
f = Y.shape[1] - 1 # f = number of factors
b_y = Y[:, f] # vector of biases
Y_e = Y.copy()
Y_e[:, f] = 1 # factors with added column of ones
YTY = np.dot(Y_e.T, Y_e) # precompute this
R = np.eye(f + 1) # regularization matrix
R[f, f] = 0 # don't regularize the biases!
R *= lambda_reg
YTYpR = YTY + R
byY = np.dot(b_y, Y_e) # precompute this as well
X_new = np.zeros((m, f + 1), dtype=dtype)
num_batches = int(np.ceil(m / float(batch_size)))
func = CallableObject(build_batch, S, Y_e, b_y, byY, YTYpR, batch_size, m, f, dtype)
pool = mp.Pool(num_batch_build_processes)
batch_gen = pool.imap(func, xrange(num_batches))
batch_gen_buffered = buffered_gen_mp(batch_gen, buffer_size=2, sleep_time=0.001)
for b, (A_stack, B_stack) in enumerate(batch_gen):
lo = b * batch_size
hi = min((b + 1) * batch_size, m)
X_stack = solve(A_stack, B_stack)
X_new[lo:hi] = X_stack
return X_new