Skip to content

Commit

Permalink
training code for hybrid-autoregressive inference model (NVIDIA#10841)
Browse files Browse the repository at this point in the history
* training code for hybrid-autoregressive inference model

Signed-off-by: Hainan Xu <[email protected]>

* Apply isort and black reformatting

Signed-off-by: hainan-xv <[email protected]>

---------

Signed-off-by: Hainan Xu <[email protected]>
Signed-off-by: hainan-xv <[email protected]>
Co-authored-by: Hainan Xu <[email protected]>
Co-authored-by: hainan-xv <[email protected]>
  • Loading branch information
3 people authored Oct 14, 2024
1 parent f51c04f commit b08d46c
Showing 1 changed file with 47 additions and 26 deletions.
73 changes: 47 additions & 26 deletions nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ class StatelessTransducerDecoder(rnnt_abstract.AbstractRNNTDecoder, Exportable):

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
return {
"targets": NeuralType(('B', 'T'), LabelsType()),
"target_length": NeuralType(tuple('B'), LengthsType()),
Expand All @@ -84,8 +83,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
"outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()),
"prednet_lengths": NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -382,15 +380,20 @@ def batch_concat_states(self, batch_states: List[List[torch.Tensor]]) -> List[to

@classmethod
def batch_replace_states_mask(
cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor], mask: torch.Tensor,
cls,
src_states: list[torch.Tensor],
dst_states: list[torch.Tensor],
mask: torch.Tensor,
):
"""Replace states in dst_states with states from src_states using the mask"""
# same as `dst_states[0][mask] = src_states[0][mask]`, but non-blocking
torch.where(mask.unsqueeze(-1), src_states[0], dst_states[0], out=dst_states[0])

@classmethod
def batch_replace_states_all(
cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor],
cls,
src_states: list[torch.Tensor],
dst_states: list[torch.Tensor],
):
"""Replace states in dst_states with states from src_states"""
dst_states[0].copy_(src_states[0])
Expand Down Expand Up @@ -591,8 +594,7 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder, Exportable, AdapterModuleMi

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
return {
"targets": NeuralType(('B', 'T'), LabelsType()),
"target_length": NeuralType(tuple('B'), LengthsType()),
Expand All @@ -601,8 +603,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
"outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()),
"prednet_lengths": NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -1018,19 +1019,19 @@ def batch_score_hypothesis(

def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]):
"""
Create batch of decoder states.
Create batch of decoder states.
Args:
batch_states (list): batch of decoder states
([L x (B, H)], [L x (B, H)])
Args:
batch_states (list): batch of decoder states
([L x (B, H)], [L x (B, H)])
decoder_states (list of list): list of decoder states
[B x ([L x (1, H)], [L x (1, H)])]
decoder_states (list of list): list of decoder states
[B x ([L x (1, H)], [L x (1, H)])]
Returns:
batch_states (tuple): batch of decoder states
([L x (B, H)], [L x (B, H)])
"""
Returns:
batch_states (tuple): batch of decoder states
([L x (B, H)], [L x (B, H)])
"""
# LSTM has 2 states
new_states = [[] for _ in range(len(decoder_states[0]))]
for layer in range(self.pred_rnn_layers):
Expand Down Expand Up @@ -1109,7 +1110,9 @@ def batch_replace_states_mask(

@classmethod
def batch_replace_states_all(
cls, src_states: Tuple[torch.Tensor, torch.Tensor], dst_states: Tuple[torch.Tensor, torch.Tensor],
cls,
src_states: Tuple[torch.Tensor, torch.Tensor],
dst_states: Tuple[torch.Tensor, torch.Tensor],
):
"""Replace states in dst_states with states from src_states"""
dst_states[0].copy_(src_states[0])
Expand Down Expand Up @@ -1249,12 +1252,15 @@ class RNNTJoint(rnnt_abstract.AbstractRNNTJoint, Exportable, AdapterModuleMixin)
fused_batch_size: Optional int, required if `fuse_loss_wer` flag is set. Determines the size of the
sub-batches. Should be any value below the actual batch size per GPU.
masking_prob: Optional float, indicating the probability of masking out decoder output in HAINAN
(Hybrid Autoregressive Inference Transducer) model, described in https://arxiv.org/pdf/2410.02597
Default to -1.0, which runs standard Joint network computation; if > 0, then masking out decoder output
with the specified probability.
"""

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
return {
"encoder_outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
"decoder_outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()),
Expand All @@ -1266,8 +1272,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
if not self._fuse_loss_wer:
return {
"outputs": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()),
Expand Down Expand Up @@ -1313,6 +1318,7 @@ def __init__(
fuse_loss_wer: bool = False,
fused_batch_size: Optional[int] = None,
experimental_fuse_loss_wer: Any = None,
masking_prob: float = -1.0,
):
super().__init__()

Expand All @@ -1322,6 +1328,10 @@ def __init__(
self._num_extra_outputs = num_extra_outputs
self._num_classes = num_classes + 1 + num_extra_outputs # 1 is for blank

self.masking_prob = masking_prob
if self.masking_prob > 0.0:
assert self.masking_prob < 1.0, "masking_prob must be between 0 and 1"

if experimental_fuse_loss_wer is not None:
# Override fuse_loss_wer from deprecated argument
fuse_loss_wer = experimental_fuse_loss_wer
Expand Down Expand Up @@ -1578,6 +1588,13 @@ def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tens
"""
f = f.unsqueeze(dim=2) # (B, T, 1, H)
g = g.unsqueeze(dim=1) # (B, 1, U, H)

if self.training and self.masking_prob > 0:
[B, _, U, _] = g.shape
rand = torch.rand([B, 1, U, 1]).to(g.device)
rand = torch.gt(rand, self.masking_prob)
g = g * rand

inp = f + g # [B, T, U, H]

del f, g
Expand Down Expand Up @@ -2047,7 +2064,11 @@ def forward(
return losses, wer, wer_num, wer_denom

def sampled_joint(
self, f: torch.Tensor, g: torch.Tensor, transcript: torch.Tensor, transcript_lengths: torch.Tensor,
self,
f: torch.Tensor,
g: torch.Tensor,
transcript: torch.Tensor,
transcript_lengths: torch.Tensor,
) -> torch.Tensor:
"""
Compute the sampled joint step of the network.
Expand Down

0 comments on commit b08d46c

Please sign in to comment.