-
Notifications
You must be signed in to change notification settings - Fork 29.2k
Fixes the inconsistency of the optionality of attention_mask #37153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
@@ -280,7 +280,7 @@ def _flash_attention_forward( | |||
query_states: torch.Tensor, | |||
key_states: torch.Tensor, | |||
value_states: torch.Tensor, | |||
attention_mask: torch.Tensor, | |||
attention_mask: Optional[torch.Tensor], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should then set it to None by default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! @Godofnothing Yes! I noticed this is different from your proposed solution in issue #37046. My point is that attention_mask
will be passed from flash_attention_forward
to _flash_attention_forward
, though the type can be either torch.Tensor
or NoneType
, so I think it's not necessary to explicit set the default value as None
here.
A similar case is in modeling_llama.py
, I noticed LlamaModel
uses Option[torch.Tensor] = None
for attention_mask
where as LlamaAttention
simply uses Optional[torch.Tensor]
, because LlamaModel is responsible to pass the value of attention_mask
to LlamaAttention
. Technically I could modify the types of attention_mask
in all involved functions to Optional[torch.Tensor] = None
, but I did not do that for the sake of simplicity.
Please correct me if I'm wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, that's fine, thanks for your efforts!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I think Optional[torch.Tensor]
without a default is okay here.
…face#37153) * debugging issue 36758 * debugging issue 36758 * debugging issue 36758 * updated attn_mask type specification in _flash_attention_forward * removed pdb * added a blank line * removed indentation
…face#37153) * debugging issue 36758 * debugging issue 36758 * debugging issue 36758 * updated attn_mask type specification in _flash_attention_forward * removed pdb * added a blank line * removed indentation
…face#37153) * debugging issue 36758 * debugging issue 36758 * debugging issue 36758 * updated attn_mask type specification in _flash_attention_forward * removed pdb * added a blank line * removed indentation
What does this PR do?
This PR fixes Issue #37046.
This issue discusses the inconsistency of the optionality of the parameter "attention_mask" in different functions. Specifically, the type of attention_mask in
LlamaForCausalLM
,LlamaModel
, andLlamaDecoderLayer
areOptional[torch.Tensor] = None
. InLlamaAttention
(a class which wraps the attention interface), the type isOptional[torch.Tensor]
becauseattention_mask
will be passed to it fromLlamaModel
, whether the type ofattention_mask
istorch.Tensor
orNoneType
. Therefore,attention_mask
will be passed to the attention_interface, no matter it's a tensor or NoneType.The key problem lies in the type specification in flash_attention. In
flash_attention_forward
, the type ofattention_mask
is stillOptional[torch.Tensor]
whereas in_flash_attention_forward
, a function called inflash_attention_forward
, the type istorch.Tensor
, which is unreasonable because:flash_attention_forward
._flash_attention_forward
is usable even ifattention_mask
is None.Therefore, this PR fixes the specification in
_flash_attention_forward
as well as the docstring to address the aforementioned issue.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.