You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
RuntimeError Traceback (most recent call last)
Cell In[105], line 8
5 lstm_input = torch.randn(5, 4, 132) # (Seq_size, batch, feature_size)
7 lstm_out, hidden_state = lstm(lstm_input)
----> 8 lstm_out_1, hidden_state_1 = lstm(inp=lstm_input, hx=hidden_state[0], cx=hidden_state[1])
File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File /usr/local/lib/python3.8/dist-packages/brevitas/nn/quant_rnn.py:909, in QuantLSTM.forward(self, inp, hx, cx)
907 layer_hidden_state = hx[2 * l + d] if hx is not None else hx
908 layer_cell_state = cx[2 * l + d] if cx is not None else cx
--> 909 out, out_hidden_state, out_cell_state = direction(inp, layer_hidden_state, layer_cell_state)
910 dir_outputs += [out]
911 dir_hidden_states += [out_hidden_state]
File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File /usr/local/lib/python3.8/dist-packages/brevitas/nn/quant_rnn.py:654, in _QuantLSTMLayer.forward(self, inp, hidden_state, cell_state)
652 else:
653 cell = self.cell
--> 654 quant_outputs, quant_hidden_state, quant_cell_state = cell(
655 quant_input.value,
656 quant_hidden_state.value,
657 quant_cell_state.value,
658 quant_weight_ii=quant_weight_ii.value,
659 quant_weight_if=quant_weight_if.value,
660 quant_weight_ic=quant_weight_ic.value,
661 quant_weight_io=quant_weight_io.value,
662 quant_weight_hi=quant_weight_hi.value,
663 quant_weight_hf=quant_weight_hf.value,
664 quant_weight_hc=quant_weight_hc.value,
665 quant_weight_ho=quant_weight_ho.value,
666 quant_bias_input=quant_bias_input,
667 quant_bias_forget=quant_bias_forget,
668 quant_bias_cell=quant_bias_cell,
669 quant_bias_output=quant_bias_output)
670 quant_outputs = self.pack_quant_outputs(quant_outputs)
671 quant_hidden_state = self.pack_quant_state(quant_hidden_state, self.cell.output_quant)
File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File /usr/local/lib/python3.8/dist-packages/brevitas/nn/quant_rnn.py:280, in _QuantLSTMCell.forward(self, quant_input, quant_hidden_state, quant_cell_state, quant_weight_ii, quant_weight_if, quant_weight_ic, quant_weight_io, quant_weight_hi, quant_weight_hf, quant_weight_hc, quant_weight_ho, quant_bias_input, quant_bias_forget, quant_bias_cell, quant_bias_output)
278 for _ in range(end):
279 quant_input = quant_inputs[index]
--> 280 quant_hidden_state_tuple, quant_cell_state_tuple = self.forward_iter(
281 quant_input,
282 quant_hidden_state,
283 quant_cell_state,
284 quant_weight_ii,
285 quant_weight_if,
286 quant_weight_ic,
287 quant_weight_io,
288 quant_weight_hi,
289 quant_weight_hf,
290 quant_weight_hc,
291 quant_weight_ho,
292 quant_bias_input,
293 quant_bias_forget,
294 quant_bias_cell,
295 quant_bias_output)
296 index = index + step
297 quant_hidden_states += [quant_hidden_state_tuple]
File /usr/local/lib/python3.8/dist-packages/brevitas/nn/quant_rnn.py:217, in _QuantLSTMCell.forward_iter(self, quant_input, quant_hidden_state, quant_cell_state, quant_weight_ii, quant_weight_if, quant_weight_ic, quant_weight_io, quant_weight_hi, quant_weight_hf, quant_weight_hc, quant_weight_ho, quant_bias_input, quant_bias_forget, quant_bias_cell, quant_bias_output)
215 quant_ii_gate = F.linear(quant_input, quant_weight_ii)
216 quant_hi_gate = F.linear(quant_hidden_state, quant_weight_hi)
--> 217 quant_input_gate = self.input_acc_quant(quant_ii_gate + quant_hi_gate + quant_bias_input)[0]
218 quant_input_gate = self.input_sigmoid_quant(quant_input_gate)[0]
219 # Forget gate
RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 1
I want to know how to remove this error. Do I make some mistakes about the way to use QuantLSTM and brevitas?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi, I'm now using QuantLSTM, and I want to know the error below.
Firstly I give an input to QuantLSTM without hidden and cell states.
Next I want to give input and hidden, cell states but I got an error.
This is my code.
And this is the error.
I want to know how to remove this error. Do I make some mistakes about the way to use QuantLSTM and brevitas?
Beta Was this translation helpful? Give feedback.
All reactions