1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
| import re import torch from datasets import load_dataset, Dataset from transformers import AutoTokenizer, AutoModelForCausalLM from peft import LoraConfig from trl import GRPOConfig, GRPOTrainer from math_verify import parse, verify import os import logging
logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__)
SYSTEM_PROMPT = """ The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think>\n reasoning process here \n</think>\n<answer>\n answer here \n</answer>. """
XML_COT_FORMAT = """\ <think> {reasoning} </think> <answer> {answer} </answer> """
def extract_xml_answer(text: str) -> str: """Extracts the answer from XML-formatted text.""" try: answer = text.split("<answer>")[-1].split("</answer>")[0].strip() return answer except IndexError: logger.warning("Failed to extract answer from XML format.") return ""
def extract_hash_answer(text: str) -> str | None: """Extracts the answer from a hash-formatted string.""" if "####" not in text: return None return text.split("####")[1].strip()
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-1.5B-Instruct") if not MODEL_NAME: raise ValueError("MODEL_NAME environment variable is not set.")
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "outputs/default-GRPO") if not OUTPUT_DIR: raise ValueError("OUTPUT_DIR environment variable is not set.")
RUN_NAME = os.getenv("RUN_NAME", "default-GRPO-gsm8k") if not RUN_NAME: raise ValueError("RUN_NAME environment variable is not set.")
def get_gsm8k_questions(split="train", use_one_shot=False) -> Dataset: """Loads and prepares the GSM8K dataset with optional one-shot prompting.""" try: data = load_dataset('openai/gsm8k', 'main')[split] except Exception as e: logger.error(f"Failed to load dataset: {e}") raise
def format_example(x): prompt = [{'role': 'system', 'content': SYSTEM_PROMPT}] if use_one_shot: prompt.extend([ {'role': 'user', 'content': 'What is the largest single-digit prime number?'}, {'role': 'assistant', 'content': XML_COT_FORMAT.format( reasoning="9 is divisible by 3 and 8 is divisible by 2, but 7 is prime.", answer="7 is the largest single-digit prime number." )} ]) prompt.append({'role': 'user', 'content': x['question']}) return {'prompt': prompt, 'answer': extract_hash_answer(x['answer'])}
return data.map(format_example)
dataset = get_gsm8k_questions(use_one_shot=False)
def math_verify_answer(answer, golden_answer, **kwargs): """Verifies the answer using math_verify.""" gold_parsed = parse(golden_answer) answer_parsed = parse(answer) try: return verify(gold_parsed, answer_parsed) except Exception as e: print(f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}") return False
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: """Calculates reward based on correctness of the response.""" responses = [completion[0]['content'] for completion in completions] q = prompts[0][-1]['content'] extracted_responses = [extract_xml_answer(r) for r in responses] logger.info(f"Question:\n{q}\nAnswer:\n{answer[0]}\nResponse:\n{responses[0]}\nExtracted:\n{extracted_responses[0]}") return [2.0 if math_verify_answer(r, a) else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions, **kwargs) -> list[float]: """Calculates reward if the extracted response is a digit.""" responses = [completion[0]['content'] for completion in completions] extracted_responses = [extract_xml_answer(r) for r in responses] return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def format_reward_func(completions, strict=False, **kwargs) -> list[float]: """Calculates reward based on XML formatting.""" pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$" if strict else r"<think>.*?</think>\s*<answer>.*?</answer>" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r, re.DOTALL | re.MULTILINE) for r in responses] return [0.5 if match else 0.0 for match in matches]
def xmlcount_reward_func(completions, **kwargs) -> list[float]: """Calculates reward based on XML tag counts.""" contents = [completion[0]["content"] for completion in completions] return [count_xml(c) for c in contents]
def count_xml(text) -> float: """Counts XML tags and penalizes extra content.""" count = 0.0 if text.count("<think>\n") == 1: count += 0.125 if text.count("\n</think>\n") == 1: count += 0.125 if text.count("\n<answer>\n") == 1: count += 0.125 if text.count("\n</answer>") == 1: count += 0.125 return count
try: model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.bfloat16, ).to("cuda") model.config.use_cache = False model.gradient_checkpointing_enable() except Exception as e: logger.error(f"Failed to load model: {e}") raise
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) tokenizer.pad_token = tokenizer.eos_token
peft_config = LoraConfig( r=16, lora_alpha=64, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"], task_type="CAUSAL_LM", lora_dropout=0.05, )
training_args = GRPOConfig( output_dir=OUTPUT_DIR, run_name=RUN_NAME, learning_rate=5e-6, adam_beta1=0.9, adam_beta2=0.99, weight_decay=0.1, warmup_ratio=0.1, lr_scheduler_type='cosine', logging_steps=1, bf16=True, per_device_train_batch_size=1, gradient_accumulation_steps=4, num_generations=16, max_prompt_length=256, max_completion_length=512, num_train_epochs=1, save_steps=100, save_total_limit=2, max_grad_norm=0.1, report_to="wandb", log_on_each_node=False, )
trainer = GRPOTrainer( model=model, processing_class=tokenizer, reward_funcs=[ xmlcount_reward_func, format_reward_func, correctness_reward_func ], args=training_args, train_dataset=dataset, )
try: trainer.train() except Exception as e: logger.error(f"Training failed: {e}") raise
|