You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm currently working on developing a websocket-based client-server application for live transcription. I've successfully created a client that reads an audio file from disk and sends audio chunks through the websocket, effectively simulating a live stream. On the server side, I receive each chunk and attempt to transcribe it using the logic provided in this example [1].
For transcription, I'm utilizing the "stt_en_fastconformer_hybrid_large_streaming_80ms" model. The audio file I'm using for testing purposes is encoded as 16-bit Signed Integer PCM at a sampling rate of 16 kHz, and it's in mono format. My NeMo version is v1.23.0.
However, I get this server side error (detalied below): RuntimeError: start (4) + length (1) exceeds dimension size (4).
Here is my server side code:
importasyncioimportwebsocketsimportjsonimportcopyimportnumpyasnpimporttorchimporttorch.nn.functionalasF# Load your model and any other dependencies herefromnemo.collections.asr.models.ctc_bpe_modelsimportEncDecCTCModelBPEimportnemo.collections.asrasnemo_asrfromomegaconfimportOmegaConf, open_dictfromcopyimportdeepcopy# Load your ASR modelasr_model=None# Initialize to Nonepreprocessor=None# Define global variables for cachingcache_last_channel=Nonecache_last_time=Nonecache_last_channel_len=Noneprevious_hypotheses=Nonepred_out_stream=Nonestep_num=0pre_encode_cache_size=0cache_pre_encode=Nonedefextract_transcriptions(hyps):
""" The transcribed_texts returned by CTC and RNNT models are different. This method would extract and return the text section of the hypothesis. """ifisinstance(hyps[0], Hypothesis):
transcriptions= []
forhypinhyps:
transcriptions.append(hyp.text)
else:
transcriptions=hypsreturntranscriptionsdefinit_preprocessor(asr_model):
cfg=copy.deepcopy(asr_model._cfg)
OmegaConf.set_struct(cfg.preprocessor, False)
# some changes for streaming scenariocfg.preprocessor.dither=0.0cfg.preprocessor.pad_to=0cfg.preprocessor.normalize="None"preprocessor=EncDecCTCModelBPE.from_config_dict(cfg.preprocessor)
preprocessor.to(asr_model.device)
returnpreprocessordefpreprocess_audio(audio, asr_model):
globalpreprocessordevice=asr_model.device# doing audio preprocessingaudio_signal=torch.from_numpy(audio).unsqueeze_(0).to(device)
audio_signal_len=torch.Tensor([audio.shape[0]]).to(device)
processed_signal, processed_signal_length=preprocessor(
input_signal=audio_signal, length=audio_signal_len
)
# Print shape of processed_signalprint("Shape of processed_signal:", processed_signal.shape)
returnprocessed_signal, processed_signal_lengthdeftranscribe_chunk(new_chunk):
globalcache_last_channel, cache_last_time, cache_last_channel_lenglobalprevious_hypotheses, pred_out_stream, step_numglobalcache_pre_encode# new_chunk is provided as np.int16, so we convert it to np.float32# as that is what our ASR models expectaudio_data=new_chunk.astype(np.float32)
audio_data=audio_data/32768.0# get mel-spectrogram signal & lengthprocessed_signal, processed_signal_length=preprocess_audio(audio_data, asr_model)
# prepend with cache_pre_encodeprocessed_signal=torch.cat([cache_pre_encode, processed_signal], dim=-1)
processed_signal_length+=cache_pre_encode.shape[1]
# save cache for next timecache_pre_encode=processed_signal[:, :, -pre_encode_cache_size:]
withtorch.no_grad():
(
pred_out_stream,
transcribed_texts,
cache_last_channel,
cache_last_time,
cache_last_channel_len,
previous_hypotheses,
) =asr_model.conformer_stream_step(
processed_signal=processed_signal,
processed_signal_length=processed_signal_length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
keep_all_outputs=False,
previous_hypotheses=previous_hypotheses,
previous_pred_out=pred_out_stream,
drop_extra_pre_encoded=None,
return_transcription=True,
)
final_streaming_tran=extract_transcriptions(transcribed_texts)
step_num+=1# Print shape of x before narrow operationprint("Shape of x before narrow:", processed_signal.shape)
returnfinal_streaming_tran[0]
asyncdefaudio_consumer(websocket, path):
try:
whileTrue:
audio_chunk_str=awaitwebsocket.recv()
audio_chunk=np.frombuffer(audio_chunk_str, dtype=np.int16)
transcription=transcribe_chunk(audio_chunk)
awaitwebsocket.send(json.dumps({"transcription": transcription}))
exceptwebsockets.exceptions.ConnectionClosed:
passasyncdefstart_server():
globalasr_model# Load your ASR model# Replace 'path_to_your_model' with the actual path to your modelmodel_path="models/stt_en_fastconformer_hybrid_large_streaming_80ms.nemo"asr_model=nemo_asr.models.EncDecRNNTBPEModel.restore_from(model_path)
lookahead_size=80decoder_type="ctc"# specify ENCODER_STEP_LENGTH (which is 80 ms for FastConformer models)ENCODER_STEP_LENGTH=80# ms# update att_context_sizeleft_context_size=asr_model.encoder.att_context_size[0]
asr_model.encoder.set_default_att_context_size([left_context_size, int(lookahead_size/ENCODER_STEP_LENGTH)])
asr_model.encoder.setup_streaming_params()
# make sure we use the specified decoder_typeasr_model.change_decoding_strategy(decoder_type=decoder_type)
# make sure the model's decoding strategy is optimaldecoding_cfg=asr_model.cfg.decodingwithopen_dict(decoding_cfg):
# save time by doing greedy decoding and not trying to record the alignmentsdecoding_cfg.strategy="greedy"decoding_cfg.preserve_alignments=Falseifhasattr(asr_model, 'joint'): # if an RNNT model# restrict max_symbols to make sure not stuck in infinite loopdecoding_cfg.greedy.max_symbols=10# sensible default parameter, but not necessary since batch size is 1decoding_cfg.fused_batch_size=-1asr_model.change_decoding_strategy(decoding_cfg)
# set model to eval modeasr_model.eval()
# get parameters to use as the initial cache statecache_last_channel, cache_last_time, cache_last_channel_len=asr_model.encoder.get_initial_cache_state(
batch_size=1
)
globalpreprocessorpreprocessor=init_preprocessor(asr_model)
# Initialize global variables for cachingglobalpre_encode_cache_size, cache_pre_encodepre_encode_cache_size=asr_model.encoder.streaming_cfg.pre_encode_cache_size[1]
cache_pre_encode=torch.zeros((1, asr_model.cfg.preprocessor.features, pre_encode_cache_size), device=asr_model.device)
asyncwithwebsockets.serve(audio_consumer, "localhost", 8765):
awaitasyncio.Future() # Run server foreverasyncio.run(start_server())
And here is my client side code:
importasyncioimportwebsocketsimportjsonimportnumpyasnpimportpyaudioimportstruct# Set the sample rate and chunk sizeSAMPLE_RATE=16000lookahead_size=80ENCODER_STEP_LENGTH=80chunk_size_ms=lookahead_size+ENCODER_STEP_LENGTHchunk_size_samples=int(SAMPLE_RATE*chunk_size_ms/1000) -1asyncdefsend_audio_stream(file_path, websocket):
withopen(file_path, 'rb') asaudio_file:
whileTrue:
audio_chunk=audio_file.read(chunk_size_samples*2) # 2 bytes per sample for 16-bit audioifnotaudio_chunk:
breakawaitwebsocket.send(audio_chunk)
response=awaitwebsocket.recv()
transcription=json.loads(response)
print("Transcription:", transcription['transcription'])
asyncdefmain():
# WebSocket server addressuri="ws://localhost:8765"# File path of the audio to streamaudio_file_path="test-audio/my_audio_file.wav"asyncwithwebsockets.connect(uri) aswebsocket:
awaitsend_audio_stream(audio_file_path, websocket)
# Run the main functionasyncio.run(main())
Hello,
I'm currently working on developing a websocket-based client-server application for live transcription. I've successfully created a client that reads an audio file from disk and sends audio chunks through the websocket, effectively simulating a live stream. On the server side, I receive each chunk and attempt to transcribe it using the logic provided in this example [1].
For transcription, I'm utilizing the "stt_en_fastconformer_hybrid_large_streaming_80ms" model. The audio file I'm using for testing purposes is encoded as 16-bit Signed Integer PCM at a sampling rate of 16 kHz, and it's in mono format. My NeMo version is v1.23.0.
However, I get this server side error (detalied below): RuntimeError: start (4) + length (1) exceeds dimension size (4).
Here is my server side code:
And here is my client side code:
Here is the error I get on server side:
Any ideas please? Thank you!
The text was updated successfully, but these errors were encountered: