Skip to content

Commit f9500d3

Browse files
author
Nicholas Leonard
committed
initial commit for C-lib
1 parent ef98a97 commit f9500d3

File tree

10 files changed

+791
-37
lines changed

10 files changed

+791
-37
lines changed

CMakeLists.txt

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,61 @@
1+
SET(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR})
12

23
CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR)
34
CMAKE_POLICY(VERSION 2.6)
4-
IF(LUAROCKS_PREFIX)
5-
MESSAGE(STATUS "Installing Torch through Luarocks")
6-
STRING(REGEX REPLACE "(.*)lib/luarocks/rocks.*" "\\1" CMAKE_INSTALL_PREFIX "${LUAROCKS_PREFIX}")
7-
MESSAGE(STATUS "Prefix inferred from Luarocks: ${CMAKE_INSTALL_PREFIX}")
8-
ENDIF()
5+
96
FIND_PACKAGE(Torch REQUIRED)
107

11-
SET(src)
12-
FILE(GLOB luasrc *.lua)
13-
SET(luasrc ${luasrc})
14-
ADD_SUBDIRECTORY(test)
15-
ADD_TORCH_PACKAGE(rnn "${src}" "${luasrc}" "Recurrent Neural Networks")
8+
SET(BUILD_STATIC YES) # makes sure static targets are enabled in ADD_TORCH_PACKAGE
9+
10+
SET(CMAKE_C_FLAGS "--std=c99 -pedantic -Werror -Wall -Wextra -Wno-unused-function -D_GNU_SOURCE ${CMAKE_C_FLAGS}")
11+
SET(src
12+
init.c
13+
)
14+
SET(luasrc
15+
init.lua
16+
AbstractRecurrent.lua
17+
AbstractSequencer.lua
18+
BiSequencer.lua
19+
BiSequencerLM.lua
20+
CopyGrad.lua
21+
Dropout.lua
22+
ExpandAs.lua
23+
FastLSTM.lua
24+
GRU.lua
25+
LinearNoBias.lua
26+
LookupTableMaskZero.lua
27+
LSTM.lua
28+
MaskZero.lua
29+
MaskZeroCriterion.lua
30+
Module.lua
31+
Mufuru.lua
32+
NormStabilizer.lua
33+
Padding.lua
34+
Recurrence.lua
35+
Recurrent.lua
36+
RecurrentAttention.lua
37+
recursiveUtils.lua
38+
Recursor.lua
39+
Repeater.lua
40+
RepeaterCriterion.lua
41+
SAdd.lua
42+
SeqBRNN.lua
43+
SeqGRU.lua
44+
SeqLSTM.lua
45+
SeqLSTMP.lua
46+
SeqReverseSequence.lua
47+
Sequencer.lua
48+
SequencerCriterion.lua
49+
TrimZero.lua
50+
ZeroGrad.lua
51+
test/bigtest.lua
52+
test/test.lua
53+
)
54+
55+
ADD_TORCH_PACKAGE(rnn "${src}" "${luasrc}" "An RNN library for Torch")
56+
57+
TARGET_LINK_LIBRARIES(rnn luaT TH)
58+
59+
SET_TARGET_PROPERTIES(rnn_static PROPERTIES COMPILE_FLAGS "-fPIC -DSTATIC_TH")
60+
61+
INSTALL(FILES ${luasrc} DESTINATION "${Torch_INSTALL_LUA_PATH_SUBDIR}/rnn")

TrimZero.lua

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
-- Decorator that zeroes the output rows of the encapsulated module
77
-- for commensurate input rows which are tensors of zeros
88

9-
-- The only difference from `MaskZero` is that it reduces computational costs
10-
-- by varying a batch size, if any, for the case that varying lengths
11-
-- are provided in the input. Notice that when the lengths are consistent,
12-
-- `MaskZero` will be faster, because `TrimZero` has an operational cost.
9+
-- The only difference from `MaskZero` is that it reduces computational costs
10+
-- by varying a batch size, if any, for the case that varying lengths
11+
-- are provided in the input. Notice that when the lengths are consistent,
12+
-- `MaskZero` will be faster, because `TrimZero` has an operational cost.
1313

1414
-- In short, the result is the same with `MaskZero`'s, however, `TrimZero` is
1515
-- faster than `MaskZero` only when sentence lengths is costly vary.
@@ -38,7 +38,7 @@ function TrimZero:recursiveMask(output, input, mask)
3838
else
3939
assert(torch.isTensor(input))
4040
output = torch.isTensor(output) and output or input.new()
41-
41+
4242
-- make sure mask has the same dimension as the input tensor
4343
if torch.type(mask) ~= 'torch.LongTensor' then
4444
local inputSize = input:size():fill(1)
@@ -48,7 +48,7 @@ function TrimZero:recursiveMask(output, input, mask)
4848
end
4949
mask:resize(inputSize)
5050
end
51-
51+
5252
-- build mask
5353
if self.batchmode then
5454
assert(torch.find, 'install torchx package : luarocks install torchx')
@@ -67,11 +67,11 @@ function TrimZero:recursiveMask(output, input, mask)
6767
else
6868
output:index(input, 1, torch.LongTensor{1}):zero()
6969
end
70-
else
71-
if mask:dim() == 0 or mask:view(-1)[1] == 1 then
72-
output:resize(input:size()):zero()
73-
else
74-
output:resize(input:size()):copy(input)
70+
else
71+
if mask:dim() == 0 or mask:view(-1)[1] == 1 then
72+
output:resize(input:size()):zero()
73+
else
74+
output:resize(input:size()):copy(input)
7575
end
7676
end
7777
end
@@ -87,14 +87,14 @@ function TrimZero:recursiveUnMask(output, input, mask)
8787
else
8888
assert(torch.isTensor(input))
8989
output = torch.isTensor(output) and output or input.new()
90-
90+
9191
-- make sure output has the same dimension as the mask
9292
local inputSize = input:size()
9393
if self.batchmode then
9494
inputSize[1] = mask:size(1)
9595
end
9696
output:resize(inputSize):zero()
97-
97+
9898
-- build mask
9999
if self.batchmode then
100100
assert(self._maskindices)
@@ -103,7 +103,7 @@ function TrimZero:recursiveUnMask(output, input, mask)
103103
output:indexCopy(1, mask, input)
104104
end
105105
else
106-
if mask:view(-1)[1] == 0 then
106+
if mask:view(-1)[1] == 0 then
107107
output:copy(input)
108108
end
109109
end
@@ -123,17 +123,17 @@ function TrimZero:updateOutput(input)
123123
else
124124
error("nInputDim error: "..rmi:dim()..", "..self.nInputDim)
125125
end
126-
126+
127127
-- build mask
128-
local vectorDim = rmi:dim()
128+
local vectorDim = rmi:dim()
129129
self._zeroMask = self._zeroMask or rmi.new()
130130
self._zeroMask:norm(rmi, 2, vectorDim)
131131
self.zeroMask = self.zeroMask or ((torch.type(rmi) == 'torch.CudaTensor') and torch.CudaTensor() or torch.ByteTensor())
132132
self._zeroMask.eq(self.zeroMask, self._zeroMask, 0)
133-
133+
134134
-- forward through decorated module
135135
self.temp = self:recursiveMask(self.temp, input, self.zeroMask)
136-
output = self.module:updateOutput(self.temp)
136+
output = self.modules[1]:updateOutput(self.temp)
137137
self.output = self:recursiveUnMask(self.output, output, self.zeroMask, true)
138138

139139
return self.output
@@ -143,7 +143,7 @@ function TrimZero:updateGradInput(input, gradOutput)
143143
self.temp = self:recursiveMask(self.temp, input, self.zeroMask)
144144
self.gradTemp = self:recursiveMask(self.gradTemp, gradOutput, self.zeroMask)
145145

146-
local gradInput = self.module:updateGradInput(self.temp, self.gradTemp)
146+
local gradInput = self.modules[1]:updateGradInput(self.temp, self.gradTemp)
147147

148148
self.gradInput = self:recursiveUnMask(self.gradInput, gradInput, self.zeroMask)
149149

@@ -152,5 +152,5 @@ end
152152

153153
function TrimZero:accGradParameters(input, gradOutput, scale)
154154
self.temp = self:recursiveMask(self.temp, input, self.zeroMask)
155-
self.module:accGradParameters(self.temp, gradOutput, scale)
155+
self.modules[1]:accGradParameters(self.temp, gradOutput, scale)
156156
end

VariableLength.lua

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
local VariableLength, parent = torch.class("nn.VariableLength", "nn.Decorator")
2+
3+
-- make sure your module has been set-up for zero-masking (that is, module:maskZero())
4+
function VariableLength:__init(module, lastOnly)
5+
parent.__init(self, module)
6+
-- only extract the last element of each sequence
7+
self.lastOnly = lastOnly -- defaults to false
8+
end
9+
10+
-- recursively masks input (inplace)
11+
function VariableLength.recursiveMask(input, mask)
12+
if torch.type(input) == 'table' then
13+
for k,v in ipairs(input) do
14+
self.recursiveMask(v, mask)
15+
end
16+
else
17+
assert(torch.isTensor(input))
18+
19+
-- make sure mask has the same dimension as the input tensor
20+
assert(mask:dim() == 2, "Expecting batchsize x seqlen mask tensor")
21+
-- expand mask to input (if necessary)
22+
local zeroMask
23+
if input:dim() == 2 then
24+
zeroMask = mask
25+
elseif input:dim() > 2 then
26+
local inputSize = input:size():fill(1)
27+
inputSize[1] = input:size(1)
28+
inputSize[2] = input:size(2)
29+
mask:resize(inputSize)
30+
zeroMask = mask:expandAs(input)
31+
else
32+
error"Expecting batchsize x seqlen [ x ...] input tensor"
33+
end
34+
-- zero-mask input in between sequences
35+
input:maskedFill(zeroMask, 0)
36+
end
37+
end
38+
39+
function VariableLength:updateOutput(input)
40+
-- input is a table of batchSize tensors
41+
assert(torch.type(input) == 'table')
42+
assert(torch.isTensor(input[1]))
43+
local batchSize = #input
44+
45+
self._input = self._input or input[1].new()
46+
-- mask is a binary tensor with 1 where self._input is zero (between sequence zero-mask)
47+
self._mask = self._mask or torch.ByteTensor()
48+
49+
-- now we process input into _input.
50+
-- indexes and mappedLengths are meta-information tables, explained below.
51+
self.indexes, self.mappedLengths = self._input.nn.VariableLength_FromSamples(input, self._input, self._mask)
52+
53+
-- zero-mask the _input where mask is 1
54+
self.recursiveMask(self._input, self._mask)
55+
56+
-- feedforward the zero-mask format through the decorated module
57+
local output = self.modules[1]:updateOutput(self._input)
58+
59+
if self.lastOnly then
60+
-- Extract the last time step of each sample.
61+
-- self.output tensor has shape: batchSize [x outputSize]
62+
self.output = torch.isTensor(self.output) and self.output or output.new()
63+
self.output.nn.VariableLength_ToFinal(selfindexes, self.mappedLengths, output, self.output)
64+
else
65+
-- This is the revese operation of everything before updateOutput
66+
self.output = input.nn.VariableLength_ToSamples(self.indexes, self.mappedLengths, output)
67+
end
68+
69+
return self.output
70+
end
71+
72+
function VariableLength:updateGradInput(input, gradInput)
73+
74+
return self.gradInput
75+
end
76+
77+
function VariableLength:accGradParameters(input, gradInput, scale)
78+
79+
end

error.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#ifndef _ERROR_H_
2+
#define _ERROR_H_
3+
4+
#include "luaT.h"
5+
#include <string.h>
6+
7+
static inline int _lua_error(lua_State *L, int ret, const char* file, int line) {
8+
int pos_ret = ret >= 0 ? ret : -ret;
9+
return luaL_error(L, "ERROR: (%s, %d): (%d, %s)\n", file, line, pos_ret, strerror(pos_ret));
10+
}
11+
12+
static inline int _lua_error_str(lua_State *L, const char *str, const char* file, int line) {
13+
return luaL_error(L, "ERROR: (%s, %d): (%s)\n", file, line, str);
14+
}
15+
16+
static inline int _lua_error_str_str(lua_State *L, const char *str, const char* file, int line, const char *extra) {
17+
return luaL_error(L, "ERROR: (%s, %d): (%s: %s)\n", file, line, str, extra);
18+
}
19+
20+
#define LUA_HANDLE_ERROR(L, ret) _lua_error(L, ret, __FILE__, __LINE__)
21+
#define LUA_HANDLE_ERROR_STR(L, str) _lua_error_str(L, str, __FILE__, __LINE__)
22+
#define LUA_HANDLE_ERROR_STR_STR(L, str, extra) _lua_error_str_str(L, str, __FILE__, __LINE__, extra)
23+
24+
#endif

0 commit comments

Comments
 (0)