一个 TRL 实现 GRPO 的实例
from trl import GRPOTrainer, GRPOConfig
from datasets import load_dataset
import re
# 1. 加载数据集: 使用 load_dataset 加载 GSM8K 数据集(数学问题),仅取 5% 数据以简化演示。用于 GRPO 的训练
dataset = load_dataset("openai/gsm8k", split="train[:5%]")
# 2. 定义奖励函数: 是否与一个指定的 pattern 匹配,
def reward_function(completions, answers, **kwargs):
"""奖励函数:比较生成答案与正确答案,奖励正确格式和答案"""
rewards = []
pattern = r"<answer>(.*?)</answer>" # 假设答案在 <answer> 标签中
for completion, correct_answer in zip(completions, answers):
try:
match = re.search(pattern, completion)
if match:
pred_answer = match.group(1).strip()
# 简单奖励:正确答案得 1.0,错误得 0.0
reward = 1.0 if pred_answer == str(correct_answer) else 0.0
else:
reward = 0.0 # 格式错误
except:
reward = 0.0 # 解析失败
rewards.append(reward)
return rewards
# 3. 数据预处理,将原始诗句格式转化为模型期望的格式
def format_dataset(example):
"""格式化数据集为 GRPO 所需的 prompt 格式"""
system_prompt = (
"Solve the math problem step-by-step, providing reasoning in <think> tags "
"and the final answer in <answer> tags."
)
return {
"prompt": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": example["question"]}
],
"answer": example["answer"] # 用于奖励函数
}
# 应用格式化
dataset = dataset.map(format_dataset)
# 4. 配置 GRPO 训练参数
training_args = GRPOConfig(
output_dir="./grpo_output",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
learning_rate=1e-5,
logging_steps=10,
num_generation=4, # 每条 prompt 生成 4 个候选答案,即 Group的大小
max_steps=50, # 限制步数以便演示
use_vllm=False, # 可设为 True 加速(需安装 vLLM)
)
# 5. 初始化 GRPOTrainer
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct", # 使用小型模型以便演示
args=training_args,
train_dataset=dataset,
reward_funcs=reward_function,
)
# 6. 开始训练
trainer.train()
KAQ: Reward 函数中的 completions
是 list of completions to evaluate。
completions 是模型为给定输入(prompt)生成的输出列表, 模型的输出列表。在 GRPO 中,每个 prompt 会生成多个候选输出(由 GRPOConfig 的 num_generation 参数控制,例如 4,则每个 prompt 生成 4 个 completion)。
completions 是一个字符串列表,每个字符串是模型对数学问题生成的回答,包含推理(<think> 标签)和答案(<answer> 标签)。
例:对于 prompt “What is 2 + 3?”,completions 可以为:
[
"<think>2 + 3 = 5</think><answer>5</answer>",
"<think>Add 2 and 3 to get 5</think><answer>5</answer>",
"<think>2 + 3 = 6</think><answer>6</answer>",
"<think>Sum is 5</think><answer>5</answer>"
]
强调,这是 模型生成的输出,而非数据集中的样本。
KAQ: Reward 函数中的 answers
list of answers to the problems from the dataset。
answer 是数据集中标注的正解(ground truth),用于评估 completions 的质量。它来自 GSM8K 数据集的 answer 字段,表示数学问题的正确答案。
比如,对于问题“What is 2 + 3?”,answer 可能是 “5”。
KAQ: Reward 中的 completion 和 answer 是如何给隐式出的
completions 和 answer 是由 GRPOTrainer 隐式传入 reward 函数的。GRPOTrainer 在训练时会:
从 train_dataset 获取 prompt(代码中为
example["prompt"])。用模型生成多个候选输出 completions,数量由
num_generation控制。从数据集获取对应的正解(answer,代码中为
example["answer"])。将 completions 和 answer 作为参数自动传递给 reward 函数。
reward_function(completions, answers, **kwargs) 是 GRPOTrainer 要求的标准接口。
KAQ: 为什么需要样本预处理
format_dataset 函数用于预处理数据集中的每个样本,将其转换为适合 GRPOTrainer 和模型(code中是 Qwen/Qwen2-0.5B-Instruct)处理的格式。具体讲,它将原始 GSM8K 数据集的样本(包含数学问题和答案)转换为包含 prompt 和 answer 字段的结构化格式。
输入原始数据集:GSM8K 数据集中的每个样本是一个字典:
{
"question": "What is 2 + 3?",
"answer": "5"
}
包含:
question:数学问题(字符串,“What is 2 + 3?")。answer:正确答案(字符串,“5”)。
输出模型适配的数据格式。一个字典,包含:
prompt:一个列表,包含系统指令和用户问题的对话格式,如[{"role": "system", "content": ...}, {"role": "user", "content": ...}]answer:原始的正确答案。
本质上是你选择的模型需要这样的格式,数据集不符合要求,故做个转换。
load_dataset 中的数据集不必与后续使用的 model 匹配上,只需要按照模型期望给预处理即可。
KAQ: KL divergence 体现在哪里
KL 散度隐式存在于 GRPOTrainer 的优化过程中,作为正则化项,防止模型生成分布过分偏离初始策略。公式:
Loss = Reward_Loss(completions, rewards) - β * KL_Div(new_policy, old_policy)
β 是 KL 散度的权重,new_policy 是当前模型输出分布,old_policy 是初始或参考分布。
KAQ: 训练结束后的模型和训练的模型区别
更新后的权重使 model 在数学问题解答任务(GSM8K)上表现更好,例如生成更准确的 <answer> 和更清晰的 <think> 推理。
输出分布也不同:
- 训练前 概率分布(old_logits)反映初始策略(old_policy),可能偏离正确答案
- 训练后 概率分布(new_logits)倾向于高奖励输出。
KAQ: 这个实例中训练解数学题,与 RLHF 要解决的问题没啥关系吧?
这里的 code 场景是提高模型在数学问题解决上的准确性和结构化输出能力,属于任务特定优化。而 RLHF 目标是 强调对话的流畅性、礼貌性、伦理性和上下文相关性,典型场景是 chat。
KAQ: 当我有多个不同的 reward function 时如何做
GRPOTrainer 的 reward_funcs 参数接受一个奖励函数的列表,但可以通过以下方式集成多个奖励函数。
reward_funcs=[
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func
]
todo
熟悉 TRL 的实现后,再看下 GRPO 的 Open R1 实现