Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"RuntimeError: start (4) + length (1) exceeds dimension size (4)." when running cache aware streaming inference #9190

Open
lucgeo opened this issue May 14, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@lucgeo
Copy link

lucgeo commented May 14, 2024

Hello,

I'm currently working on developing a websocket-based client-server application for live transcription. I've successfully created a client that reads an audio file from disk and sends audio chunks through the websocket, effectively simulating a live stream. On the server side, I receive each chunk and attempt to transcribe it using the logic provided in this example [1].

For transcription, I'm utilizing the "stt_en_fastconformer_hybrid_large_streaming_80ms" model. The audio file I'm using for testing purposes is encoded as 16-bit Signed Integer PCM at a sampling rate of 16 kHz, and it's in mono format. My NeMo version is v1.23.0.

However, I get this server side error (detalied below): RuntimeError: start (4) + length (1) exceeds dimension size (4).

Here is my server side code:

import asyncio
import websockets
import json
import copy
import numpy as np
import torch
import torch.nn.functional as F

# Load your model and any other dependencies here
from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE
import nemo.collections.asr as nemo_asr
from omegaconf import OmegaConf, open_dict
from copy import deepcopy

# Load your ASR model
asr_model = None  # Initialize to None
preprocessor = None

# Define global variables for caching
cache_last_channel = None
cache_last_time = None
cache_last_channel_len = None
previous_hypotheses = None
pred_out_stream = None
step_num = 0
pre_encode_cache_size = 0
cache_pre_encode = None



def extract_transcriptions(hyps):
    """
        The transcribed_texts returned by CTC and RNNT models are different.
        This method would extract and return the text section of the hypothesis.
    """
    if isinstance(hyps[0], Hypothesis):
        transcriptions = []
        for hyp in hyps:
            transcriptions.append(hyp.text)
    else:
        transcriptions = hyps
    return transcriptions



def init_preprocessor(asr_model):
    cfg = copy.deepcopy(asr_model._cfg)
    OmegaConf.set_struct(cfg.preprocessor, False)

    # some changes for streaming scenario
    cfg.preprocessor.dither = 0.0
    cfg.preprocessor.pad_to = 0
    cfg.preprocessor.normalize = "None"

    preprocessor = EncDecCTCModelBPE.from_config_dict(cfg.preprocessor)
    preprocessor.to(asr_model.device)

    return preprocessor




def preprocess_audio(audio, asr_model):

    global preprocessor
    device = asr_model.device

    # doing audio preprocessing
    audio_signal = torch.from_numpy(audio).unsqueeze_(0).to(device)
    audio_signal_len = torch.Tensor([audio.shape[0]]).to(device)
    processed_signal, processed_signal_length = preprocessor(
        input_signal=audio_signal, length=audio_signal_len
    )

    # Print shape of processed_signal
    print("Shape of processed_signal:", processed_signal.shape)
    return processed_signal, processed_signal_length





def transcribe_chunk(new_chunk):
    
    global cache_last_channel, cache_last_time, cache_last_channel_len
    global previous_hypotheses, pred_out_stream, step_num
    global cache_pre_encode
    
    # new_chunk is provided as np.int16, so we convert it to np.float32
    # as that is what our ASR models expect
    audio_data = new_chunk.astype(np.float32)
    audio_data = audio_data / 32768.0

    # get mel-spectrogram signal & length
    processed_signal, processed_signal_length = preprocess_audio(audio_data, asr_model)
     
    # prepend with cache_pre_encode
    processed_signal = torch.cat([cache_pre_encode, processed_signal], dim=-1)
    processed_signal_length += cache_pre_encode.shape[1]
    
    # save cache for next time
    cache_pre_encode = processed_signal[:, :, -pre_encode_cache_size:]
    
    with torch.no_grad():
        (
            pred_out_stream,
            transcribed_texts,
            cache_last_channel,
            cache_last_time,
            cache_last_channel_len,
            previous_hypotheses,
        ) = asr_model.conformer_stream_step(
            processed_signal=processed_signal,
            processed_signal_length=processed_signal_length,
            cache_last_channel=cache_last_channel,
            cache_last_time=cache_last_time,
            cache_last_channel_len=cache_last_channel_len,
            keep_all_outputs=False,
            previous_hypotheses=previous_hypotheses,
            previous_pred_out=pred_out_stream,
            drop_extra_pre_encoded=None,
            return_transcription=True,
        )
    
    final_streaming_tran = extract_transcriptions(transcribed_texts)
    step_num += 1

    # Print shape of x before narrow operation
    print("Shape of x before narrow:", processed_signal.shape)
    
    return final_streaming_tran[0]





async def audio_consumer(websocket, path):
    try:
        while True:
            audio_chunk_str = await websocket.recv()
            audio_chunk = np.frombuffer(audio_chunk_str, dtype=np.int16)
            transcription = transcribe_chunk(audio_chunk)
            await websocket.send(json.dumps({"transcription": transcription}))
    except websockets.exceptions.ConnectionClosed:
        pass



async def start_server():
    global asr_model
    # Load your ASR model
    # Replace 'path_to_your_model' with the actual path to your model
    
    model_path = "models/stt_en_fastconformer_hybrid_large_streaming_80ms.nemo"
    asr_model = nemo_asr.models.EncDecRNNTBPEModel.restore_from(model_path)
    
    
    lookahead_size = 80
    decoder_type = "ctc"

    # specify ENCODER_STEP_LENGTH (which is 80 ms for FastConformer models)
    ENCODER_STEP_LENGTH = 80 # ms


    # update att_context_size
    left_context_size = asr_model.encoder.att_context_size[0]
    asr_model.encoder.set_default_att_context_size([left_context_size, int(lookahead_size / ENCODER_STEP_LENGTH)])

    asr_model.encoder.setup_streaming_params()

    # make sure we use the specified decoder_type
    asr_model.change_decoding_strategy(decoder_type=decoder_type)

    # make sure the model's decoding strategy is optimal
    decoding_cfg = asr_model.cfg.decoding
    with open_dict(decoding_cfg):
        # save time by doing greedy decoding and not trying to record the alignments
        decoding_cfg.strategy = "greedy"
        decoding_cfg.preserve_alignments = False
        if hasattr(asr_model, 'joint'):  # if an RNNT model
            # restrict max_symbols to make sure not stuck in infinite loop
            decoding_cfg.greedy.max_symbols = 10
            # sensible default parameter, but not necessary since batch size is 1
            decoding_cfg.fused_batch_size = -1
        asr_model.change_decoding_strategy(decoding_cfg)


    # set model to eval mode
    asr_model.eval()


    # get parameters to use as the initial cache state
    cache_last_channel, cache_last_time, cache_last_channel_len = asr_model.encoder.get_initial_cache_state(
        batch_size=1
    )


    global preprocessor
    preprocessor = init_preprocessor(asr_model)

    # Initialize global variables for caching
    global pre_encode_cache_size, cache_pre_encode
    pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size[1]
    cache_pre_encode = torch.zeros((1, asr_model.cfg.preprocessor.features, pre_encode_cache_size), device=asr_model.device)

    async with websockets.serve(audio_consumer, "localhost", 8765):
        await asyncio.Future()  # Run server forever

asyncio.run(start_server())

And here is my client side code:

import asyncio
import websockets
import json
import numpy as np
import pyaudio
import struct

# Set the sample rate and chunk size
SAMPLE_RATE = 16000
lookahead_size = 80
ENCODER_STEP_LENGTH = 80
chunk_size_ms = lookahead_size + ENCODER_STEP_LENGTH
chunk_size_samples = int(SAMPLE_RATE * chunk_size_ms / 1000) - 1

async def send_audio_stream(file_path, websocket):
    with open(file_path, 'rb') as audio_file:
        while True:
            audio_chunk = audio_file.read(chunk_size_samples * 2)  # 2 bytes per sample for 16-bit audio
            if not audio_chunk:
                break

            await websocket.send(audio_chunk)
            response = await websocket.recv()
            transcription = json.loads(response)
            print("Transcription:", transcription['transcription'])



async def main():
    # WebSocket server address
    uri = "ws://localhost:8765"

    # File path of the audio to stream
    audio_file_path = "test-audio/my_audio_file.wav"

    async with websockets.connect(uri) as websocket:
        await send_audio_stream(audio_file_path, websocket)

# Run the main function
asyncio.run(main())

Here is the error I get on server side:

connection handler failed
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/site-packages/websockets/legacy/server.py", line 236, in handler
    await self.ws_handler(self)
  File "/usr/local/lib/python3.10/site-packages/websockets/legacy/server.py", line 1175, in _ws_handler
    return await cast(
  File "/home/apps/ASR/server_streaming.py", line 142, in audio_consumer
    transcription = transcribe_chunk(audio_chunk)
  File "/home/apps/ASR/server_streaming.py", line 112, in transcribe_chunk
    ) = asr_model.conformer_stream_step(
  File "/home/apps/ASR/nemo/collections/asr/parts/mixins/mixins.py", line 676, in conformer_stream_step
    best_hyp, all_hyp_or_transcribed_texts = self.decoding.rnnt_decoder_predictions_tensor(
  File "/home/apps/ASR/nemo/collections/asr/parts/submodules/rnnt_decoding.py", line 455, in rnnt_decoder_predictions_tensor
    hypotheses_list = self.decoding(
  File "/home/apps/ASR/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py", line 180, in __call__
    return self.forward(*args, **kwargs)
  File "/home/apps/ASR/nemo/core/classes/common.py", line 1098, in __call__
    outputs = wrapped(*args, **kwargs)
  File "/home/apps/ASR/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py", line 388, in forward
    hypothesis = self._greedy_decode(inseq, logitlen, partial_hypotheses=partial_hypothesis)
  File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/apps/ASR/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py", line 431, in _greedy_decode
    f = x.narrow(dim=0, start=time_idx, length=1)
RuntimeError: start (4) + length (1) exceeds dimension size (4).

Any ideas please? Thank you!

@lucgeo lucgeo added the bug Something isn't working label May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant