Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flash_attn import in siglip_vit #34

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions deepseek_vl2/models/siglip_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
AttentionPoolLatent, PatchDropout, resample_abs_pos_embed, LayerType
)
from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from flash_attn import flash_attn_qkvpacked_func
from transformers.modeling_utils import is_flash_attn_2_available
from xformers.ops import memory_efficient_attention
from functools import partial


if is_flash_attn_2_available():
from flash_attn import flash_attn_qkvpacked_func


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
Expand Down Expand Up @@ -134,7 +138,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)

if not self.qk_norm:
if self.head_dim % 32 == 0:
if self.head_dim % 32 == 0 and is_flash_attn_2_available():
# flashattn的head_dim必须是32的倍数,SigLIP-SO400M无法使用flashattn
x = flash_attn_qkvpacked_func(qkv, dropout_p=self.attn_drop.p if self.training else 0.,
deterministic=self.deterministic)
Expand Down