From a60e8208bb94ea51cd8ddd2ecf84970b26cfe016 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 10 Jan 2025 01:55:19 +0800 Subject: [PATCH] fix flash_attn import on old GPU Signed-off-by: Isotr0py <2037008807@qq.com> --- deepseek_vl2/models/siglip_vit.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/deepseek_vl2/models/siglip_vit.py b/deepseek_vl2/models/siglip_vit.py index 462d5a5..cc102a4 100644 --- a/deepseek_vl2/models/siglip_vit.py +++ b/deepseek_vl2/models/siglip_vit.py @@ -12,11 +12,15 @@ AttentionPoolLatent, PatchDropout, resample_abs_pos_embed, LayerType ) from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv -from flash_attn import flash_attn_qkvpacked_func +from transformers.modeling_utils import is_flash_attn_2_available from xformers.ops import memory_efficient_attention from functools import partial +if is_flash_attn_2_available(): + from flash_attn import flash_attn_qkvpacked_func + + def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf @@ -134,7 +138,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) if not self.qk_norm: - if self.head_dim % 32 == 0: + if self.head_dim % 32 == 0 and is_flash_attn_2_available(): # flashattn的head_dim必须是32的倍数,SigLIP-SO400M无法使用flashattn x = flash_attn_qkvpacked_func(qkv, dropout_p=self.attn_drop.p if self.training else 0., deterministic=self.deterministic)