-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wav2Vec2Pretrain (HFTransformersInterface implementation) samples padded values for mask_time_indices and negative_sample_indices #2386
Labels
bug
Something isn't working
Comments
Hey @TParcollet, could you please have a look? |
My local fix is something like this (using features_padding_mask):
|
That's quite late to answer, but yes it certainly is true. The reason is that we rely on HF functions here, and back to when we wrote this code, I believe there was no alternative. @porfirythelaw could you propose a PR with this fix? I will test it. Many thanks. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Describe the bug
I've been using SpeechBrain Wav2Vec2 training recipe (with HF integration) on my own data, and noticed that I get significantly different metrics with the same model on validation dataset depending on the amount of padding in the batch. My hypothesis was that somehow padding is not ignored during indices sampling process, and I think this is what in fact is happening.
speechbrain/speechbrain/lobes/models/huggingface_transformers/wav2vec2.py
Lines 261 to 265 in eba7714
As you can see in this function you don't provide attention mask, so masked indices are drawn from padded values as well.
Same for negative masked indicies, which you take from the whole sequence
speechbrain/speechbrain/lobes/models/huggingface_transformers/wav2vec2.py
Lines 278 to 286 in eba7714
You provide attention mask in this call to the model
speechbrain/speechbrain/lobes/models/huggingface_transformers/wav2vec2.py
Lines 289 to 296 in eba7714
However, if you check hugginface source code it does not affect loss calculation, it only affects encoder self-attention.
I'm not sure if this behavior was intended or not.
Expected behaviour
Padded values should not be influencing model loss / metrics.
To Reproduce
No response
Environment Details
Speechbrain v0.5.16
Relevant Log Output
No response
Additional Context
No response
The text was updated successfully, but these errors were encountered: