使用 GRPO 微调一个模型
使用到 PEFT(Parameter-Efficient Fine-Tuning)库
# pip install -qqq datasets==3.2.0 transformers==4.47.1 trl==0.14.0 peft==0.14.0
# pip install -qqq accelerate==1.2.1 bitsandbytes==0.45.2 wandb==0.19.7 --progress-bar off
# pip install -qqq flash-attn --no-build-isolation --progress-bar off
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
import wandb
wandb.login()
# 短篇小说训练数据
dataset = load_dataset("mlabonne/smoltldr")
print(dataset)
# 使用小模型
model_id = "HuggingFaceTB/SmolLM-135M-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype="auto",
device_map="auto",
attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Load LoRA
print(f"Before lora: ", model.print_trainable_parameters())
lora_config = LoraConfig(
task_type="CAUSAL_LM",
r=16,
lora_alpha=32,
target_modules="all-linear",
)
model = get_peft_model(model, lora_config)
print(f"After lora: ", model.print_trainable_parameters())
# Reward function
ideal_length = 50
def reward_len(completions, **kwargs):
return [-abs(ideal_length - len(completion)) for completion in completions]
# Training arguments
training_args = GRPOConfig(
output_dir="GRPO",
learning_rate=2e-5,
per_device_train_batch_size=8,
gradient_accumulation_steps=2,
max_prompt_length=512,
max_completion_length=96,
num_generations=8,
optim="adamw_8bit",
num_train_epochs=1,
bf16=True,
report_to=["wandb"],
remove_unused_columns=False,
logging_steps=1,
)
# 开始 Trainer
trainer = GRPOTrainer(
model=model,
reward_funcs=[reward_len],
args=training_args,
train_dataset=dataset["train"],
)
# Train model
wandb.init(project="GRPO")
trainer.train()
# 保存并发布
merged_model = trainer.model.merge_and_unload()
merged_model.push_to_hub(
"SmolLM-135M-Instruct-GRPO-135M", private=False, tags=["GRPO", "Reasoning-Course"]
)
解释训练结果
GRPOTrainer 记录了奖励函数的奖励值、损失值以及其他一系列指标
- 随训练 step 奖励函数的奖励值逐渐接近 0。这表明模型正在学习生成正确长度的文本。
- Training loss 随 step 增加,这表明模型学习生成更符合奖励函数的文本,导致其与初始策略的偏差越来越大。
使用新模型生成文本
prompt = """
# A long document about the Cat
The cat (Felis catus), also referred to as the domestic cat or house cat, is a small
domesticated carnivorous mammal. It is the only domesticated species of the family Felidae.
Advances in archaeology and genetics have shown that the domestication of the cat occurred
in the Near East around 7500 BC. It is commonly kept as a pet and farm cat, but also ranges
freely as a feral cat avoiding human contact. It is valued by humans for companionship and
its ability to kill vermin. Its retractable claws are adapted to killing small prey species
such as mice and rats. It has a strong, flexible body, quick reflexes, and sharp teeth,
and its night vision and sense of smell are well developed. It is a social species,
but a solitary hunter and a crepuscular predator. Cat communication includes
vocalizations—including meowing, purring, trilling, hissing, growling, and grunting—as
well as body language. It can hear sounds too faint or too high in frequency for human ears,
such as those made by small mammals. It secretes and perceives pheromones.
"""
messages = [
{"role": "user", "content": prompt},
]# Generate text
from transformers import pipeline
generator = pipeline("text-generation", model="SmolGRPO-135M")
## Or use the model and tokenizer we defined earlier
# generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
generate_kwargs = {
"max_new_tokens": 256,
"do_sample": True,
"temperature": 0.5,
"min_p": 0.1,
}
generated_text = generator(messages, generate_kwargs=generate_kwargs)
print(generated_text)
KAQ:model = get_peft_model(model, lora_config)
将 LoRA 适配器(矩阵 $ A $ 和 $ B $)应用到 model 的指定权重矩阵(由 lora_config 的 target_modules 定义),冻结原始权重,仅允许适配器参数可训练。返回的是一个 PeftModel 实例,包含原始模型和 LoRA 适配器,适用于高效微调。
KAQ:merged_model = trainer.model.merge_and_unload()
接上述,得到原始模型和 LoRA 适配器后,需要合并两者。即将LoRA 适配器的权重更新合并到原始模型的权重中去。得到微调后的新模型。