Skip to content

Commit

Permalink
Adding continued_pretraining task (#131)
Browse files Browse the repository at this point in the history
* add continued pretraining script

* simplify config; add dataset_config option

* add ds configs in data mixer creator

* use extended sftconfig

* add option to avoid setting chat template

* fix data_configs bug

* add continued pretraining info

* add gpt2-nl recipe for continued pretraining example

* add final newline

* make style

* Update README.md

Co-authored-by: lewtun <[email protected]>

* Update README.md

Co-authored-by: lewtun <[email protected]>

* Update recipes/gpt2-nl/README.md

Co-authored-by: lewtun <[email protected]>

* rename continued pretraining to cpt

* improve README

---------

Co-authored-by: lewtun <[email protected]>
  • Loading branch information
BramVanroy and lewtun authored Mar 14, 2024
1 parent a9b8a50 commit 595023f
Show file tree
Hide file tree
Showing 12 changed files with 415 additions and 12 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# The Alignment Handbook

Robust recipes to align language models with human and AI preferences.
Robust recipes to continue pretraining and to align language models with human and AI preferences.

## What is this?

Expand All @@ -33,8 +33,8 @@ The Alignment Handbook aims to fill that gap by providing the community with a s

This project is simple by design and mostly consists of:

* [`scripts`](./scripts/) to train and evaluate chat models. Each script supports distributed training of the full model weights with DeepSpeed ZeRO-3, or LoRA/QLoRA for parameter-efficient fine-tuning.
* [`recipes`](./recipes/) to reproduce models like Zephyr 7B. Each recipe takes the form of a YAML file which contains all the parameters associated with a single training run.
* [`scripts`](./scripts/) to train and evaluate models. Three steps are included: continued pretraining, supervised-finetuning (SFT) for chat, and preference alignment with DPO. Each script supports distributed training of the full model weights with DeepSpeed ZeRO-3, or LoRA/QLoRA for parameter-efficient fine-tuning.
* [`recipes`](./recipes/) to reproduce models like Zephyr 7B. Each recipe takes the form of a YAML file which contains all the parameters associated with a single training run. A `gpt2-nl` recipe is also given to illustrate how this handbook can be used for language or domain adaptation, e.g. by continuing to pretrain on a different language, and then SFT and DPO tuning the result.

We are also working on a series of guides to explain how methods like direct preference optimization (DPO) work, along with lessons learned from gathering human preferences in practice. To get started, we recommend the following:

Expand All @@ -48,6 +48,7 @@ If you would like to train chat models on your own datasets, we recommend follow

The initial release of the handbook will focus on the following techniques:

* **Continued pretraining:** adapt language models to a new language or domain, or simply improve it by continue pretraning (causal language modeling) on a new dataset.
* **Supervised fine-tuning:** teach language models to follow instructions and tips on how to collect and curate your own training dataset.
* **Reward modeling:** teach language models to distinguish model responses according to human or AI preferences.
* **Rejection sampling:** a simple, but powerful technique to boost the performance of your SFT model.
Expand Down
43 changes: 43 additions & 0 deletions recipes/gpt2-nl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Language Adaptation through Continued Pretraining

This directory shows a base example of how to use continued pretraining and further tuning to adapt a language model to new data (e.g. a new language or domain).

Three steps are needed: continued pretraining (`cpt`), supervised finetuning (`sft`), and direct preference optimisation (`dpo`). In this dummy example we'll continue pretraining gpt2 on Dutch raw data, then sft-tuning it, and finally aligning it with DPO. Note that no extensive hyperparameters were tested in this example and that the output models are bad - it is just to show you how you can use the scripts for LM adaptation. The scripts work on 4x 3090s (24GB VRAM). If you have less powerful hardware you may need to reduce the batch size.

## Continued pretraining

This step will further pretrain the original `gpt2` model on plain Dutch text. Note that the script will by default use the `text` column in the dataset but you can change that by specifying `text_column` in the yaml file or on the command-line.

```shell
ACCELERATE_LOG_LEVEL=info accelerate launch \
--config_file recipes/accelerate_configs/multi_gpu.yaml \
--num_processes 4 \
scripts/run_cpt.py \
recipes/gpt2-nl/cpt/config_full.yaml
```

## Supervised finetuning

As other recipes, such as the famous zephyr-7b-beta recipe, have shown, we can then teach our model how to hold a conversation by finetuning it on chat-formatted data. As a base model we'll make use of the output of the previous step.

```shell
ACCELERATE_LOG_LEVEL=info accelerate launch \
--config_file recipes/accelerate_configs/multi_gpu.yaml \
--num_processes 4 \
scripts/run_sft.py recipes/gpt2-nl/sft/config_full.yaml
```

## Direct preference optimisation

Finally, to align the model better with feedback, we can finetune the SFT output with the DPO algorithm. This should improve the quality of the chat capabilities of the model.

```shell
ACCELERATE_LOG_LEVEL=info accelerate launch \
--config_file recipes/accelerate_configs/multi_gpu.yaml \
--num_processes 4 \
scripts/run_dpo.py recipes/gpt2-nl/dpo/config_full.yaml
```

## Conclusion

With the steps above you can adapt an LM to a new domain, more data, or even a different language. Then, with sft and dpo, you can end up building a powerful chatbot, too! All within just three simple commands. It should be obvious that all of these follow a very similar approach, which makes them suitable to apply in parameterized slurm jobs. The neat part is that you can easily overwrite arguments in the yaml files by specifying the overwriting argument as a command-line argument, so the adaptability is also great.
45 changes: 45 additions & 0 deletions recipes/gpt2-nl/cpt/config_full.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Model arguments
model_name_or_path: gpt2
model_revision: main
torch_dtype: bfloat16

# Data training arguments
dataset_mixer:
yhavinga/mc4_nl_cleaned: 1.0
dataset_splits:
- train
dataset_configs:
- tiny
preprocessing_num_workers: 12

# SFT trainer config
bf16: true
do_eval: False
evaluation_strategy: "no"
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: False
hub_model_id: gpt2-cpt-dutch
hub_strategy: every_save
learning_rate: 2.0e-04
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_seq_length: 1024
max_steps: -1
num_train_epochs: 1
output_dir: data/gpt2-cpt-dutch
overwrite_output_dir: true
per_device_eval_batch_size: 8
per_device_train_batch_size: 16
push_to_hub: true
remove_unused_columns: true
report_to:
- wandb
save_strategy: "steps"
save_steps: 100
save_total_limit: 1
seed: 42
warmup_ratio: 0.1
44 changes: 44 additions & 0 deletions recipes/gpt2-nl/dpo/config_full.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Model arguments
model_name_or_path: BramVanroy/gpt2-sft-dutch
model_revision: main
torch_dtype: bfloat16

# Data training arguments
# For definitions, see: src/h4/training/config.py
dataset_mixer:
BramVanroy/ultra_feedback_dutch: 1.0
dataset_splits:
- train_prefs
- test_prefs
preprocessing_num_workers: 12

# DPOTrainer arguments
bf16: true
beta: 0.1
do_eval: true
evaluation_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: False
hub_model_id: gpt2-dpo-dutch
learning_rate: 5.0e-7
log_level: info
logging_steps: 10
lr_scheduler_type: cosine
max_length: 1024
max_prompt_length: 512
num_train_epochs: 1
optim: adamw_torch
output_dir: data/gpt2-dpo-dutch
per_device_train_batch_size: 8
per_device_eval_batch_size: 8
push_to_hub: true
save_strategy: "steps"
save_steps: 100
save_total_limit: 1
seed: 42
warmup_ratio: 0.1
report_to:
- wandb
45 changes: 45 additions & 0 deletions recipes/gpt2-nl/sft/config_full.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Model arguments
model_name_or_path: BramVanroy/gpt2-cpt-dutch
model_revision: main
torch_dtype: bfloat16

# Data training arguments
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
dataset_mixer:
BramVanroy/ultrachat_200k_dutch: 1.0
dataset_splits:
- train_sft
- test_sft
preprocessing_num_workers: 12

# SFT trainer config
bf16: true
do_eval: true
evaluation_strategy: epoch
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: False
hub_model_id: gpt2-sft-dutch
hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_seq_length: 1024
max_steps: -1
num_train_epochs: 1
output_dir: data/gpt2-sft-dutch
overwrite_output_dir: true
per_device_eval_batch_size: 8
per_device_train_batch_size: 8
push_to_hub: true
remove_unused_columns: true
report_to:
- wandb
save_strategy: "steps"
save_steps: 100
save_total_limit: 1
seed: 42
warmup_ratio: 0.1
2 changes: 1 addition & 1 deletion scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml --num_processes={num_gpus} scripts/run_{task}.py recipes/{model_name}/{task}/config_qlora.yaml --load_in_4bit=false
```

Here `{task}` refers to the type of training you wish to run (SFT, DPO, etc), while `{model_name}` refers to the choice of a recipe in the `recipes` directory. For example, to replicate Zephyr-7B-β you can run:
Here `{task}` refers to the type of training you wish to run. Currently the following tasks are supported: continued pretraining `cpt`, supervised finetuning `sft`, and direct preference optimisation `dpo`. Note that `cpt` is only present in the `gpt-nl` example recipe. {model_name}` refers to the choice of a recipe in the `recipes` directory. For example, to replicate Zephyr-7B-β you can run:

```shell
# Step 1 - train SFT policy
Expand Down
Loading

0 comments on commit 595023f

Please sign in to comment.