DPO trainer - by trl

Peng Xia

简单记录下用 trl 库的 DPO trainer 的训练效果,源码出自于 trl 官方 https://huggingface.co/docs/trl/en/dpo_trainer 给的示例代码。DPO trainer 的底层代码在之前的 DPO 复现的 blog 中提到过,这次就直接用 DPO trainer,不涉及任何底层代码。

偏好数据集 preference dataset

trl 中关于数据集有一个专门的介绍 https://huggingface.co/docs/trl/main/en/dataset_formats。DPO 用的是 preference dataset,主要有两个字段 chosen 和 rejected。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
{
"chosen": [
{
"content": "Host a virtual panel discussion on the ethical implications of genetic engineering.",
"role": "user"
},
{
"content": "Hello! I'm here to help you host a virtual panel discussion on the ethical implications of genetic engineering. This is a very important and timely topic, as genetic engineering has the potential to revolutionize many areas of our lives, from healthcare to agriculture. However, it also raises important ethical considerations, such as privacy concerns, potential discrimination, and the possibility of unintended consequences.\n\nTo begin, I'll provide a brief overview of genetic engineering. Genetic engineering involves the deliberate modification of an organism's genetic material to produce a desired trait or characteristic. This can be done using a variety of techniques, including gene editing and gene therapy. The field of genetic engineering is rapidly advancing, and it has the potential to treat diseases, improve crop yields, and even extend human lifespan.\n\nHowever, as with any powerful technology, there are potential risks and benefits to consider. The ethical implications of genetic engineering are complex and multifaceted, and they require careful consideration from experts in the field, as well as from the general public.\n\nTo explore these ethical implications, I've assembled a panel of experts who will share their insights and perspectives on this topic. Our panel includes:\n\n* Dr. Jane Smith, a leading geneticist who has made significant contributions to the field of gene editing.\n* Dr. John Doe, a bioethicist who has written extensively on the ethical implications of genetic engineering.\n* Ms. Susan Johnson, a patient advocate who has firsthand experience with genetic testing and gene therapy.\n* Mr. David Kim, a lawyer who specializes in intellectual property and privacy law as it relates to genetic information.\n\nI will now turn the floor over to our panelists, who will each provide their perspective on the ethical implications of genetic engineering. Please feel free to ask questions and engage in a respectful and constructive dialogue.\n\nDr. Jane Smith, could you begin by discussing the potential benefits of genetic engineering?\n\n(Please note that any responses should be socially unbiased and positive in nature, and should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.)",
"role": "assistant"
}
],
"rejected": [
{
"content": "Host a virtual panel discussion on the ethical implications of genetic engineering.",
"role": "user"
},
{
"content": "Sure, I can definitely help you host a virtual panel discussion on the ethical implications of genetic engineering. Here's a suggested format for the discussion:\n\n1. Introduction:\nBegin by introducing the topic and the discussants who will participate in the panel discussion. They could include experts in genetic engineering, bioethics, philosophy, law, sociology, and science policy.\n\n2. Understanding Genetic Engineering:\nBefore delving into ethical implications, provide a brief overview of what genetic engineering is and its different types, such as gene editing, gene therapy, and gene drives. Also, discuss some of the current advancements and applications in genetic engineering.\n\n3. Ethical Implications:\nThe discussion can begin by exploring the ethical implications of genetic engineering. Here are some topics you can consider:\n\n a. Privacy and Informed Consent: Explore the dangers of confidentiality breaches in genetic testing and the transparency issues surrounding the use of genetic information. Also, discuss the importance of informed consent.\n\n b. Discrimination: Discuss the concern that genetic information could be used for discrimination in areas like health insurance or employment.\n\n c. Inequality: Discuss the worry that access to genetic engineering will only be available to the rich, leading to social inequalities.\n\n d. Designer Babies: Examine the ethical implications of genetic engineering for selecting physical, intellectual, or moral traits in children, often referred to as \"designer babies\".\n\n e. Genetic Exploitation: Discuss how genetic engineering could be exploited for extreme human enhancement beyond normal human capacity or for non-therapeutic purposes.\n\n f. Environmental Impact: Look at the potential environmental consequences of genetic engineering, including the risk of unintended spreading of modifications into non-target species, particularly through gene drives.\n\n g. Responsibility: Discuss the ethical responsibility of both scientific researchers and society as a whole in the field of genetic engineering.\n\n4. Legal Implications:\nEach nation or jurisdiction has its own legal rules governing genetic engineering. Overview the current legal landscape of genetic engineering and discuss the potential need for new legislation to regulate the technology.\n\n5. Philosophical Implications:\nDiscuss philosophical questions like the nature of human identity, the meaning of human life, and whether genetic engineering might conflict or enhance the flourishing of human life.\n\n6. Q&A:\nOpen the floor for questions from the audience. This is an excellent opportunity for the panel to address any pressing concerns or issues that were not discussed during the main part of the discussion.\n\nThis panel discussion should spark important conversations about the ethical implications of genetic engineering. The goal is to inform and engage the audience on these complex issues, stimulating further thought and debate outside of the panel discussion.",
"role": "assistant"
}
],
"score_chosen": 8.0,
"score_rejected": 8.0
}

chosen 和 rejected 是有公共 messages 前缀的,底层会区分出公共部分,loss只会算到非公共部分 completion 上。

代码和脚本

以下是官网的代码

1
2
3
4
5
6
7
8
9
10
11
12
# train_dpo.py
from datasets import load_dataset
from trl import DPOConfig, DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")

training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO")
trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()

执行脚本

1
accelerate launch train_dpo.py

我自己用了4张4090显卡(24GB显存),增加了一些减少显存的处理,PS: flash_attention_2 和 lora 与 accelerate 有点冲突,我在网上没有找到对应的解决方法,所以还是用的全量微调。

以下是我的代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# train_dpo.py
import os
os.environ['WANDB_MODE'] = 'offline' # Disable Weights & Biases logging
import torch
from datasets import load_dataset
from trl import DPOConfig, DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2-0.5B-Instruct",
# attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
model.config.use_cache = False # Disable cache for DPO training
model.gradient_checkpointing_enable()

# ====== peft 和 acclerate 同时使用会有问题,暂时注释掉 ======
# from peft import get_peft_model, LoraConfig, TaskType
# peft_config = LoraConfig(
# task_type=TaskType.CAUSAL_LM,
# inference_mode=False,
# r=8,
# lora_alpha=32,
# lora_dropout=0.05,
# target_modules=["q_proj", "v_proj"], # Qwen 使用q_proj/v_proj
# )
# model = get_peft_model(model, peft_config)

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:10000]")
dataset = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = dataset['train']
test_dataset = dataset['test']

training_args = DPOConfig(
output_dir="Qwen2-0.5B-DPO",
per_device_train_batch_size=2,
per_device_eval_batch_size=4,
max_length=1024,
gradient_accumulation_steps=8,
num_train_epochs=3,
report_to=['wandb'],
bf16=True,
learning_rate=5e-5,
lr_scheduler_type="cosine",
logging_steps=1,
eval_strategy="steps",
eval_steps=100,
)
trainer = DPOTrainer(
model=model,
args=training_args,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=test_dataset,
)
trainer.train()

以下是acclerate config

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: true
gpu_ids: 4,5,6,7
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

建议除非服务器配置类似,还是自己执行 accelerate config 创建配置

执行命令

1
accelerate launch --config_file config.yaml train_dpo.py

需要自己调整下 max_length, 我用以下代码检视合理的长度限制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
chosen_lengths = [len(tokenizer.apply_chat_template(item['chosen'])) for item in train_dataset]
rejected_lengths = [len(tokenizer.apply_chat_template(item['rejected'])) for item in train_dataset]
print(f"chosen 平均长度: {sum(chosen_lengths)/len(chosen_lengths):.2f}")
print(f"rejected 平均长度: {sum(rejected_lengths)/len(rejected_lengths):.2f}")

import matplotlib.pyplot as plt

plt.plot(chosen_lengths, label='Chosen Lengths', alpha=0.5)

import numpy as np

for q in [25, 50, 75, 80, 85, 90, 95]:
print(f"{q}th percentile:", np.percentile(chosen_lengths, q))

image-20250716170439870

这个图里是完整数据集的长度,结果可能看起来虚高,平均1500,实际没那么高,最好用百分位数看下

1
2
3
4
5
6
7
25th percentile: 220.0
50th percentile: 399.0
75th percentile: 609.0
80th percentile: 671.0
85th percentile: 748.0
90th percentile: 854.0
95th percentile: 1009.0

大部分数据都是 1000 以下,太长的就截断算了。

1
2
3
4
5
plt.hist(chosen_lengths, bins=100, alpha=0.7)
plt.xlabel('Length')
plt.ylabel('Frequency')
plt.title('Histogram of Chosen Lengths')
plt.show()

用统计柱状图显示更合理点。

image-20250716170901956

训练结果

10k数据

W&B Chart 7_16_2025, 10_55_36 AM

W&B Chart 7_16_2025, 10_55_04 AM

W&B Chart 7_16_2025, 10_56_02 AM

W&B Chart 7_16_2025, 10_55_55 AM

W&B Chart 7_16_2025, 10_55_46 AM

效果在第1个epoch结束时变得很明显,之后reward acc就一直在0.95以上,可惜的是 chosen 的 reward 没有比一开始的值大,rejected 的 reward 明显降了,感觉 loss 对 rejected 的作用挺大的,但对 chosen 的提升不明显。我只用了部分数据, 10k/62k,在四卡4090上跑了20min。

虽然训练集上 acc 很高,但实际测试集不是很高。

W&B Chart 7_16_2025, 4_48_44 PM

62k数据

W&B Chart 16_07_2025, 22_09_29

W&B Chart 16_07_2025, 22_08_56

W&B Chart 16_07_2025, 22_09_09

W&B Chart 16_07_2025, 22_09_15

W&B Chart 16_07_2025, 22_09_22

评估结果上看和10k数据没有多大差别

W&B Chart 16_07_2025, 22_11_02

W&B Chart 16_07_2025, 22_10_24

W&B Chart 16_07_2025, 22_10_38

W&B Chart 16_07_2025, 22_10_46

W&B Chart 16_07_2025, 22_10_57

可能0.5B能力有限吧,reward acc 都没有破 70%

Comments