Skip to content

Commit ab033f4

Browse files
committed
feat: batch size implementation
1 parent 33539c5 commit ab033f4

File tree

7 files changed

+42
-26
lines changed

7 files changed

+42
-26
lines changed

copy_task.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def train(epochs=50_000):
5353
writer.add_scalar("hidden_layer_size", hidden_layer_size)
5454
writer.add_scalar("lstm_controller", lstm_controller)
5555
writer.add_scalar("seed", seed)
56+
writer.add_scalar("batch_size", batch_size)
5657

5758
model = NTM(vector_length, hidden_layer_size, memory_size, lstm_controller)
5859

@@ -76,7 +77,7 @@ def train(epochs=50_000):
7677
_, state = model(vector, state)
7778
y_out = torch.zeros(target.size())
7879
for j in range(len(target)):
79-
y_out[j], state = model(torch.zeros(1, vector_length + 1), state)
80+
y_out[j], state = model(torch.zeros(batch_size, vector_length + 1), state)
8081
loss = F.binary_cross_entropy(y_out, target)
8182
loss.backward()
8283
optimizer.step()

ntm/head.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,23 @@ def get_head_weight(self, x, previous_state, memory_read):
3232
w_c = F.softmax(beta * F.cosine_similarity(memory_read + 1e-16, k.unsqueeze(1) + 1e-16, dim=-1), dim=1)
3333
# Focusing by location
3434
w_g = g * w_c + (1 - g) * previous_state
35-
w_t = _convolve(w_g, s)
35+
w_t = self.shift(w_g, s)
3636
w = w_t ** gamma
3737
w = torch.div(w, torch.sum(w, dim=1).view(-1, 1) + 1e-16)
3838
return w
3939

40+
def shift(self, w_g, s):
41+
result = torch.zeros(w_g.size())
42+
for b in range(len(w_g)):
43+
result[b] = _convolve(w_g[b], s[b])
44+
return result
45+
4046

4147
class ReadHead(Head):
4248
def forward(self, x, previous_state):
4349
memory_read = self.memory.read()
4450
w = self.get_head_weight(x, previous_state, memory_read)
45-
return torch.matmul(w, memory_read), w
51+
return torch.matmul(w.unsqueeze(1), memory_read).squeeze(1), w
4652

4753

4854
class WriteHead(Head):

ntm/memory.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ class Memory(nn.Module):
77
def __init__(self, memory_size):
88
super(Memory, self).__init__()
99
self._memory_size = memory_size
10-
print(self._memory_size)
10+
1111
# Initialize memory bias
1212
stdev = 1 / (np.sqrt(memory_size[0] + memory_size[1]))
1313
intial_state = torch.Tensor(memory_size[0], memory_size[1]).uniform_(-stdev, stdev)
@@ -29,8 +29,8 @@ def read(self):
2929
return self.memory
3030

3131
def write(self, w, e, a):
32-
self.memory = self.memory * (1 - torch.t(w) * e)
33-
self.memory = self.memory + torch.t(w) * a
32+
self.memory = self.memory * (1 - torch.matmul(torch.t(w), e))
33+
self.memory = self.memory + torch.matmul(torch.t(w), a)
3434
return self.memory
3535

3636
def size(self):

ntm/ntm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@ def __init__(self, vector_length, hidden_size, memory_size, lstm_controller=True
1717
nn.init.xavier_uniform_(self.fc.weight, gain=1)
1818
nn.init.normal_(self.fc.bias, std=0.01)
1919

20-
def get_initial_state(self, batch_size):
20+
def get_initial_state(self, batch_size=1):
2121
self.memory.reset(batch_size)
2222
controller_state = self.controller.get_initial_state(batch_size)
2323
read = self.memory.get_initial_state(batch_size)
2424
read_head_state = self.read_head.get_initial_state(batch_size)
25-
print("read_head_state.shape", read_head_state.shape)
2625
write_head_state = self.write_head.get_initial_state(batch_size)
2726
return (read, read_head_state, write_head_state, controller_state)
2827

ntm/utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,9 @@ def circular_convolution(w, s):
1515

1616
def _convolve(w, s):
1717
"""Circular convolution implementation."""
18-
assert s.size(1) == 3
19-
print(w.shape)
20-
t = torch.cat([w[:, -1:], w, w[:, :1]], dim=1)
21-
print(t.shape)
22-
c = F.conv1d(t.unsqueeze(1), s.view(1, 1, -1))
23-
print(c.shape)
18+
assert s.size(0) == 3
19+
t = torch.cat([w[-1:], w, w[:1]], dim=0)
20+
c = F.conv1d(t.view(1, 1, -1), s.view(1, 1, -1)).view(-1)
2421
return c
2522

2623

repeat_task.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,20 @@
2323
# torch.manual_seed(seed)
2424

2525

26-
def get_training_sequence(sequence_min_length, sequence_max_length, vector_length):
26+
def get_training_sequence(sequence_min_length, sequence_max_length, vector_length, batch_size=1):
2727
sequence_length = random.randint(sequence_min_length, sequence_max_length)
2828
repeat = random.randint(sequence_min_length, sequence_max_length)
2929

30-
target = torch.bernoulli(torch.Tensor(sequence_length, vector_length).uniform_(0, 1))
31-
target = torch.unsqueeze(target, 1)
30+
target = torch.bernoulli(torch.Tensor(sequence_length, batch_size, vector_length).uniform_(0, 1))
3231

33-
input = torch.zeros(sequence_length + 2, 1, vector_length + 2)
32+
input = torch.zeros(sequence_length + 2, batch_size, vector_length + 2)
3433
input[:sequence_length, :, :vector_length] = target
3534
# delimiter vector
3635
input[sequence_length, :, vector_length] = 1.0
3736
# repeat channel
38-
input[sequence_length+1, :, vector_length+1] = repeat / sequence_max_length
37+
input[sequence_length + 1, :, vector_length + 1] = repeat / sequence_max_length
3938

40-
output = torch.zeros(sequence_length * repeat + 1, 1, vector_length + 1)
39+
output = torch.zeros(sequence_length * repeat + 1, batch_size, vector_length + 1)
4140
output[:sequence_length * repeat, :, :vector_length] = target.clone().repeat(repeat, 1, 1)
4241
# delimiter vector
4342
output[-1, :, -1] = 1.0
@@ -53,6 +52,7 @@ def train(epochs=50_000):
5352
vector_length = 8
5453
memory_size = (128, 20)
5554
hidden_layer_size = 100
55+
batch_size = 2
5656
lstm_controller = not args.ff
5757

5858
writer.add_scalar("sequence_min_length", sequence_min_length)
@@ -63,6 +63,7 @@ def train(epochs=50_000):
6363
writer.add_scalar("hidden_layer_size", hidden_layer_size)
6464
writer.add_scalar("lstm_controller", lstm_controller)
6565
writer.add_scalar("seed", seed)
66+
writer.add_scalar("batch_size", batch_size)
6667

6768
model = NTM(vector_length + 1, hidden_layer_size, memory_size, lstm_controller)
6869

@@ -78,15 +79,15 @@ def train(epochs=50_000):
7879
checkpoint = torch.load(model_path)
7980
model.load_state_dict(checkpoint)
8081

81-
for epoch in range(epochs):
82+
for epoch in range(epochs + 1):
8283
optimizer.zero_grad()
83-
input, target = get_training_sequence(sequence_min_length, sequence_max_length, vector_length)
84-
state = model.get_initial_state()
84+
input, target = get_training_sequence(sequence_min_length, sequence_max_length, vector_length, batch_size)
85+
state = model.get_initial_state(batch_size)
8586
for vector in input:
8687
_, state = model(vector, state)
8788
y_out = torch.zeros(target.size())
8889
for j in range(len(target)):
89-
y_out[j], state = model(torch.zeros(1, vector_length + 2), state)
90+
y_out[j], state = model(torch.zeros(batch_size, vector_length + 2), state)
9091
loss = F.binary_cross_entropy(y_out, target)
9192
loss.backward()
9293
optimizer.step()
@@ -95,7 +96,7 @@ def train(epochs=50_000):
9596
y_out_binarized.apply_(lambda x: 0 if x < 0.5 else 1)
9697
cost = torch.sum(torch.abs(y_out_binarized - target)) / len(target)
9798
total_cost.append(cost.item())
98-
if epoch % feedback_frequence == feedback_frequence - 1:
99+
if epoch % feedback_frequence == 0:
99100
running_loss = sum(total_loss) / len(total_loss)
100101
running_cost = sum(total_cost) / len(total_cost)
101102
print(f"Loss at step {epoch}: {running_loss}")

tests/test_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

3-
from ..ntm.utils import circular_convolution
3+
from ..ntm.utils import circular_convolution, _convolve
44

55

66
def test_circular_convolution():
@@ -23,3 +23,15 @@ def test_circular_convolution():
2323
b = torch.tensor([[1, 0, 1, 0, 0]])
2424
res = torch.tensor([[5, 7, 4, 6, 8]])
2525
assert torch.equal(circular_convolution(a, b), res)
26+
27+
28+
def test_convolve():
29+
w = torch.tensor([0, 0, 1, 0, 0])
30+
s = torch.tensor([0, 1, 0])
31+
res = torch.tensor([0, 0, 1, 0, 0])
32+
assert torch.equal(_convolve(w, s), res)
33+
34+
w = torch.tensor([0, 0, 1.0, 0, 0])
35+
s = torch.tensor([0.5, 0, 0.5])
36+
res = torch.tensor([0, 0.5, 0, 0.5, 0])
37+
assert torch.equal(_convolve(w, s), res)

0 commit comments

Comments
 (0)