-
Notifications
You must be signed in to change notification settings - Fork 433
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding continued_pretraining task (#131)
* add continued pretraining script * simplify config; add dataset_config option * add ds configs in data mixer creator * use extended sftconfig * add option to avoid setting chat template * fix data_configs bug * add continued pretraining info * add gpt2-nl recipe for continued pretraining example * add final newline * make style * Update README.md Co-authored-by: lewtun <[email protected]> * Update README.md Co-authored-by: lewtun <[email protected]> * Update recipes/gpt2-nl/README.md Co-authored-by: lewtun <[email protected]> * rename continued pretraining to cpt * improve README --------- Co-authored-by: lewtun <[email protected]>
- Loading branch information
1 parent
a9b8a50
commit 595023f
Showing
12 changed files
with
415 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Language Adaptation through Continued Pretraining | ||
|
||
This directory shows a base example of how to use continued pretraining and further tuning to adapt a language model to new data (e.g. a new language or domain). | ||
|
||
Three steps are needed: continued pretraining (`cpt`), supervised finetuning (`sft`), and direct preference optimisation (`dpo`). In this dummy example we'll continue pretraining gpt2 on Dutch raw data, then sft-tuning it, and finally aligning it with DPO. Note that no extensive hyperparameters were tested in this example and that the output models are bad - it is just to show you how you can use the scripts for LM adaptation. The scripts work on 4x 3090s (24GB VRAM). If you have less powerful hardware you may need to reduce the batch size. | ||
|
||
## Continued pretraining | ||
|
||
This step will further pretrain the original `gpt2` model on plain Dutch text. Note that the script will by default use the `text` column in the dataset but you can change that by specifying `text_column` in the yaml file or on the command-line. | ||
|
||
```shell | ||
ACCELERATE_LOG_LEVEL=info accelerate launch \ | ||
--config_file recipes/accelerate_configs/multi_gpu.yaml \ | ||
--num_processes 4 \ | ||
scripts/run_cpt.py \ | ||
recipes/gpt2-nl/cpt/config_full.yaml | ||
``` | ||
|
||
## Supervised finetuning | ||
|
||
As other recipes, such as the famous zephyr-7b-beta recipe, have shown, we can then teach our model how to hold a conversation by finetuning it on chat-formatted data. As a base model we'll make use of the output of the previous step. | ||
|
||
```shell | ||
ACCELERATE_LOG_LEVEL=info accelerate launch \ | ||
--config_file recipes/accelerate_configs/multi_gpu.yaml \ | ||
--num_processes 4 \ | ||
scripts/run_sft.py recipes/gpt2-nl/sft/config_full.yaml | ||
``` | ||
|
||
## Direct preference optimisation | ||
|
||
Finally, to align the model better with feedback, we can finetune the SFT output with the DPO algorithm. This should improve the quality of the chat capabilities of the model. | ||
|
||
```shell | ||
ACCELERATE_LOG_LEVEL=info accelerate launch \ | ||
--config_file recipes/accelerate_configs/multi_gpu.yaml \ | ||
--num_processes 4 \ | ||
scripts/run_dpo.py recipes/gpt2-nl/dpo/config_full.yaml | ||
``` | ||
|
||
## Conclusion | ||
|
||
With the steps above you can adapt an LM to a new domain, more data, or even a different language. Then, with sft and dpo, you can end up building a powerful chatbot, too! All within just three simple commands. It should be obvious that all of these follow a very similar approach, which makes them suitable to apply in parameterized slurm jobs. The neat part is that you can easily overwrite arguments in the yaml files by specifying the overwriting argument as a command-line argument, so the adaptability is also great. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Model arguments | ||
model_name_or_path: gpt2 | ||
model_revision: main | ||
torch_dtype: bfloat16 | ||
|
||
# Data training arguments | ||
dataset_mixer: | ||
yhavinga/mc4_nl_cleaned: 1.0 | ||
dataset_splits: | ||
- train | ||
dataset_configs: | ||
- tiny | ||
preprocessing_num_workers: 12 | ||
|
||
# SFT trainer config | ||
bf16: true | ||
do_eval: False | ||
evaluation_strategy: "no" | ||
gradient_accumulation_steps: 1 | ||
gradient_checkpointing: true | ||
gradient_checkpointing_kwargs: | ||
use_reentrant: False | ||
hub_model_id: gpt2-cpt-dutch | ||
hub_strategy: every_save | ||
learning_rate: 2.0e-04 | ||
log_level: info | ||
logging_steps: 5 | ||
logging_strategy: steps | ||
lr_scheduler_type: cosine | ||
max_seq_length: 1024 | ||
max_steps: -1 | ||
num_train_epochs: 1 | ||
output_dir: data/gpt2-cpt-dutch | ||
overwrite_output_dir: true | ||
per_device_eval_batch_size: 8 | ||
per_device_train_batch_size: 16 | ||
push_to_hub: true | ||
remove_unused_columns: true | ||
report_to: | ||
- wandb | ||
save_strategy: "steps" | ||
save_steps: 100 | ||
save_total_limit: 1 | ||
seed: 42 | ||
warmup_ratio: 0.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Model arguments | ||
model_name_or_path: BramVanroy/gpt2-sft-dutch | ||
model_revision: main | ||
torch_dtype: bfloat16 | ||
|
||
# Data training arguments | ||
# For definitions, see: src/h4/training/config.py | ||
dataset_mixer: | ||
BramVanroy/ultra_feedback_dutch: 1.0 | ||
dataset_splits: | ||
- train_prefs | ||
- test_prefs | ||
preprocessing_num_workers: 12 | ||
|
||
# DPOTrainer arguments | ||
bf16: true | ||
beta: 0.1 | ||
do_eval: true | ||
evaluation_strategy: steps | ||
eval_steps: 100 | ||
gradient_accumulation_steps: 8 | ||
gradient_checkpointing: true | ||
gradient_checkpointing_kwargs: | ||
use_reentrant: False | ||
hub_model_id: gpt2-dpo-dutch | ||
learning_rate: 5.0e-7 | ||
log_level: info | ||
logging_steps: 10 | ||
lr_scheduler_type: cosine | ||
max_length: 1024 | ||
max_prompt_length: 512 | ||
num_train_epochs: 1 | ||
optim: adamw_torch | ||
output_dir: data/gpt2-dpo-dutch | ||
per_device_train_batch_size: 8 | ||
per_device_eval_batch_size: 8 | ||
push_to_hub: true | ||
save_strategy: "steps" | ||
save_steps: 100 | ||
save_total_limit: 1 | ||
seed: 42 | ||
warmup_ratio: 0.1 | ||
report_to: | ||
- wandb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Model arguments | ||
model_name_or_path: BramVanroy/gpt2-cpt-dutch | ||
model_revision: main | ||
torch_dtype: bfloat16 | ||
|
||
# Data training arguments | ||
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" | ||
dataset_mixer: | ||
BramVanroy/ultrachat_200k_dutch: 1.0 | ||
dataset_splits: | ||
- train_sft | ||
- test_sft | ||
preprocessing_num_workers: 12 | ||
|
||
# SFT trainer config | ||
bf16: true | ||
do_eval: true | ||
evaluation_strategy: epoch | ||
gradient_accumulation_steps: 1 | ||
gradient_checkpointing: true | ||
gradient_checkpointing_kwargs: | ||
use_reentrant: False | ||
hub_model_id: gpt2-sft-dutch | ||
hub_strategy: every_save | ||
learning_rate: 2.0e-05 | ||
log_level: info | ||
logging_steps: 5 | ||
logging_strategy: steps | ||
lr_scheduler_type: cosine | ||
max_seq_length: 1024 | ||
max_steps: -1 | ||
num_train_epochs: 1 | ||
output_dir: data/gpt2-sft-dutch | ||
overwrite_output_dir: true | ||
per_device_eval_batch_size: 8 | ||
per_device_train_batch_size: 8 | ||
push_to_hub: true | ||
remove_unused_columns: true | ||
report_to: | ||
- wandb | ||
save_strategy: "steps" | ||
save_steps: 100 | ||
save_total_limit: 1 | ||
seed: 42 | ||
warmup_ratio: 0.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.