一个 TRL 实现 GRPO 的实例

from trl import GRPOTrainer, GRPOConfig
from datasets import load_dataset
import re

# 1. 加载数据集: 使用 load_dataset 加载 GSM8K 数据集(数学问题),仅取 5% 数据以简化演示。用于 GRPO 的训练
dataset = load_dataset("openai/gsm8k", split="train[:5%]")  


# 2. 定义奖励函数: 是否与一个指定的 pattern 匹配,
def reward_function(completions, answers, **kwargs):
    """奖励函数:比较生成答案与正确答案,奖励正确格式和答案"""
    rewards = []
    pattern = r"<answer>(.*?)</answer>"  # 假设答案在 <answer> 标签中
    for completion, correct_answer in zip(completions, answers):
        try:
            match = re.search(pattern, completion)
            if match:
                pred_answer = match.group(1).strip()
                # 简单奖励:正确答案得 1.0,错误得 0.0
                reward = 1.0 if pred_answer == str(correct_answer) else 0.0
            else:
                reward = 0.0  # 格式错误
        except:
            reward = 0.0  # 解析失败
        rewards.append(reward)
    return rewards


# 3. 数据预处理,将原始诗句格式转化为模型期望的格式
def format_dataset(example):
    """格式化数据集为 GRPO 所需的 prompt 格式"""
    system_prompt = (
        "Solve the math problem step-by-step, providing reasoning in <think> tags "
        "and the final answer in <answer> tags."
    )
    return {
        "prompt": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": example["question"]}
        ],
        "answer": example["answer"]  # 用于奖励函数
    }


# 应用格式化
dataset = dataset.map(format_dataset)

# 4. 配置 GRPO 训练参数
training_args = GRPOConfig(
    output_dir="./grpo_output",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    learning_rate=1e-5,
    logging_steps=10,
    num_generation=4,  # 每条 prompt 生成 4 个候选答案,即 Group的大小
    max_steps=50,  # 限制步数以便演示
    use_vllm=False,  # 可设为 True 加速(需安装 vLLM)
)

# 5. 初始化 GRPOTrainer
trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",  # 使用小型模型以便演示
    args=training_args,
    train_dataset=dataset,
    reward_funcs=reward_function,
)

# 6. 开始训练
trainer.train()

KAQ: Reward 函数中的 completions

是 list of completions to evaluate。

completions 是模型为给定输入(prompt)生成的输出列表, 模型的输出列表。在 GRPO 中,每个 prompt 会生成多个候选输出(由 GRPOConfignum_generation 参数控制,例如 4,则每个 prompt 生成 4 个 completion)。

completions 是一个字符串列表,每个字符串是模型对数学问题生成的回答,包含推理(<think> 标签)和答案(<answer> 标签)。

例:对于 prompt “What is 2 + 3?”,completions 可以为:

[
    "<think>2 + 3 = 5</think><answer>5</answer>",
    "<think>Add 2 and 3 to get 5</think><answer>5</answer>",
    "<think>2 + 3 = 6</think><answer>6</answer>",
    "<think>Sum is 5</think><answer>5</answer>"
]

强调,这是 模型生成的输出,而非数据集中的样本。

KAQ: Reward 函数中的 answers

list of answers to the problems from the dataset。

answer 是数据集中标注的正解(ground truth),用于评估 completions 的质量。它来自 GSM8K 数据集的 answer 字段,表示数学问题的正确答案。

比如,对于问题“What is 2 + 3?”,answer 可能是 “5”。

KAQ: Reward 中的 completion 和 answer 是如何给隐式出的

completionsanswer 是由 GRPOTrainer 隐式传入 reward 函数的。GRPOTrainer 在训练时会:

  1. 从 train_dataset 获取 prompt(代码中为 example["prompt"])。

  2. 用模型生成多个候选输出 completions,数量由 num_generation 控制。

  3. 从数据集获取对应的正解(answer,代码中为 example["answer"])。

  4. 将 completions 和 answer 作为参数自动传递给 reward 函数。

reward_function(completions, answers, **kwargs)GRPOTrainer 要求的标准接口。

KAQ: 为什么需要样本预处理

format_dataset 函数用于预处理数据集中的每个样本,将其转换为适合 GRPOTrainer 和模型(code中是 Qwen/Qwen2-0.5B-Instruct)处理的格式。具体讲,它将原始 GSM8K 数据集的样本(包含数学问题和答案)转换为包含 promptanswer 字段的结构化格式。

输入原始数据集:GSM8K 数据集中的每个样本是一个字典

{
    "question": "What is 2 + 3?",
    "answer": "5"
}

包含:

  • question:数学问题(字符串,“What is 2 + 3?")。
  • answer:正确答案(字符串,“5”)。

输出模型适配的数据格式。一个字典,包含:

  • prompt:一个列表,包含系统指令和用户问题的对话格式,如[{"role": "system", "content": ...}, {"role": "user", "content": ...}]
  • answer:原始的正确答案。

本质上是你选择的模型需要这样的格式,数据集不符合要求,故做个转换。

load_dataset 中的数据集不必与后续使用的 model 匹配上,只需要按照模型期望给预处理即可。

KAQ: KL divergence 体现在哪里

KL 散度隐式存在GRPOTrainer 的优化过程中,作为正则化项,防止模型生成分布过分偏离初始策略。公式:

Loss = Reward_Loss(completions, rewards) - β * KL_Div(new_policy, old_policy)

β 是 KL 散度的权重,new_policy 是当前模型输出分布,old_policy 是初始或参考分布。

KAQ: 训练结束后的模型和训练的模型区别

更新后的权重使 model 在数学问题解答任务(GSM8K)上表现更好,例如生成更准确的 <answer> 和更清晰的 <think> 推理。

输出分布也不同:

  • 训练前 概率分布(old_logits)反映初始策略(old_policy),可能偏离正确答案
  • 训练后 概率分布(new_logits)倾向于高奖励输出。

KAQ: 这个实例中训练解数学题,与 RLHF 要解决的问题没啥关系吧?

这里的 code 场景是提高模型在数学问题解决上的准确性和结构化输出能力,属于任务特定优化。而 RLHF 目标是 强调对话的流畅性、礼貌性、伦理性和上下文相关性,典型场景是 chat。

KAQ: 当我有多个不同的 reward function 时如何做

GRPOTrainer 的 reward_funcs 参数接受一个奖励函数的列表,但可以通过以下方式集成多个奖励函数。

reward_funcs=[
    xmlcount_reward_func,
    soft_format_reward_func,
    strict_format_reward_func,
    int_reward_func,
    correctness_reward_func
]

todo

熟悉 TRL 的实现后,再看下 GRPO 的 Open R1 实现