-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added partially auto-regressive decoding #5769
Conversation
for more information, see https://pre-commit.ci
With the partially auto-regressive decoding we can speed up OWSM inference twice as fast as Whisper on Google colab. We can use PAR inference with the following code: from espnet2.bin.s2t_inference import Speech2Text
m = Speech2Text.from_pretrained(
'espnet/owsm_v3.1_ebf',
partial_ar=True,
max_mask_parallel=10
) |
The paper encountered a memory issue, but a minor change in attention (commit) resolved this problem. Additionally, it contributed to speeding up the inference process. |
Thanks a lot, @Masao-Someki! |
Thanks! This speed-up looks very cool. BTW, I didn't read the paper. Why does it only support ASR? Does it rely on CTC? If not, ASR and ST would be similar? |
if expand_value: | ||
v_shape = value.shape | ||
v = ( | ||
self.linear_v(value[:1, :, :]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does expand_value
do? It seems to select only the first sample?
It would be better to add some comments to explain this argument and its purpose.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, I will add comments. Actually this is very important to reduce memory usage, because the batch size of the PAR might become too large. We compute all beams for all masks in a batched manner, so for example, if we have 10 beam size and 10 masks, we have batch_size=100
.
So we compute this memory-heavy calculation only one time. In my local PC, applying this trick makes inference faster (3.2sec -> 2.1sec)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation. Have you added some comments about expand_value
? Sorry if I missed it somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Masao-Someki, can you add some comments about expand_value
in this file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized that we need to expand the key
as well, so I renamed it to expand_kv
. I added the following comment to clarify this change:
https://github.com/Masao-Someki/espnet/blob/7eeb3c0608b63dd46cb4eaad3990c2d6ba1a948b/espnet/nets/pytorch_backend/transformer/attention.py#L121-L126
Another question: Does this support an extra text prompt (i.e., the text before |
Thank you @pyf98!
It only supports ASR because PAR uses CTC output as its basis. In this decoding process, we first get the CTC output with the probability for each token, and then apply masking. After that, we use the decoder to predict the masked section in AR manner.
It should depend on the case. Considering that the decoder can use the extra prompt but CTC cannot, if we get CTC output with high probability (which means that there is no masks to be predicted with decoder) then PAR do not use extra prompt. If there are any tokens with lower probability, then PAR use decoder with extra prompt to predict the masked part. |
Thanks for the explanation. I now understand it at a high level. |
- Removed a lot of noise code - Added comments - Removed unrequired comments (automatically generated by editor) - Format with black, isort, pycodestyle
@pyf98 I noticed that there were quite a lot of noisy codes I used in my experiments, so I removed everything unrequired. I also added comments and re-formated codes. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5769 +/- ##
===========================================
- Coverage 62.77% 16.18% -46.60%
===========================================
Files 153 772 +619
Lines 16958 70863 +53905
===========================================
+ Hits 10646 11469 +823
- Misses 6312 59394 +53082
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
I've completed all the reviews and added unit tests! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know why many lines do not go through the test?
Since we created an instance in https://github.com/espnet/espnet/pull/5769/files#diff-17a5ec3fb4fd9d507bd7f2f97f11219d6f47310afb5a80ee31ece48696e193d7R36, I thought that most lines in this file are covered, but actually not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked the log file of unit_test_espnet1_and_espnet2_on_debian11, it says we could cover most of the lines:
...
espnet2/asr/partially_AR_model.py 88 3 97%
...
espnet/nets/beam_search_partially_AR.py 172 15 91%
It seems that Codecov hasn't updated its results, likely due to the issue in CI.
Thanks, @Masao-Someki! |
What?
Support partially auto-regressive decoding for ASR and S2T task.
Since this decoding process is based on BERT-CTC, we can only perform speech recognition.
Why?
To speed up OWSM inference
See also
https://arxiv.org/abs/2309.14922