/ NLP, DEEPLEARNING

大模型的微调和推理加速

一. 大模型微调

1.1 微调的原因

在业务中直接使用大模型,往往会发现:

  • 模型对prompt非常敏感,需要不断调整,迭代prompt的过程很痛苦
  • 本身能力不满足独有的业务需求
  • 模型的更新/更换 导致相同的prompt却拿到不同的结果

1.2 微调的方式

  • zero-shot prompting(0样本):撰写prompt,迭代prompt,获取答案。无需数据,无需额外开发、部署成本。
  • few-shot prompting(1-10样本):收集样例,迭代样例,获取答案。需要提供少量的样本(正反例)。
  • adaptation(1k-1w样本):收集数据,适配模型,部署模型。
    • 适用百亿参数以下的模型
    • 少量的适配特殊场景/任务的样本,这里其实可以利用更大参数、更优的大模型输出的结果来训练这类“小模型”
  • 更深层次的微调:
    • 领域预训练:用无监督领域数据继续训练基座模型
    • 全参数精调:用有监督领域数据精调基座模型
    • 检索增强:外挂知识库,这里比如一般的RAG,还有比较火热的GraphRAG。
    • 多智能体:多模型/agent共同完成任务。

二. LLM微调工具-Unsloth

git地址:https://github.com/unslothai/unsloth

选择unsloth作为微调工具的理由:

  • 训练更快
  • 显存占用更少
  • 迭代快,模型适配广,目前都能支持Qwen3了

2.1 依赖的安装

这里爬了不少坑,因此记录下

  1. 直接pip安装:pip install unsloth, 这里会默认安装最新的torch版本,而非适配于自身系统cuda版本的torch(用nvcc -V可查看cuda版本)
  2. 重新安装torch,本人这里cuda是12.1的:pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
  3. 重新安装对应cuda版本的xformers:pip install -U xformers --index-url https://download.pytorch.org/whl/cu121
  4. 如果需要微调Qwen3,那么对于transformers有版本要求:pip install transformers==4.51.0 --no-deps
  5. 改了transformers之后,那么也需要对应的trl版本更改:pip install trl==0.17.0 --no-deps
  6. 安装vllm:pip install vllm==0.7 --no-deps
  7. 这里会遇到unsloth和vllm版本不匹配的问题:No module named vllm.lora.peft_helper,简单的处理方案就是到site-packages/unsloth_zoo/vllm_lora_worker_manager.py文件中注释掉所有的和PEFTHelper相关的行即可。
  8. 如果出现No platform detected, vLLM is running on UnspecifiedPlatform,这里只需要安装pip install pynvml==12.0即可。

2.2 微调的例子

这里推荐使用官方的基于Qwen3的微调例子:https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(14B)-Reasoning-Conversational.ipynb

当然在上述微调的例子中,也存了坑,就在训练模型过程中:

  • 对于高版本的trl而言,参数从tokenizer变成了processing_class
  • 如果不指定data_collator的话,会出现padding相关的bug。

下面的代码是修复之后的

# 配置训练参数 - 使用SFTTrainer进行监督微调
trainer = SFTTrainer(
    model=model,  # 要训练的模型
    processing_class=tokenizer,  # 分词器
    train_dataset=combined_dataset,  # 训练数据集
    data_collator=DataCollatorForLanguageModeling(tokenizer,mlm=False), # 定义data_collator,一定是这个不然会报错
    # max_seq_length=1024,  # 训练时使用的最大序列长度
    args=SFTConfig(  # SFT配置参数
        dataset_num_proc = 12,
        dataset_text_field="text",  # 数据集中文本字段的名称
        per_device_train_batch_size=2,  # 每个设备的训练批次大小
        gradient_accumulation_steps=4,  # 梯度累积步数,用于模拟更大的批次大小
        warmup_steps=5,  # 学习率预热步数
        # num_train_epochs = 2,  # 训练轮数,这里注释掉了,使用max_steps代替
        max_steps=50,  # 最大训练步数
        max_seq_length=1024,
        learning_rate=2e-4,  # 学习率,对于长时间训练可以降低到2e-5
        logging_steps=1,  # 日志记录间隔步数
        optim="adamw_8bit",  # 优化器,使用8位AdamW
        weight_decay=0.01,  # 权重衰减率
        lr_scheduler_type="linear",  # 学习率调度器类型
        seed=3407,  # 随机种子
        report_to="none",  # 报告工具,可以设置为"wandb"等
    ),
)

三. 推理加速

对于训练完成后的模型,我们需要选择一个合适的加速推理的框架来加速大模型的推理。

3.1 常用框架

常用的推理加速框架:

  • LLM-tensorRT:在TensorRT推理引擎基础上,针对大模型的推理优化框架
  • vllm:通过PagedAttention高效地管理attention中缓存的张量
  • LLMdeploy:LMDeploy 开发了 Persistent Batch(即 Continuous Batch),Blocked K/V Cache,动态拆分和融合,张量并行,高效的计算 kernel等重要特性。推理性能是 vLLM 的 1.8 倍
  • llama.cpp:针对C/C++的加速推理框架

3.2 vLLM的使用

对于企业级的服务以及高并发场景,vLLM会更加适合。 以下是两种构建服务的方式:

  • 第一种:使用vllm.entrypoints.openai.api_server,适用curl调用
python -m vllm.entrypoints.openai.api_server --model "模型地址" --served-model-name "模型名字" --port 8123 --dtype=half
  • 第二种:使用vllm serve,适用于openai的方式调用apikey
vllm serve "模型地址" --tensor-parallel-size 1 --enforce-eager --port 8124 --dtype float16 --trust-remote-code --served-model-name 模型名字

然后可以使用openai来调用vllm启动模型的接口

from openai import OpenAI
# Set OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8123/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

chat_response = client.chat.completions.create(
    # model="Qwen/Qwen3-8B",
    model="模型名字",
    messages=[
        {"role": "user", "content": "Give me a short introduction to large language models."},
    ],
    # max_tokens=32768,
    temperature=0.6,
    top_p=0.95,
    extra_body={
        "top_k": 20,
    },
)
print("Chat response:", chat_response)