Skip to content

Commit

Permalink
fix: Zephyr LoRA fine-tuning fixed (#139)
Browse files Browse the repository at this point in the history
Co-authored-by: svbogdanov <[email protected]>
  • Loading branch information
Serega6678 and svbogdanov authored Mar 21, 2024
1 parent 595023f commit c44cb1c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
16 changes: 14 additions & 2 deletions recipes/zephyr-7b-beta/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,22 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con

## QLoRA training examples

```shell
Train faster with flash-attention 2 (GPU supporting FA2: A100, H100, etc)
```````shell
# Step 1 - SFT
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_qlora.yaml --load_in_4bit=true

# Step 2 - DPO
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_qlora.yaml
```
```````

P.S. Using Flash Attention also allows you to drastically increase the batch size (x2 in my case)

Train without flash-attention:
```````shell
# Step 1 - SFT
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_qlora.yaml --load_in_4bit=true --use_flash_attention_2=false

# Step 2 - DPO
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_qlora.yaml --use_flash_attention_2=false
```````
1 change: 1 addition & 0 deletions recipes/zephyr-7b-beta/dpo/config_qlora.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Model arguments
model_name_or_path: alignment-handbook/zephyr-7b-sft-qlora
torch_dtype: bfloat16
use_flash_attention_2: true

# LoRA arguments
use_peft: true
Expand Down
3 changes: 2 additions & 1 deletion recipes/zephyr-7b-beta/sft/config_qlora.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Model arguments
model_name_or_path: mistralai/Mistral-7B-v0.1
model_revision: main
torch_dtype: float16
torch_dtype: bfloat16
use_flash_attention_2: true

# LoRA arguments
load_in_4bit: true
Expand Down

0 comments on commit c44cb1c

Please sign in to comment.