Replies: 1 comment 5 replies
-
I can suggest a possible implementation of this in JAX/Flax however i am not 100% certain that this Flax code is working exactly as the pytorch implementation but lets give it a try so to implement this:
so in PyTorch that is done in 2 lines..In JAX we need to be alert regarding the input and how it is being passed to the LSTM cell and then track each hidden state,cell state so as u know for each time instance t we would need to compute the hidden state h and cell state c carry = [ but we also need to loop over time instances which is: so as per me the full code should be: class LSTM(nn.Module):
if anyone thinks am missing something or got something wrong, please let me know!! |
Beta Was this translation helpful? Give feedback.
-
Hi 👋,
I'm new to using Jax and I'm finding it challenging to convert my models from PyTorch to Flax NNX, particularly the LSTMs. For example, the following model processes data in batches of length 64 with 553 timesteps and 1 feature (64, 553, 1). How would the same model be implemented in Flax NNX? I feel that it would be helpful to supplement the package's documentation with more examples to assist new users.
Beta Was this translation helpful? Give feedback.
All reactions