Top-K 采样

定义:从模型输出的概率分布中选择概率最高的 $ K $ 个 token 作为候选集,其余 token 概率置零,再归一化后采样。

数学定义:

  • 给定 logits $ z_i $(模型输出),经 softmax 得概率: $$p_i = \frac{\exp(z_i)}{\sum_j \exp(z_j)}$$

  • 选择 $ p_i $ 最大的 $ K $ 个 token,重新归一化: $$p_i’ = \frac{p_i}{\sum_{j \in \text{top-K}} p_j}, \quad p_i = 0 \text{ if } i \notin \text{top-K}$$

作用:

  • 限制词汇量,减少低概率、无意义 token(如拼写错误)。
  • 提高连贯性,降低创造性。

llama.cpp 中实现:

static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
    if (k <= 0) {return;}
    k = std::min(k, (int) cur_p->size);

    // 降序排序
    bucket_sort(cur_p->data);
    
    cur_p->size = k;
}

Top-P 采样

定义:选择累积概率达到 $ p $ 的最小 token 集,过滤低于阈值的 token,再归一化后采样。

数学:

  • 按概率 $ p_i $ 降序排列 token,保留最小的集合 $ S $ 满足: $$\sum_{i \in S} p_i \geq p$$

  • 归一化: $$p_i’ = \frac{p_i}{\sum_{j \in S} p_j}, \quad p_i = 0 \text{ if } i \notin S$$

作用:

  • 动态调整候选集大小,高概率 token 少时保留更多,低概率 token 多时过滤更多。
  • 平衡连贯性和创造性,适合生成多样化但合理输出。

llama.cpp 中实现:

static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
    const auto * ctx = (llama_sampler_top_p *) smpl->ctx;

    if (ctx->p >= 1.0f) {
        return;
    }

    // 计算cur_p 中所以数值的 softmax
    llama_sampler_softmax_impl(cur_p);

    // 计算累计 概率并求和
    float cum_sum = 0.0f;
    size_t last_idx = cur_p->size;

    for (size_t i = 0; i < cur_p->size; ++i) {
        cum_sum += cur_p->data[i].p;
        // 如果当前累计值 超过了 p, 则结束
        if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
            last_idx = i + 1;
            break;
        }
    }

    // 缩小长度
    cur_p->size = last_idx;
}

Temperature

定义:通过缩放 logits 调整概率分布的平滑度,控制随机性。

数学定义:

原始 logits $ z_i $,经温度 $ T $ 缩放后 softmax: $$p_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}$$

llama.cpp 实现:

static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
    if (temp <= 0.0f) {
        // find the token with the highest logit and set the rest to -inf
        size_t max_i = 0;
        float  max_l = cur_p->data[0].logit;

        for (size_t i = 1; i < cur_p->size; ++i) {
            if (cur_p->data[i    ].logit > max_l) {
                cur_p->data[max_i].logit = -INFINITY;
                max_i = i;
                max_l = cur_p->data[i].logit;
            } else {
                cur_p->data[i].logit = -INFINITY;
            }
        }

        return;
    }

    for (size_t i = 0; i < cur_p->size; ++i) {
        cur_p->data[i].logit /= temp;
    }
}