You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi all! I'm in the process of converting some of my code to Jax, and one thing that I'm currently struggling to figure out is how to set up my recurrent model correctly. In PyTorch I can set up a recurrent model pretty easily as below, where I can set up an LSTM that is able to handle batches of sequence data with variable lengths. Furthermore, with the LSTM I am passing the hidden state around when doing inference for action selection.
class Actor(nn.Module):
"""Actor model that produces actions given observation sequences."""
def __init__(self, obs_dim: int, action_dim: int):
"""Initializes the critic model.
Parameters
----------
obs_dim : int
Dimensionality of observations.
action_dim : int
Dimensionality of actions.
"""
super().__init__()
# Configure embedding layer
self.embedding = nn.Embedding(obs_dim, 256)
# Set up recurrent layer for processing sequences
self.recurrence = nn.LSTM(256, 256, batch_first=True)
# Set up post-processing layers
self.post_processing_1 = nn.Linear(256, 256)
self.post_processing_2 = nn.Linear(256, action_dim)
self.softmax = nn.Softmax(dim=-1)
def forward(
self,
observations: torch.Tensor,
seq_lengths: torch.Tensor,
in_hidden: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calculates action probabilities given batch of sequences.
Parameters
----------
observations : torch.Tensor
Batches of sequences of observations.
seq_lengths : torch.Tensor
Sequence lengths of episodes.
in_hidden : torch.Tensor, optional
Hidden state used during inference.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
Action distribution.
"""
# Embedding layer
x = self.embedding(observations)
# Recurrent layer
x = pack_padded_sequence(x, seq_lengths, batch_first=True, enforce_sorted=False)
self.recurrence.flatten_parameters()
x, out_hidden = self.recurrence(x, in_hidden)
x, x_unpacked_len = pad_packed_sequence(x, batch_first=True)
# Remaining layers
x = F.relu(self.post_processing_1(x))
action_logits = self.post_processing_2(x)
action_probs = self.softmax(action_logits)
return action_probs, out_hidden
The part I'm struggling with in particular is how to convert the LSTM layer, handle batches of variable length sequences, and pass the hidden state during inference. I've looked at a couple different codebases on GitHub about how they handle this, but they all seem to do it differently. Based on the docs here, it seems like I can use a seq_lengths variable very similar to PyTorch to handle the variable lengths. How do I implement the LSTM, is it with a jax.lax.scan across multiple cells? Etc.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi all! I'm in the process of converting some of my code to Jax, and one thing that I'm currently struggling to figure out is how to set up my recurrent model correctly. In PyTorch I can set up a recurrent model pretty easily as below, where I can set up an LSTM that is able to handle batches of sequence data with variable lengths. Furthermore, with the LSTM I am passing the hidden state around when doing inference for action selection.
The part I'm struggling with in particular is how to convert the LSTM layer, handle batches of variable length sequences, and pass the hidden state during inference. I've looked at a couple different codebases on GitHub about how they handle this, but they all seem to do it differently. Based on the docs here, it seems like I can use a
seq_lengths
variable very similar to PyTorch to handle the variable lengths. How do I implement the LSTM, is it with ajax.lax.scan
across multiple cells? Etc.Thanks in advance for any help!
Beta Was this translation helpful? Give feedback.
All reactions