Top-K 采样
定义:从模型输出的概率分布中选择概率最高的 $ K $ 个 token 作为候选集,其余 token 概率置零,再归一化后采样。
数学定义:
给定 logits $ z_i $(模型输出),经 softmax 得概率: $$p_i = \frac{\exp(z_i)}{\sum_j \exp(z_j)}$$
选择 $ p_i $ 最大的 $ K $ 个 token,重新归一化: $$p_i’ = \frac{p_i}{\sum_{j \in \text{top-K}} p_j}, \quad p_i = 0 \text{ if } i \notin \text{top-K}$$
作用:
- 限制词汇量,减少低概率、无意义 token(如拼写错误)。
- 提高连贯性,降低创造性。
llama.cpp 中实现:
static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
if (k <= 0) {return;}
k = std::min(k, (int) cur_p->size);
// 降序排序
bucket_sort(cur_p->data);
cur_p->size = k;
}
Top-P 采样
定义:选择累积概率达到 $ p $ 的最小 token 集,过滤低于阈值的 token,再归一化后采样。
数学:
按概率 $ p_i $ 降序排列 token,保留最小的集合 $ S $ 满足: $$\sum_{i \in S} p_i \geq p$$
归一化: $$p_i’ = \frac{p_i}{\sum_{j \in S} p_j}, \quad p_i = 0 \text{ if } i \notin S$$
作用:
- 动态调整候选集大小,高概率 token 少时保留更多,低概率 token 多时过滤更多。
- 平衡连贯性和创造性,适合生成多样化但合理输出。
llama.cpp 中实现:
static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
if (ctx->p >= 1.0f) {
return;
}
// 计算cur_p 中所以数值的 softmax
llama_sampler_softmax_impl(cur_p);
// 计算累计 概率并求和
float cum_sum = 0.0f;
size_t last_idx = cur_p->size;
for (size_t i = 0; i < cur_p->size; ++i) {
cum_sum += cur_p->data[i].p;
// 如果当前累计值 超过了 p, 则结束
if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
last_idx = i + 1;
break;
}
}
// 缩小长度
cur_p->size = last_idx;
}
Temperature
定义:通过缩放 logits 调整概率分布的平滑度,控制随机性。
数学定义:
原始 logits $ z_i $,经温度 $ T $ 缩放后 softmax: $$p_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}$$
llama.cpp 实现:
static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
if (temp <= 0.0f) {
// find the token with the highest logit and set the rest to -inf
size_t max_i = 0;
float max_l = cur_p->data[0].logit;
for (size_t i = 1; i < cur_p->size; ++i) {
if (cur_p->data[i ].logit > max_l) {
cur_p->data[max_i].logit = -INFINITY;
max_i = i;
max_l = cur_p->data[i].logit;
} else {
cur_p->data[i].logit = -INFINITY;
}
}
return;
}
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].logit /= temp;
}
}