Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added partially auto-regressive decoding #5769

Merged
merged 15 commits into from
May 25, 2024

Conversation

Masao-Someki
Copy link
Contributor

What?

Support partially auto-regressive decoding for ASR and S2T task.
Since this decoding process is based on BERT-CTC, we can only perform speech recognition.

Why?

To speed up OWSM inference

See also

https://arxiv.org/abs/2309.14922

@Masao-Someki
Copy link
Contributor Author

With the partially auto-regressive decoding we can speed up OWSM inference twice as fast as Whisper on Google colab. We can use PAR inference with the following code:

from espnet2.bin.s2t_inference import Speech2Text
m = Speech2Text.from_pretrained(
    'espnet/owsm_v3.1_ebf',
    partial_ar=True,
    max_mask_parallel=10
)

@Masao-Someki
Copy link
Contributor Author

The paper encountered a memory issue, but a minor change in attention (commit) resolved this problem. Additionally, it contributed to speeding up the inference process.

@sw005320 sw005320 requested a review from pyf98 April 30, 2024 11:52
@sw005320 sw005320 added this to the v.202405 milestone Apr 30, 2024
@sw005320
Copy link
Contributor

Thanks a lot, @Masao-Someki!
@pyf98, can you review this PR?

espnet2/bin/s2t_inference.py Outdated Show resolved Hide resolved
@pyf98
Copy link
Collaborator

pyf98 commented Apr 30, 2024

Thanks! This speed-up looks very cool.
Since it only supports ASR, can you add some assertions in s2t_inference to ensure the task is ASR?

BTW, I didn't read the paper. Why does it only support ASR? Does it rely on CTC? If not, ASR and ST would be similar?
Also, is there any degradation in performance?

if expand_value:
v_shape = value.shape
v = (
self.linear_v(value[:1, :, :])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does expand_value do? It seems to select only the first sample?
It would be better to add some comments to explain this argument and its purpose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I will add comments. Actually this is very important to reduce memory usage, because the batch size of the PAR might become too large. We compute all beams for all masks in a batched manner, so for example, if we have 10 beam size and 10 masks, we have batch_size=100.
So we compute this memory-heavy calculation only one time. In my local PC, applying this trick makes inference faster (3.2sec -> 2.1sec)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. Have you added some comments about expand_value? Sorry if I missed it somewhere.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Masao-Someki, can you add some comments about expand_value in this file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized that we need to expand the key as well, so I renamed it to expand_kv. I added the following comment to clarify this change:
https://github.com/Masao-Someki/espnet/blob/7eeb3c0608b63dd46cb4eaad3990c2d6ba1a948b/espnet/nets/pytorch_backend/transformer/attention.py#L121-L126

espnet2/bin/s2t_inference.py Outdated Show resolved Hide resolved
espnet2/bin/s2t_inference.py Outdated Show resolved Hide resolved
espnet2/bin/asr_inference.py Outdated Show resolved Hide resolved
@pyf98
Copy link
Collaborator

pyf98 commented Apr 30, 2024

Another question: Does this support an extra text prompt (i.e., the text before <sos> in OWSM)?

@Masao-Someki
Copy link
Contributor Author

Thank you @pyf98!

BTW, I didn't read the paper. Why does it only support ASR? Does it rely on CTC? If not, ASR and ST would be similar?
Also, is there any degradation in performance?

It only supports ASR because PAR uses CTC output as its basis. In this decoding process, we first get the CTC output with the probability for each token, and then apply masking. After that, we use the decoder to predict the masked section in AR manner.
The accuracy relies on several hyperparameters, but basically there is almost no degradation.
image

Another question: Does this support an extra text prompt (i.e., the text before in OWSM)?

It should depend on the case. Considering that the decoder can use the extra prompt but CTC cannot, if we get CTC output with high probability (which means that there is no masks to be predicted with decoder) then PAR do not use extra prompt. If there are any tokens with lower probability, then PAR use decoder with extra prompt to predict the masked part.

@pyf98
Copy link
Collaborator

pyf98 commented May 1, 2024

Thank you @pyf98!

BTW, I didn't read the paper. Why does it only support ASR? Does it rely on CTC? If not, ASR and ST would be similar?
Also, is there any degradation in performance?

It only supports ASR because PAR uses CTC output as its basis. In this decoding process, we first get the CTC output with the probability for each token, and then apply masking. After that, we use the decoder to predict the masked section in AR manner. The accuracy relies on several hyperparameters, but basically there is almost no degradation. image

Another question: Does this support an extra text prompt (i.e., the text before in OWSM)?

It should depend on the case. Considering that the decoder can use the extra prompt but CTC cannot, if we get CTC output with high probability (which means that there is no masks to be predicted with decoder) then PAR do not use extra prompt. If there are any tokens with lower probability, then PAR use decoder with extra prompt to predict the masked part.

Thanks for the explanation. I now understand it at a high level.

- Removed a lot of noise code
- Added comments
- Removed unrequired comments (automatically generated by editor)
- Format with black, isort, pycodestyle
@Masao-Someki
Copy link
Contributor Author

@pyf98 I noticed that there were quite a lot of noisy codes I used in my experiments, so I removed everything unrequired. I also added comments and re-formated codes.

Copy link

codecov bot commented May 4, 2024

Codecov Report

Attention: Patch coverage is 3.68421% with 366 lines in your changes are missing coverage. Please review.

Project coverage is 16.18%. Comparing base (1ea3fdf) to head (1e6becd).

Files Patch % Lines
espnet/nets/beam_search_partially_AR.py 0.00% 172 Missing ⚠️
espnet2/asr/partially_AR_model.py 0.00% 88 Missing ⚠️
.../nets/pytorch_backend/transformer/decoder_layer.py 4.76% 40 Missing ⚠️
espnet2/bin/s2t_inference.py 0.00% 27 Missing ⚠️
espnet2/asr/decoder/transformer_decoder.py 13.79% 25 Missing ⚠️
espnet2/bin/asr_inference.py 0.00% 9 Missing ⚠️
...pnet/nets/pytorch_backend/transformer/attention.py 60.00% 4 Missing ⚠️
espnet/nets/scorer_interface.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           master    #5769       +/-   ##
===========================================
- Coverage   62.77%   16.18%   -46.60%     
===========================================
  Files         153      772      +619     
  Lines       16958    70863    +53905     
===========================================
+ Hits        10646    11469      +823     
- Misses       6312    59394    +53082     
Flag Coverage Δ
test_integration_espnet1 ?
test_python_espnetez 14.18% <2.36%> (?)
test_utils 20.61% <18.18%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Masao-Someki
Copy link
Contributor Author

I've completed all the reviews and added unit tests!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why many lines do not go through the test?
Since we created an instance in https://github.com/espnet/espnet/pull/5769/files#diff-17a5ec3fb4fd9d507bd7f2f97f11219d6f47310afb5a80ee31ece48696e193d7R36, I thought that most lines in this file are covered, but actually not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the log file of unit_test_espnet1_and_espnet2_on_debian11, it says we could cover most of the lines:

...
 espnet2/asr/partially_AR_model.py    88    3    97%
...
 espnet/nets/beam_search_partially_AR.py    172    15    91%

It seems that Codecov hasn't updated its results, likely due to the issue in CI.

@sw005320 sw005320 merged commit 721fb64 into espnet:master May 25, 2024
32 of 35 checks passed
@sw005320
Copy link
Contributor

Thanks, @Masao-Someki!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants