Skip to content

Latest commit

 

History

History
121 lines (93 loc) · 5.65 KB

README.md

File metadata and controls

121 lines (93 loc) · 5.65 KB

Bridging and Modeling Correlations in Pairwise Data for Direct Preference Optimization (ICLR 2025)

Direct preference optimization (DPO), a widely adopted offline preference optimization algorithm, aims to align large language models (LLMs) with humandesired behaviors using pairwise preference data. However, the winning response and the losing response within pairwise data are generated isolatedly, leading to weak correlations between them as well as suboptimal alignment performance. To address this issue, we propose an effective framework named BMC, for bridging and modeling correlations in pairwise data.

  • Firstly, we increase the consistency and informativeness of the pairwise preference signals through targeted modifications, synthesizing a pseudo-winning response by improving the losing response with the winning response as a reference.
  • Secondly, we identify that DPO alone is insufficient to model these correlations and capture nuanced variations. Therefore, we propose learning token-level correlations by dynamically leveraging the policy model’s confidence during training.



🔍 Table of Contents

⚙️ Install Requirements

Our codebase is built upon the alignment-handbook repo. The following steps will guide you through the installation process.

First, create a Python virtual environment using e.g. Conda:

conda create -n handbook python=3.10 && conda activate handbook

Next, install PyTorch v2.2.2. Since this is hardware-dependent, we direct you to the PyTorch Installation Page.

You can then install the remaining package dependencies of alignment-handbook as follows:

git clone https://github.com/huggingface/alignment-handbook.git
cd ./alignment-handbook/
python -m pip install .

You will also need Flash Attention 2 installed, which can be done by running:

python -m pip install flash-attn --no-build-isolation

Finally, install other required packages:

pip install -r requirements.txt

💻 Training Scripts

Before training, we need to utilize gpt-4-0125-preview to obtain the pseudo-winning reponse by making targeted modification of the losing response. This process is achieved by calling OpenAI's API. The obtained data can be found in the training_data directory.

We provide training config files for the three setups including question answering, mathematical reasoning, and instruction following. The training config is set for 4xA800 GPUs. You may need to adjust num_processes and per_device_train_batch_size based on your computation environment.

Question Answering

  • Llama-2-7B-Base (SFT):
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py training_configs/question_answering/llama-2-7b-base-sft.yaml
  • Llama-2-7B-Base (DPO-BMC):
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo_bmc.py training_configs/question_answering/llama-2-7b-base-dpo-bmc.yaml

Mathematical Reasoning

  • Llama-2-7B-Base (SFT):
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py training_configs/math/llama-2-7b-base-sft.yaml
  • Llama-2-7B-Base (DPO-BMC):
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo_bmc.py training_configs/math/llama-2-7b-base-dpo-bmc.yaml

Insturction Following

  • Llama-3-8B-Base (DPO-BMC):
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo_bmc.py training_configs/instruction_following/llama-3-8b-base-dpo-bmc.yaml
  • Mistral-7B-Base (DPO-BMC):
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo_bmc.py training_configs/instruction_following/mistral-7b-base-dpo-bmc.yaml

💹 Evaluation

We conduct three downstream scenarios for a comprehensive evaluation, including question answering, mathematical reasoning, and instruction following.

Question Answering

cd /eval/QA
bash run_eval.sh

Mathematical Reasoning

cd /eval/math
bash run_eval.sh

Insturction Following

We follow the official implementation for evaluation on AlpacaEval 2 and Arena-Hard, as follows:

  • AlpacaEval 2: Please refer to the AlpacaEval repo for evaluation. We provide generation configurations for our models in the eval/alpacaeval2 directory.

  • Arena-Hard: Please refer to to the Arena-Hard-Auto repo for evaluation. We provide generation configurations for our models in the eval/arenahard directory.

📝 Citation

Please cite our paper if you find the repo helpful in your work:

@inproceedings{jiang2025dpobmc,
title={Bridging and Modeling Correlations in Pairwise Data for Direct Preference Optimization},
author={Jiang, Yuxin and Huang, bo and Wang, Yufei and Zeng Xingshan and Li, Liangyou and Wang, Yasheng and Jiang, Xin and Shang, Lifeng and Tang, Ruiming and Wang, Wei},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=hRwxZmcvW9}
}