@ KL 离散度自实现

KL 离散作为 loss 的一部分,续训练中避免新旧模型输出的 logits 概率分布变化过大。

import torch
from torch.nn.functional import kl_div, log_softmax, softmax
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import re
from copy import deepcopy

# 1. 加载数据集
dataset = load_dataset("openai/gsm8k", split="train[:5%]")

# 2. 奖励函数
# 组件:奖励函数,用于评估每一个响应的好坏
def reward_function(completions, answers, **kwargs):
    rewards = []
    pattern = r"<answer>(.*?)</answer>"
    for completion, correct_answer in zip(completions, answers):
        try:
            match = re.search(pattern, completion)
            reward = 1.0 if match and match.group(1).strip() == str(correct_answer) else 0.0
        except:
            reward = 0.0
        rewards.append(reward)
    return rewards

# 3. 数据预处理
def format_dataset(example):
    system_prompt = (
        "Solve the math problem step-by-step, providing reasoning in <think> tags "
        "and the final answer in <answer> tags."
    )
    return {
        "prompt": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": example["question"]}
        ],
        "answer": example["answer"]
    }

dataset = dataset.map(format_dataset)

# 4. 初始化模型和分词器
model_name = "Qwen/Qwen2-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model_old = deepcopy(model)  # 参考模型(旧策略),保持不变
model.to("cuda" if torch.cuda.is_available() else "cpu")
model_old.to("cuda" if torch.cuda.is_available() else "cpu")

# 5. 自定义训练循环
def train_with_kl(dataset, num_epochs=3, batch_size=4, num_generation=4, kl_weight=0.1):
    # Adam 优化器的目标是最小化
    # 这里默认是更新所有参数 0.5B 
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    for epoch in range(num_epochs):
        for i in range(0, len(dataset), batch_size):
            batch = dataset[i:i + batch_size]
            prompts = batch["prompt"]
            answers = batch["answer"]

            # 生成多个候选输出
            completions = []
            new_logits = []
            old_logits = []
            for prompt in prompts:
                inputs = tokenizer([f"{prompt[0]['content']}\n{prompt[1]['content']}"], 
                                    return_tensors="pt", 
                                    padding=True).to(model.device)
                # 组件分组采样:每一个prompt 生成4个响应,形成一个响应组
                for _ in range(num_generation):
                    outputs = model.generate(**inputs, 
                                             max_new_tokens=50, 
                                             do_sample=True, 
                                             return_dict_in_generate=True, 
                                             output_scores=True)
                    # 生成响应,evla值
                    completion = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
                    completions.append(completion)
                    # 获取 logits , 最后一层的 logits
                    logits = torch.stack(outputs.scores, dim=1)[:, -1, :]
                    new_logits.append(logits)
                # 旧策略 logits,通过no_grad() 确保旧policy不参与梯度计算
                with torch.no_grad():
                    ref_outputs = model_old(**inputs, return_dict=True)
                    old_logits.append(ref_outputs.logits[:, -1, :])

            # 计算奖励
            # 评估响应,响应值与真值比较
            rewards = reward_function(completions, answers)

            # 计算 KL 散度
            kl_loss = 0
            for new_logit, old_logit in zip(new_logits, old_logits):
                kl_loss += kl_div(
                    log_softmax(new_logit, dim=-1),
                    softmax(old_logit, dim=-1),
                    reduction="batchmean"
                )
            kl_loss /= num_generation

            # 计算奖励loss
            reward_loss = -torch.tensor(rewards, device=model.device).mean()
            # 总loss
            total_loss = reward_loss + kl_weight * kl_loss
            # 优化
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            print(f"Epoch {epoch+1}, Batch {i//batch_size+1}, Total Loss: {total_loss.item()}, KL Loss: {kl_loss.item()}")

# 6. 开始训练
train_with_kl(dataset, num_epochs=3, batch_size=4, num_generation=4, kl_weight=0.1)

KAQ:梯度更新时,如何体现当前 model(policy)的更新

model_old 表示了旧 policy,model 表示当前 policy, 在梯度更新时,如何体现对当前policy的更新?

在梯度更新时,通过 torch.no_torch() 确保 model_old 的参数不会被更新。只更新当前model的参数。并且更新的是model中可训练参数。包括Embedding 层权重、self-attention 权重、FFN 权重、归一化权重。

model 的初始参数来自预训练模型(Qwen/Qwen2-0.5B-Instruct),在训练过程中,optimizer.step() 通过梯度下降直接修改这些权重和偏置的值。

对于语言模型,输出分布 $ \pi_\theta(o|q) $ 是通过 softmax 函数从 logits 计算得到的。logit_o 是模型最后一层(模型 Head)对输出 token $ o $ 的预测分数,依赖于所有 Transformer 层的权重。所以,当高奖励输出对应的 logits 会增加,导致 $ \pi_\theta(o|q) $ 对这些输出的概率更高。

上述 code 是更新 model 多有参数,参数量有参数量有 0.5B,考虑引入 LoRA,只更新部分参数

KAQ: 如何体现新旧模型的不同,其实是新旧 Policy 的不同

model_old = deepcopy(model) model_old 保持不变,即旧 Policy 的表达无变化。model 是在学习的模型,它表达 Policy 的更新。

model 的行为体现了 Policy。

KAQ:与 ppo 一样,总 loss 是几个 loss 的和

total_loss = reward_loss + kl_weight * kl_loss

KAQ:new/old policy 是什么样的?含有概率分布?什么的概率分布?

new policy 就是模型的直接输出 new logits。然后通过 softmax 等转换为概率分布。

要想正确理解,首先理解解数学题的训练数据是什么样的?既是 LLM,就有分词器,就有下一个 token 的概率分布。生成过程是自回归的,即模型按 token 顺序生成序列(token-by-token),即按 token 顺序逐个生成,每个 token 的生成依赖于之前的 token 和输入 prompt。

更具体讲,自回归逐步生成条件概率 自回归:模型在时间步 $ t $ 生成 token $ t $,基于之前生成的 token $ 1, 2, …, t-1 $ 和输入 prompt 的条件概率: $$ P(token_t | prompt, token_1, …, token_{t-1}) $$ 它是逐步生成的。

api model.generate 就是通过自回归的方式生成序列的。

比如:

输入:inputs 是分词后的 prompt(如“What is 2 + 3?”格式化为 [{"role": "system", ...}, {"role": "user", "What is 2 + 3?"}])。

输出:outputs.sequences 是生成的 token 序列(如 <think>2 + 3 = 5</think><answer>5</answer>),outputs.scores 是每个时间步的 logits。

自回归过程:模型从 prompt 开始,逐个生成 token:

  • 时间步 1:生成 <think>,基于 prompt 的概率 $ P(think | prompt) $.
  • 时间步 2:生成 2,基于 $ P(2 | prompt, think) $.
  • 时间步 3:生成 +,基于 $ P(+ | prompt, think, 2) $.
  • 以此类推,直到生成 </answer> 或达到 max_new_tokens。

每个 token 的生成依赖于之前的所有 token 和 prompt。

生成 5 的时间步,logits 是一个向量(如 [2.1, -0.3, 1.5, 4.8, 0.2, ...],长度等于 vocab_size),表示词汇表中每个 token 的得分。

然后通过 softmax(logits),得到概率分布(如 [0.05, 0.01, 0.03, 0.75, 0.02, ...]),选择概率最高的 token(或采样)作为下一个 token。

熟悉 TRL 的实现后,再看下 GRPO 的 Open R1 实现 【todo】

KAQ:KL divergence 为什么一个是 log_softmax 一个是 softmax

kl-div 作用是起着正则化的作用,即约束策略更新,避免策略过度偏离原始策略。

对新策略(new_policy)使用 log_softmax,而对旧策略(old_policy)使用 softmax,这是因为 KL 散度的数学定义和实现中的数值稳定性要求。也是 PyTorch kl_div 的标准做法

不管使用softmax 还是logsoftmax,目的都是得到概率分布,这是使用KL divergence的前体。

KAQ:KL divergence 数学表达

kl_loss = 0
for new_logit, old_logit in zip(new_logits, old_logits):
    kl_loss += kl_div(
        log_softmax(new_logit, dim=-1),
        softmax(old_logit, dim=-1),
        reduction="batchmean"
    )
kl_loss /= num_generation

KL 散度是信息论中衡量两个概率分布之间的差异:

对于两个离散概率分布 $P$(新策略 $\pi_\theta$)和 $Q$(老策略 $\pi_{\text{ref}}$),KL 散度定义为: $$D_{\text{KL}}(P | Q) = \sum_{i} P(i) \log \left( \frac{P(i)}{Q(i)} \right)$$

  • $P(i)$:新策略在状态 $x$ 下生成 token $i$ 的概率($\pi_\theta(i|x)$)

  • $Q(i)$:老策略生成 token $i$ 的概率($\pi_{\text{ref}}(i|x)$)

直观含义:KL 散度衡量 $P$ 和 $Q$ 的“距离”,值越大表示分布差异越大。我们的目的是让新策略尽可能接近老策略,会略有改进,但不会偏离太多。故 kl loss 符号为正,Adam 优化器最小化 total loss,意味着最小化 kl loss,即让新策略尽可能接近老策略。

就我的 case。Qwen2 模型的词汇表通常约 15 万 具体为 151936,new_logits/old_logits 都是列表,包含 4 (响应个数) 个 [batch_size, vocab_size] 形状的张量:

new_logits = [
    torch.tensor([[2.1, 0.5, ..., -1.2], ..., [1.8, 0.3, ..., -0.9]]),  # 第一个响应
    torch.tensor([[1.9, 0.7, ..., -1.0], ..., [2.0, 0.4, ..., -1.1]]),  # 第二个响应
    torch.tensor([[...]]),
    torch.tensor([[...]])
]
old_logits = [
    torch.tensor([[1.5, 0.6, ..., -0.8], ..., [1.7, 0.5, ..., -0.7]]),  # 老模型输出
    torch.tensor([[...]]),
    torch.tensor([[...]]),
    torch.tensor([[...]])
]

将上述 logits 替换掉 这里 的 $P$ 和 $Q$。然后就理解了

实际上

$$ \pi_\theta(i|x) = P(i | x; \theta) $$

  • $x$:输入上下文(token 序列,例如 “What is the”)。
  • $i$:词汇表中的 token 索引(如 “meaning” 的 ID=123)。
  • $\pi_\theta(i|x)$:给定上下文 $x$,模型生成下一个 token 为 $i$ 的概率。

KAQ:既然是条件概率与x有关,也就是输入的 token 有关。那么概率分布的长度也与输入 token 的长度有关喽?

无关。

$P(i | x; \theta)$ 的长度(即概率分布的大小)由 词汇表大小 决定,与输入上下文 $x$ 的长度无关。

无论 $x$ 是 1 个 token(如 “What”)还是 100 个 token(如 “Solve the math problem step-by-step…"),模型输出的 logits 和 $P(i | x; \theta)$ 始终是 [vocab_size] 长度(151936)。

原因:LLM(如 Qwen2)在每次预测下一个 token 时,基于整个上下文 $x$(通过 Transformer 的注意力机制)生成一个固定大小的概率分布,覆盖整个词汇表。 !!!

$x$ 的长度影响模型的计算过程(如注意力计算的复杂度),但不改变输出分布的长度。 例如:

  • $x = \text{“What”}$(1 token)=> $P(i | x; \theta)$:151936 个概率。
  • $x = \text{“What is the meaning of life”}$(5 tokens)=> $P(i | x; \theta)$:仍为 151936 个概率。

KAQ:这个概率的得到其实是decode中transformer 的输出,所以概率分布的长度是与输入长度无关的?

是的。