diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 5af83a53dfc4..a39f4f5746c3 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -35,6 +35,7 @@ from nemo.collections.asr.modules import rnnt_abstract from nemo.collections.asr.parts.submodules.rnnt_loop_labels_computer import GreedyBatchedRNNTLoopLabelsComputer +from nemo.collections.asr.parts.submodules.tdt_loop_labels_computer import GreedyBatchedTDTLoopLabelsComputer from nemo.collections.asr.parts.utils import rnnt_utils from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin from nemo.collections.common.parts.rnn import label_collate @@ -2638,8 +2639,20 @@ def __init__( # Depending on availability of `blank_as_pad` support # switch between more efficient batch decoding technique + self._decoding_computer = None if self.decoder.blank_as_pad: - self._greedy_decode = self._greedy_decode_blank_as_pad + # batched "loop frames" is not implemented for TDT + self._decoding_computer = GreedyBatchedTDTLoopLabelsComputer( + decoder=self.decoder, + joint=self.joint, + blank_index=self._blank_index, + durations=self.durations, + max_symbols_per_step=self.max_symbols, + preserve_alignments=preserve_alignments, + preserve_frame_confidence=preserve_frame_confidence, + confidence_method_cfg=confidence_method_cfg, + ) + self._greedy_decode = self._greedy_decode_blank_as_pad_loop_labels else: self._greedy_decode = self._greedy_decode_masked @@ -2685,179 +2698,33 @@ def forward( return (packed_result,) - def _greedy_decode_blank_as_pad( + def _greedy_decode_masked( self, x: torch.Tensor, out_len: torch.Tensor, device: torch.device, partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None, ): - if partial_hypotheses is not None: - raise NotImplementedError("`partial_hypotheses` support is not supported") - - with torch.inference_mode(): - # x: [B, T, D] - # out_len: [B] - # device: torch.device - - # Initialize list of Hypothesis - batchsize = x.shape[0] - hypotheses = [ - rnnt_utils.Hypothesis(score=0.0, y_sequence=[], timestep=[], dec_state=None) for _ in range(batchsize) - ] - - # Initialize Hidden state matrix (shared by entire batch) - hidden = None - - # If alignments need to be preserved, register a danling list to hold the values - if self.preserve_alignments: - # alignments is a 3-dimensional dangling list representing B x T x U - for hyp in hypotheses: - hyp.alignments = [[]] - - # If confidence scores need to be preserved, register a danling list to hold the values - if self.preserve_frame_confidence: - # frame_confidence is a 3-dimensional dangling list representing B x T x U - for hyp in hypotheses: - hyp.frame_confidence = [[]] - - # Last Label buffer + Last Label without blank buffer - # batch level equivalent of the last_label - last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device) - - # Mask buffers - blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) - - # Get max sequence length - max_out_len = out_len.max() - - # skip means the number of frames the next decoding step should "jump" to. When skip == 1 - # it means the next decoding step will just use the next input frame. - skip = 1 - for time_idx in range(max_out_len): - if skip > 1: # if skip > 1 at the current step, we decrement it and skip the current frame. - skip -= 1 - continue - f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D] - - # need_to_stay is a boolean indicates whether the next decoding step should remain in the same frame. - need_to_stay = True - symbols_added = 0 - - # Reset blank mask - blank_mask.mul_(False) - - # Update blank mask with time mask - # Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch) - # Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len - blank_mask = time_idx >= out_len - - # Start inner loop - while need_to_stay and (self.max_symbols is None or symbols_added < self.max_symbols): - # Batch prediction and joint network steps - # If very first prediction step, submit SOS tag (blank) to pred_step. - # This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state - if time_idx == 0 and symbols_added == 0 and hidden is None: - g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize) - else: - # Perform batch step prediction of decoder, getting new states and scores ("g") - g, hidden_prime = self._pred_step(last_label, hidden, batch_size=batchsize) - - # Batched joint step - Output = [B, V + 1 + num-big-blanks] - # Note: log_normalize must not be True here since the joiner output is contanetation of both token logits and duration logits, - # and they need to be normalized independently. - joined = self._joint_step(f, g, log_normalize=None) - logp = joined[:, 0, 0, : -len(self.durations)] - duration_logp = joined[:, 0, 0, -len(self.durations) :] - - if logp.dtype != torch.float32: - logp = logp.float() - duration_logp = duration_logp.float() - - # get the max for both token and duration predictions. - v, k = logp.max(1) - dv, dk = duration_logp.max(1) - - # here we set the skip value to be the minimum of all predicted durations, hense the "torch.min(dk)" call there. - # Please refer to Section 5.2 of our paper https://arxiv.org/pdf/2304.06795.pdf for explanation of this. - skip = self.durations[int(torch.min(dk))] - - # this is a special case: if all batches emit blanks, we require that skip be at least 1 - # so we don't loop forever at the current frame. - if blank_mask.all(): - if skip == 0: - skip = 1 - - need_to_stay = skip == 0 - del g - - # Update blank mask with current predicted blanks - # This is accumulating blanks over all time steps T and all target steps min(max_symbols, U) - k_is_blank = k == self._blank_index - blank_mask.bitwise_or_(k_is_blank) - - del k_is_blank - del logp, duration_logp - - # If all samples predict / have predicted prior blanks, exit loop early - # This is equivalent to if single sample predicted k - if not blank_mask.all(): - # Collect batch indices where blanks occurred now/past - blank_indices = (blank_mask == 1).nonzero(as_tuple=False) - - # Recover prior state for all samples which predicted blank now/past - if hidden is not None: - hidden_prime = self.decoder.batch_copy_states(hidden_prime, hidden, blank_indices) - - elif len(blank_indices) > 0 and hidden is None: - # Reset state if there were some blank and other non-blank predictions in batch - # Original state is filled with zeros so we just multiply - # LSTM has 2 states - hidden_prime = self.decoder.batch_copy_states(hidden_prime, None, blank_indices, value=0.0) - - # Recover prior predicted label for all samples which predicted blank now/past - k[blank_indices] = last_label[blank_indices, 0] - - # Update new label and hidden state for next iteration - last_label = k.clone().view(-1, 1) - hidden = hidden_prime - - # Update predicted labels, accounting for time mask - # If blank was predicted even once, now or in the past, - # Force the current predicted label to also be blank - # This ensures that blanks propogate across all timesteps - # once they have occured (normally stopping condition of sample level loop). - for kidx, ki in enumerate(k): - if blank_mask[kidx] == 0: - hypotheses[kidx].y_sequence.append(ki) - hypotheses[kidx].timestep.append(time_idx) - hypotheses[kidx].score += float(v[kidx]) - - symbols_added += 1 - - # Remove trailing empty list of alignments at T_{am-len} x Uj - if self.preserve_alignments: - for batch_idx in range(batchsize): - if len(hypotheses[batch_idx].alignments[-1]) == 0: - del hypotheses[batch_idx].alignments[-1] - - # Remove trailing empty list of confidence scores at T_{am-len} x Uj - if self.preserve_frame_confidence: - for batch_idx in range(batchsize): - if len(hypotheses[batch_idx].frame_confidence[-1]) == 0: - del hypotheses[batch_idx].frame_confidence[-1] - - # Preserve states - for batch_idx in range(batchsize): - hypotheses[batch_idx].dec_state = self.decoder.batch_select_state(hidden, batch_idx) - - return hypotheses + raise NotImplementedError("masked greedy-batched decode is not supported for TDT models.") - def _greedy_decode_masked( + @torch.inference_mode() + def _greedy_decode_blank_as_pad_loop_labels( self, x: torch.Tensor, out_len: torch.Tensor, device: torch.device, - partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None, - ): - raise NotImplementedError("masked greedy-batched decode is not supported for TDT models.") + partial_hypotheses: Optional[list[rnnt_utils.Hypothesis]] = None, + ) -> list[rnnt_utils.Hypothesis]: + """ + Optimized batched greedy decoding. + The main idea: search for next labels for the whole batch (evaluating Joint) + and thus always evaluate prediction network with maximum possible batch size + """ + if partial_hypotheses is not None: + raise NotImplementedError("`partial_hypotheses` support is not implemented") + + batched_hyps, alignments, last_decoder_state = self._decoding_computer(x=x, out_len=out_len) + hyps = rnnt_utils.batched_hyps_to_hypotheses(batched_hyps, alignments) + for hyp, state in zip(hyps, self.decoder.batch_split_states(last_decoder_state)): + hyp.dec_state = state + return hyps diff --git a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py new file mode 100644 index 000000000000..ce34d8362171 --- /dev/null +++ b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py @@ -0,0 +1,268 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from omegaconf import DictConfig, ListConfig + +from nemo.collections.asr.parts.utils import rnnt_utils +from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin + + +class GreedyBatchedTDTLoopLabelsComputer(ConfidenceMethodMixin): + """ + Loop Labels algorithm implementation. Callable. + """ + + def __init__( + self, + decoder, + joint, + blank_index: int, + durations: Union[list[int], ListConfig[int]], + max_symbols_per_step: Optional[int] = None, + preserve_alignments=False, + preserve_frame_confidence=False, + confidence_method_cfg: Optional[DictConfig] = None, + ): + """ + Init method. + Args: + decoder: Prediction network from RNN-T + joint: Joint module from RNN-T + blank_index: index of blank symbol + durations: list of TDT durations, e.g., [0, 1, 2, 4, 8] + max_symbols_per_step: max symbols to emit on each step (to avoid infinite looping) + preserve_alignments: if alignments are needed + preserve_frame_confidence: if frame confidence is needed + confidence_method_cfg: config for the confidence + """ + super().__init__() + self.decoder = decoder + self.joint = joint + # keep durations on CPU to avoid side effects in multi-gpu environment + self.durations = torch.tensor(list(durations), device="cpu").to(torch.long) + self._blank_index = blank_index + self.max_symbols = max_symbols_per_step + self.preserve_alignments = preserve_alignments + self.preserve_frame_confidence = preserve_frame_confidence + self._SOS = self._blank_index + self._init_confidence_method(confidence_method_cfg=confidence_method_cfg) + assert self._SOS == self._blank_index # "blank as pad" algorithm only + + def __call__( + self, x: torch.Tensor, out_len: torch.Tensor, + ) -> Tuple[rnnt_utils.BatchedHyps, Optional[rnnt_utils.BatchedAlignments], Any]: + """ + Optimized batched greedy decoding. + Iterates over labels, on each step finding the next non-blank label + (evaluating Joint multiple times in inner loop); It uses a minimal possible amount of calls + to prediction network (with maximum possible batch size), + which makes it especially useful for scaling the prediction network. + During decoding all active hypotheses ("texts") have the same lengths. + + Args: + x: output from the encoder + out_len: lengths of the utterances in `x` + """ + batch_size, max_time, _unused = x.shape + device = x.device + + x = self.joint.project_encoder(x) # do not recalculate joint projection, project only once + + # init output structures: BatchedHyps (for results), BatchedAlignments + last decoder state + # init empty batched hypotheses + batched_hyps = rnnt_utils.BatchedHyps( + batch_size=batch_size, + init_length=max_time * self.max_symbols if self.max_symbols is not None else max_time, + device=x.device, + float_dtype=x.dtype, + ) + # sample state, will be replaced further when the decoding for hypothesis is done + last_decoder_state = self.decoder.initialize_state(x) + # init alignments if necessary + use_alignments = self.preserve_alignments or self.preserve_frame_confidence + # always use alignments variable - for torch.jit adaptation, but keep it as minimal as possible + alignments = rnnt_utils.BatchedAlignments( + batch_size=batch_size, + logits_dim=self.joint.num_classes_with_blank, + init_length=max_time * 2 if use_alignments else 1, # blank for each timestep + text tokens + device=x.device, + float_dtype=x.dtype, + store_alignments=self.preserve_alignments, + store_frame_confidence=self.preserve_frame_confidence, + ) + + # durations + all_durations = self.durations.to(device, non_blocking=True) + num_durations = all_durations.shape[0] + + # initial state, needed for torch.jit to compile (cannot handle None) + state = self.decoder.initialize_state(x) + # indices of elements in batch (constant) + batch_indices = torch.arange(batch_size, dtype=torch.long, device=device) + # last found labels - initially () symbol + labels = torch.full_like(batch_indices, fill_value=self._SOS) + + # time indices + time_indices = torch.zeros_like(batch_indices) + safe_time_indices = torch.zeros_like(time_indices) # time indices, guaranteed to be < out_len + time_indices_current_labels = torch.zeros_like(time_indices) + last_timesteps = out_len - 1 + + # masks for utterances in batch + active_mask: torch.Tensor = out_len > 0 + advance_mask = torch.empty_like(active_mask) + + # for storing the last state we need to know what elements became "inactive" on this step + active_mask_prev = torch.empty_like(active_mask) + became_inactive_mask = torch.empty_like(active_mask) + + # loop while there are active utterances + first_step = True + while active_mask.any(): + active_mask_prev.copy_(active_mask, non_blocking=True) + # stage 1: get decoder (prediction network) output + if first_step: + # start of the loop, SOS symbol is passed into prediction network, state is None + # we need to separate this for torch.jit + decoder_output, state, *_ = self.decoder.predict( + labels.unsqueeze(1), None, add_sos=False, batch_size=batch_size + ) + first_step = False + else: + decoder_output, state, *_ = self.decoder.predict( + labels.unsqueeze(1), state, add_sos=False, batch_size=batch_size + ) + decoder_output = self.joint.project_prednet(decoder_output) # do not recalculate joint projection + + # stage 2: get joint output, iteratively seeking for non-blank labels + # blank label in `labels` tensor means "end of hypothesis" (for this index) + logits = ( + self.joint.joint_after_projection(x[batch_indices, safe_time_indices].unsqueeze(1), decoder_output,) + .squeeze(1) + .squeeze(1) + ) + scores, labels = logits[:, :-num_durations].max(dim=-1) + jump_durations_indices = logits[:, -num_durations:].argmax(dim=-1) + durations = all_durations[jump_durations_indices] + + # search for non-blank labels using joint, advancing time indices for blank labels + # checking max_symbols is not needed, since we already forced advancing time indices for such cases + blank_mask = labels == self._blank_index + # for blank labels force duration >= 1 + durations.masked_fill_(torch.logical_and(durations == 0, blank_mask), 1) + time_indices_current_labels.copy_(time_indices, non_blocking=True) + if use_alignments: + alignments.add_results_masked_( + active_mask=active_mask, + time_indices=time_indices_current_labels, + logits=logits if self.preserve_alignments else None, + labels=labels if self.preserve_alignments else None, + confidence=self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)) + if self.preserve_frame_confidence + else None, + ) + + # advance_mask is a mask for current batch for searching non-blank labels; + # each element is True if non-blank symbol is not yet found AND we can increase the time index + time_indices += durations + torch.minimum(time_indices, last_timesteps, out=safe_time_indices) + torch.less(time_indices, out_len, out=active_mask) + torch.logical_and(active_mask, blank_mask, out=advance_mask) + # inner loop: find next non-blank labels (if exist) + while advance_mask.any(): + # same as: time_indices_current_labels[advance_mask] = time_indices[advance_mask], but non-blocking + # store current time indices to use further for storing the results + torch.where(advance_mask, time_indices, time_indices_current_labels, out=time_indices_current_labels) + logits = ( + self.joint.joint_after_projection( + x[batch_indices, safe_time_indices].unsqueeze(1), decoder_output, + ) + .squeeze(1) + .squeeze(1) + ) + # get labels (greedy) and scores from current logits, replace labels/scores with new + # labels[advance_mask] are blank, and we are looking for non-blank labels + more_scores, more_labels = logits[:, :-num_durations].max(dim=-1) + # same as: labels[advance_mask] = more_labels[advance_mask], but non-blocking + torch.where(advance_mask, more_labels, labels, out=labels) + # same as: scores[advance_mask] = more_scores[advance_mask], but non-blocking + torch.where(advance_mask, more_scores, scores, out=scores) + jump_durations_indices = logits[:, -num_durations:].argmax(dim=-1) + durations = all_durations[jump_durations_indices] + + if use_alignments: + alignments.add_results_masked_( + active_mask=advance_mask, + time_indices=time_indices_current_labels, + logits=logits if self.preserve_alignments else None, + labels=more_labels if self.preserve_alignments else None, + confidence=self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)) + if self.preserve_frame_confidence + else None, + ) + + blank_mask = labels == self._blank_index + # for blank labels force duration >= 1 + durations.masked_fill_(torch.logical_and(durations == 0, blank_mask), 1) + # same as time_indices[advance_mask] += durations[advance_mask], but non-blocking + torch.where(advance_mask, time_indices + durations, time_indices, out=time_indices) + torch.minimum(time_indices, last_timesteps, out=safe_time_indices) + torch.less(time_indices, out_len, out=active_mask) + torch.logical_and(active_mask, blank_mask, out=advance_mask) + + # stage 3: filter labels and state, store hypotheses + # select states for hyps that became inactive (is it necessary?) + # this seems to be redundant, but used in the `loop_frames` output + torch.ne(active_mask, active_mask_prev, out=became_inactive_mask) + self.decoder.batch_replace_states_mask( + src_states=state, dst_states=last_decoder_state, mask=became_inactive_mask, + ) + + # store hypotheses + if self.max_symbols is not None: + # pre-allocated memory, no need for checks + batched_hyps.add_results_masked_no_checks_( + active_mask, labels, time_indices_current_labels, scores, + ) + else: + # auto-adjusted storage + batched_hyps.add_results_masked_( + active_mask, labels, time_indices_current_labels, scores, + ) + + # stage 4: to avoid looping, go to next frame after max_symbols emission + if self.max_symbols is not None: + # if labels are non-blank (not end-of-utterance), check that last observed timestep with label: + # if it is equal to the current time index, and number of observations is >= max_symbols, force blank + force_blank_mask = torch.logical_and( + active_mask, + torch.logical_and( + torch.logical_and( + labels != self._blank_index, batched_hyps.last_timestep_lasts >= self.max_symbols, + ), + batched_hyps.last_timestep == time_indices, + ), + ) + time_indices += force_blank_mask # emit blank => advance time indices + # update safe_time_indices, non-blocking + torch.minimum(time_indices, last_timesteps, out=safe_time_indices) + # same as: active_mask = time_indices < out_len + torch.less(time_indices, out_len, out=active_mask) + if use_alignments: + return batched_hyps, alignments, last_decoder_state + return batched_hyps, None, last_decoder_state