diff --git a/recipes/zephyr-7b-beta/README.md b/recipes/zephyr-7b-beta/README.md index 1134e719..d27de43a 100644 --- a/recipes/zephyr-7b-beta/README.md +++ b/recipes/zephyr-7b-beta/README.md @@ -23,10 +23,22 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con ## QLoRA training examples -```shell +Train faster with flash-attention 2 (GPU supporting FA2: A100, H100, etc) +```````shell # Step 1 - SFT ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_qlora.yaml --load_in_4bit=true # Step 2 - DPO ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_qlora.yaml -``` \ No newline at end of file +``````` + +P.S. Using Flash Attention also allows you to drastically increase the batch size (x2 in my case) + +Train without flash-attention: +```````shell +# Step 1 - SFT +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_qlora.yaml --load_in_4bit=true --use_flash_attention_2=false + +# Step 2 - DPO +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_qlora.yaml --use_flash_attention_2=false +``````` \ No newline at end of file diff --git a/recipes/zephyr-7b-beta/dpo/config_qlora.yaml b/recipes/zephyr-7b-beta/dpo/config_qlora.yaml index 3928341f..77742558 100644 --- a/recipes/zephyr-7b-beta/dpo/config_qlora.yaml +++ b/recipes/zephyr-7b-beta/dpo/config_qlora.yaml @@ -1,6 +1,7 @@ # Model arguments model_name_or_path: alignment-handbook/zephyr-7b-sft-qlora torch_dtype: bfloat16 +use_flash_attention_2: true # LoRA arguments use_peft: true diff --git a/recipes/zephyr-7b-beta/sft/config_qlora.yaml b/recipes/zephyr-7b-beta/sft/config_qlora.yaml index 3b09218b..19840830 100644 --- a/recipes/zephyr-7b-beta/sft/config_qlora.yaml +++ b/recipes/zephyr-7b-beta/sft/config_qlora.yaml @@ -1,7 +1,8 @@ # Model arguments model_name_or_path: mistralai/Mistral-7B-v0.1 model_revision: main -torch_dtype: float16 +torch_dtype: bfloat16 +use_flash_attention_2: true # LoRA arguments load_in_4bit: true