GRPO trainer 训练推理模型

Peng Xia

Qwen 模型在 GSM8K 数据集上的 GRPO 训练脚本

本脚本使用 GRPO(Group Relative Policy Optimization) 方法,在 GSM8K(Grade School Math 8K)数据集上对 Qwen 模型进行训练。

  • 数据集:加载 GSM8K 并应用 one shot 提示。
  • 奖励函数:定义xml tag reward 、format reward、correctness 用于模型训练。
  • PEFT 设置:使用 LoRA(Low-Rank Adaptation)进行高效微调。

环境配置

使用 pip 安装依赖:

1
2
3
4
5
6
7
pip install \
torch \
transformers>=4.28.1 \
peft>=0.3.0 \
trl>=0.3.0 \
datasets>=2.10.1 \
wandb

环境变量设置

运行脚本前,请设置以下环境变量:

变量名 说明 默认值
MODEL_NAME Hugging Face 模型名称 Qwen/Qwen2.5-1.5B-Instruct
OUTPUT_DIR 训练输出目录 outputs/default-GRPO
RUN_NAME 本次训练运行名称 default-GRPO-gsm8k

示例:

1
2
3
export MODEL_NAME="Qwen/Qwen2.5-1.5B-Instruct"
export OUTPUT_DIR="./outputs/default-GRPO"
export RUN_NAME="default-GRPO-gsm8k"

运行脚本

在终端执行:

1
python train_grpo.py

代码结构

  1. 日志设置
    • 配置日志以跟踪训练进度并处理错误。
  2. 数据集准备
    • 加载 GSM8K 数据集。
    • 使用系统提示和可选示例完成一枪提示。
  3. 奖励函数
    • 正确性奖励:根据提取的答案判断正确与否。
    • 格式奖励:奖励符合 XML 格式的输出。
    • XML tag 奖励:根据 XML tag 数量奖励或惩罚多余内容。
  4. 模型加载
    • 加载预训练 Qwen 模型与tokenizer。
    • (可选)使用 LoRA 进行高效微调。
  5. 训练设置
    • 配置 GRPO 训练参数。
    • 使用模型、tokenizer、奖励函数和数据集初始化 GRPOTrainer
  6. 训练流程
    • 运行训练循环。
    • 将训练后的模型和日志保存到指定目录。

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# train_grpo.py
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer
from math_verify import parse, verify
import os
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load and prep dataset
SYSTEM_PROMPT = """
The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think>\n reasoning process here \n</think>\n<answer>\n answer here \n</answer>.
"""

XML_COT_FORMAT = """\
<think>
{reasoning}
</think>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
"""Extracts the answer from XML-formatted text."""
try:
answer = text.split("<answer>")[-1].split("</answer>")[0].strip()
return answer
except IndexError:
logger.warning("Failed to extract answer from XML format.")
return ""

def extract_hash_answer(text: str) -> str | None:
"""Extracts the answer from a hash-formatted string."""
if "####" not in text:
return None
return text.split("####")[1].strip()

# Validate environment variables
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-1.5B-Instruct")
if not MODEL_NAME:
raise ValueError("MODEL_NAME environment variable is not set.")

OUTPUT_DIR = os.getenv("OUTPUT_DIR", "outputs/default-GRPO")
if not OUTPUT_DIR:
raise ValueError("OUTPUT_DIR environment variable is not set.")

RUN_NAME = os.getenv("RUN_NAME", "default-GRPO-gsm8k")
if not RUN_NAME:
raise ValueError("RUN_NAME environment variable is not set.")

# Configurable one-shot prompting
def get_gsm8k_questions(split="train", use_one_shot=False) -> Dataset:
"""Loads and prepares the GSM8K dataset with optional one-shot prompting."""
try:
data = load_dataset('openai/gsm8k', 'main')[split]
except Exception as e:
logger.error(f"Failed to load dataset: {e}")
raise

def format_example(x):
prompt = [{'role': 'system', 'content': SYSTEM_PROMPT}]
if use_one_shot:
prompt.extend([
{'role': 'user', 'content': 'What is the largest single-digit prime number?'},
{'role': 'assistant', 'content': XML_COT_FORMAT.format(
reasoning="9 is divisible by 3 and 8 is divisible by 2, but 7 is prime.",
answer="7 is the largest single-digit prime number."
)}
])
prompt.append({'role': 'user', 'content': x['question']})
return {'prompt': prompt, 'answer': extract_hash_answer(x['answer'])}

return data.map(format_example)

dataset = get_gsm8k_questions(use_one_shot=False)

# Reward functions

def math_verify_answer(answer, golden_answer, **kwargs):
"""Verifies the answer using math_verify."""
gold_parsed = parse(golden_answer)
answer_parsed = parse(answer)
try:
return verify(gold_parsed, answer_parsed)
except Exception as e:
print(f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}")
return False

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
"""Calculates reward based on correctness of the response."""
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
logger.info(f"Question:\n{q}\nAnswer:\n{answer[0]}\nResponse:\n{responses[0]}\nExtracted:\n{extracted_responses[0]}")
return [2.0 if math_verify_answer(r, a) else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
"""Calculates reward if the extracted response is a digit."""
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def format_reward_func(completions, strict=False, **kwargs) -> list[float]:
"""Calculates reward based on XML formatting."""
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$" if strict else r"<think>.*?</think>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r, re.DOTALL | re.MULTILINE) for r in responses]
return [0.5 if match else 0.0 for match in matches]

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
"""Calculates reward based on XML tag counts."""
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]

def count_xml(text) -> float:
"""Counts XML tags and penalizes extra content."""
count = 0.0
if text.count("<think>\n") == 1:
count += 0.125
if text.count("\n</think>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
if text.count("\n</answer>") == 1:
count += 0.125
return count

# Model setup
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
# attn_implementation="flash_attention_2",
# device_map="auto"
).to("cuda")
model.config.use_cache = False
model.gradient_checkpointing_enable()
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

# PEFT config (optional)
peft_config = LoraConfig(
r=16,
lora_alpha=64,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
task_type="CAUSAL_LM",
lora_dropout=0.05,
)

# Training config
training_args = GRPOConfig(
output_dir=OUTPUT_DIR,
run_name=RUN_NAME,
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type='cosine',
logging_steps=1,
bf16=True,
per_device_train_batch_size=1, # Increased from 1
gradient_accumulation_steps=4, # Reduced from 4
num_generations=16, # Reduced from 16
max_prompt_length=256,
max_completion_length=512,
num_train_epochs=1,
save_steps=100,
save_total_limit=2,
max_grad_norm=0.1,
report_to="wandb",
log_on_each_node=False,
)

# Trainer setup
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
xmlcount_reward_func,
format_reward_func, # No need for lambda, just pass the function
# int_reward_func, # the answer shouldn't be an integer, user should receive a detailed answer rather than a simple answer.
correctness_reward_func
],
args=training_args,
train_dataset=dataset,
# peft_config=peft_config # Uncomment if PEFT is working for you
)

# Train the model
try:
trainer.train()
except Exception as e:
logger.error(f"Training failed: {e}")
raise

这是我的config设置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 4
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
dynamo_config:
dynamo_backend: INDUCTOR
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

这是我的脚本内容

1
2
3
4
5
6
7
8
export CUDA_VISIBLE_DEVICES=1,2,3,5  # Set the GPUs to use
export WANDB_MODE=offline # Disable Weights & Biases logging
export MODEL_NAME=Qwen/Qwen2.5-1.5B-Instruct
export RUN_NAME=default-GRPO-fixed_format_reward
export OUTPUT_DIR=outputs/${RUN_NAME}

accelerate launch train_grpo.py \
--config_file default_config.yaml \

训练结果

image-20250802222140659

image-20250802222237144

image-20250802222156074

image-20250802222248158

基于 GRPO trainer 的结果更好点,其 xml reward 上限为 0.5,format reward 上限为 0.5,correctness reward 上限为 2,所以其reward 上限为3(2~2.5之间波动,平均2.25)。

Comments