Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chatglm4-9b-chat微调使用flash_attn报错:RuntimeError: FlashAttention only support fp16 and bf16 data type #2850

Open
DanLiu0623 opened this issue Jan 3, 2025 · 1 comment

Comments

@DanLiu0623
Copy link

Describe the bug
使用数据微调chatglm4-9b-chat模型,训练参数为:
CUDA_VISIBLE_DEVICES=0
swift sft
--model /home/models/glm-4-9b-chat
--model_type glm4
--torch-dtype bfloat16
--train_type lora
--dataset /home/train.jsonl
--num_train_epochs 3
--per_device_train_batch_size 1
--learning_rate 1e-4
--lora_rank 8
--lora_alpha 32
--gradient_accumulation_steps 16
--attn_impl flash_attn
--eval_steps 100
--save_steps 100
--save_total_limit 2
--logging_steps 10
--model_author ld
--model_name glm4-chat-lora
报错 File "/home/user1/liudan32/ms-swift-3.0.1/swift/cli/sft.py", line 5, in <module> sft_main() File "/home/user1/liudan32/ms-swift-3.0.1/swift/llm/train/sft.py", line 320, in sft_main return SwiftSft(args).main() File "/home/user1/liudan32/ms-swift-3.0.1/swift/llm/base.py", line 45, in main result = self.run() File "/home/user1/liudan32/ms-swift-3.0.1/swift/llm/train/sft.py", line 159, in run return self.train(trainer) File "/home/user1/liudan32/ms-swift-3.0.1/swift/llm/train/sft.py", line 214, in train trainer.train(trainer.args.resume_from_checkpoint) File "/home/user1/liudan32/ms-swift-3.0.1/swift/trainers/mixin.py", line 261, in train res = super().train(*args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/transformers/trainer.py", line 2164, in train return inner_training_loop( File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/transformers/trainer.py", line 2524, in _inner_training_loop tr_loss_step = self.training_step(model, inputs, num_items_in_batch) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/transformers/trainer.py", line 3654, in training_step loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) File "/home/user1/liudan32/ms-swift-3.0.1/swift/trainers/trainers.py", line 142, in compute_loss outputs = model(**inputs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/accelerate/utils/operations.py", line 823, in forward return model_forward(*args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/accelerate/utils/operations.py", line 811, in __call__ return convert_to_fp32(self.model_forward(*args, **kwargs)) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast return func(*args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/peft/peft_model.py", line 1719, in forward return self.base_model( File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 197, in forward return self.model.forward(*args, **kwargs) File "/home/user1/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 994, in forward transformer_outputs = self.transformer( File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/home/user1/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 892, in forward hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/home/user1/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 712, in forward layer_ret = torch.utils.checkpoint.checkpoint( File "/home/user1/liudan32/ms-swift-3.0.1/swift/trainers/arguments.py", line 46, in _new_checkpoint return _old_checkpoint(*args, use_reentrant=use_reentrant_, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner return disable_fn(*args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn return fn(*args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 489, in checkpoint return CheckpointFunction.apply(function, preserve, *args) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 264, in forward outputs = run_function(*args) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/home/user1/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 625, in forward attention_output, kv_cache = self.self_attention( File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/home/user1/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 522, in forward context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/home/user1/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 334, in forward attn_output = flash_attn_func( File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 1168, in flash_attn_func return FlashAttnFunc.apply( File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 815, in forward out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward( File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/_ops.py", line 1116, in __call__ return self._op(*args, **(kwargs or {})) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/_library/autograd.py", line 113, in autograd_impl result = forward_no_grad(*args, Metadata(keyset, keyword_only_args)) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/_library/autograd.py", line 40, in forward_no_grad result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/_ops.py", line 721, in redispatch return self._handle.redispatch_boxed(keyset, *args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 324, in backend_impl result = self._backend_fns[device_type](*args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner return disable_fn(*args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn return fn(*args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 367, in wrapped_fn return fn(*args, **kwargs) File "/home/user1/miniconda3/envs/swift/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 96, in _flash_attn_forward out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( RuntimeError: FlashAttention only support fp16 and bf16 data type

Your hardware and system info
cuda:12.4
gpu:RTX4090
`
absl-py 2.1.0
accelerate 1.2.1
addict 2.4.0
aiofiles 23.2.1
aiohappyeyeballs 2.4.4
aiohttp 3.11.11
aiosignal 1.3.2
aliyun-python-sdk-core 2.16.0
aliyun-python-sdk-kms 2.16.5
annotated-types 0.7.0
anyio 4.7.0
async-timeout 5.0.1
attrdict 2.0.1
attrs 24.3.0
binpacking 1.5.2
dacite 1.8.1
.14
cffi 1.17.1
charset-normalizer 3.4.1
click 8.1.8
contourpy 1.3.1
cpm-kernels 1.0.11
crcmod 1.7
cryptography 44.0.0
cycler 0.12.1
ffmpy 0.5.0
datasets 3.0.1
dill 0.3.8
distro 1.9.0
docstring_parser 0.16
einops 0.8.0
exceptiongroup 1.2.2
fastapi 0.115.6
gradio_client 1.5.2
filelock 3.16.1
flash_attn 2.7.2.post1
fonttools 4.55.3
frozenlist 1.5.0
fsspec 2024.6.1
future 1.0.0
gradio 5.9.1
gradio_client 1.5.2
grpcio 1.68.1
h11 0.14.0
httpcore 1.0.7
httpx 0.28.1
huggingface-hub 0.27.0
idna 3.10
importlib_metadata 8.5.0
MarkupSafe 2.1.5

Jinja2 3.1.5
jiter 0.8.2
jmespath 0.10.0
joblib 1.4.2
kiwisolver 1.4.8
Markdown 3.7
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib 3.10.0
mdurl 0.1.2
modelscope 1.21.1
mpmath 1.3.0
ms-swift 3.0.1 /home/user1/liudan32/ms-swift-3.0.1
multidict 6.1.0
multiprocess 0.70.16
networkx 3.4.2
ninja 1.11.1.3
nltk 3.9.1
numpy 1.26.4
nvidia-cublas-cu12 12.4.5.8
nvidia-cuda-cupti-cu12 12.4.127
nvidia-cuda-nvrtc-cu12 12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.2.1.3
nvidia-curand-cu12 10.3.5.147
nvidia-cusolver-cu12 11.6.1.9
nvidia-cusparse-cu12 12.3.1.170
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.4.127
openai 1.58.1
orjson 3.10.13
oss2 2.19.1
packaging 24.2
pandas 2.2.3
peft 0.14.0
pillow 11.0.0
pip 24.2
propcache 0.2.1
protobuf 5.29.2
psutil 6.1.1
pyarrow 18.1.0
pycparser 2.22
pycryptodome 3.21.0
pydantic 2.10.4
pydantic_core 2.27.2
pydub 0.25.1
Pygments 2.18.0
pyparsing 3.2.0
python-dateutil 2.9.0.post0
python-multipart 0.0.20
pytz 2024.2
PyYAML 6.0.2
regex 2024.11.6
requests 2.32.3
rich 13.9.4
rouge 1.0.1
ruff 0.8.4
safehttpx 0.1.6
safetensors 0.4.5
scipy 1.14.1
semantic-version 2.10.0
sentencepiece 0.2.0
setuptools 69.5.1
shellingham 1.5.4
shtab 1.7.1
simplejson 3.19.3
six 1.17.0
sniffio 1.3.1
tqdm 4.67.1
tarlette 0.41.3
sympy 1.13.1
tensorboard 2.18.0
tensorboard-data-server 0.7.2
tiktoken 0.8.0
tokenizers 0.21.0
tomlkit 0.13.2
torch 2.5.1
tqdm 4.67.1
transformers 4.47.1
transformers-stream-generator 0.0.5
triton 3.1.0
trl 0.11.4
typeguard 4.4.1
typer 0.15.1
typing_extensions 4.12.2
tyro 0.9.5
tzdata 2024.2
urllib3 2.3.0
uvicorn 0.34.0
websockets 14.1
Werkzeug 3.1.3
wheel 0.44.0
xxhash 3.5.0
yarl 1.18.3
zipp 3.21.0`

@Jintao-Huang
Copy link
Collaborator

--torch-dtype float16

这个试试

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants