LoRA:高效微调预训练模型的利器

article/2025/7/18 19:01:21

LoRA(Low-Rank Adaptation) 的思想:冻结预训练模型权重,将可训练的低秩分解矩阵注入到Transformer架构的每一层(也可单独配置某一层)中, 从而大大减少在下游任务的可训练参数量。

核心原理

对于预训练权重矩阵 ,LoRA限制了其更新方式,将全参微调的增量参数矩阵  表示为两个参数量更小的矩阵  和  的低秩近似:

其中:

  • •  和  为LoRA低秩适应的权重矩阵

  • • 秩  远小于 (即 )

此时,微调的参数量从原来 的,变成了和的。由于(满足),显著降低了训练参数量。
方法:

图片

优势:

  1. 1. 高效训练:大大减少需要训练的参数数量(只训练 A 和 B,而不是 W₀),降低对GPU内存的需求,缩短训练时间。

  2. 2. 高效存储/切换:对每个新任务,只需要存储和加载小的 LoRA 权重(A 和 B),而不是整个模型的副本,这样就可以为一个基础模型配备多个任务的“适配器”。

  3. 3. 性能保持:LoRA能在降低训练成本的同时,达到接近完全微调的性能。

通过代码理解原理

下列代码拷贝合在一起,更换数据集与模型文件路径后,可直接运行,PEFT版本为0.14.0。重点关注第四步配置LoRA第八步模型推理, 其余代码在往期文章中已有详细介绍。

  • • 数据集:alpaca_data_zh

  • • 预训练模型:bloom-389m-zh

第一步: 导入相关包

import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer
from peft import PromptTuningConfig, get_peft_model, TaskType, PromptTuningInit, PeftModel

第二步: 加载数据集

# 包含 'instruction' (指令), 'input' (可选的额外输入), 'output' (期望的回答)
ds = Dataset.load_from_disk("../data/alpaca_data_zh/") 

第三步: 数据集预处理

将每个样本处理成包含 input_ids, attention_mask, 和 labels 的字典。

tokenizer = AutoTokenizer.from_pretrained("D:\\git\\model-download\\bloom-389m-zh") 
defprocess_func(example):MAX_LENGTH = 256# 构建输入文本:将指令和输入(可选)组合到一起,并添加明确的 "Human:" 和 "Assistant:" 标识符。"\n\nAssistant: " 是提示模型开始生成回答的关键分隔符。prompt = "\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: "# 对输入+提示进行分词,这里暂时不添加特殊token (<s>, </s>),后面要拼接instruction_tokenized = tokenizer(prompt, add_special_tokens=False)# 对期望的输出(回答)进行分词,在回答的末尾加上 `tokenizer.eos_token` (end-of-sentence)。告诉模型生成到这里就可以结束。response_tokenized = tokenizer(example["output"] + tokenizer.eos_token, add_special_tokens=False)# 将输入提示和回答的 token IDs 拼接起来,形成完整的输入序列 input_idsinput_ids = instruction_tokenized["input_ids"] + response_tokenized["input_ids"]# attention_mask 用于告诉模型哪些 token 是真实的、需要关注的,哪些是填充的(padding)。attention_mask = instruction_tokenized["attention_mask"] + response_tokenized["attention_mask"]# 创建标签 (labels):这是模型需要学习预测的目标,因为只希望模型学习预测 "Assistant:" 后面的回答部分,所以将输入提示部分的标签设置为 -100,损失函数自动忽略标签为 -100 的 token,不计算它们的损失。labels = [-100] * len(instruction_tokenized["input_ids"]) + response_tokenized["input_ids"]# 截断iflen(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]# 返回处理好的数据return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}
#  .map() 方法将处理函数应用到整个数据集的所有样本上。
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)    #  `remove_columns` 会移除原始的列,只保留 process_func 返回的新列。    
print("\n检查第2条数据处理结果:")
print("输入序列 (input_ids解码):", tokenizer.decode(tokenized_ds[1]["input_ids"]))
target_labels = list(filter(lambda x: x != -100, tokenized_ds[1]["labels"])) # 过滤掉 -100,看看模型真正需要预测的标签是什么
print("标签序列 (labels解码,过滤-100后):", tokenizer.decode(target_labels))

第四步: 加载预训练模型和配置LoRA

1. 配置LoRA(关键步骤):
  • • 选择在哪些层上应用LoRA,target_modules=".*\\.1.*query_key_value": 用来指定要在哪些模块(层)上应用LoRA适配器,.*\\.1.*query_key_value 匹配的是模型中名字包含 ".1." (通常指第一层Transformer块)并且是 "query_key_value" (在某些模型结构中,QKV是合并在一起的)的线性层。如果指定的是 ["query_key_value"],则表示适配所有层的QKV映射,从而调整模型注意力机制中的参数。

  • • r:低秩分解的秩,默认值通常是8或16。r 越小,引入的参数越少,但会牺牲一些性能;r 越大,参数越多,可能性能更好,但效率增益降低。

2. 使用PEFT将LoRA应用到模型:
  • • get_peft_model接收原始模型和LoRA配置:
    a. 冻结原始模型所有参数。
    b. 根据 target_modules 在指定层旁边添加LoRA适配器层(可训练的小矩阵A和B)。
    c. 如果指定了 modules_to_save,则会解冻这些模块的参数,使其也可训练。

  • • 返回的 model 是一个 PeftModel 对象,封装了原始模型和LoRA适配器。

model = AutoModelForCausalLM.from_pretrained("D:\\git\\transformers-code-master\\model-download\\bloom-389m-zh")# LoRA 配置
config = LoraConfig(task_type=TaskType.CAUSAL_LM,target_modules=".*\\.1.*query_key_value", # 适配第1层的QKV合并层(根据模型结构调整),或者更通用的写法,target_modules=["query_key_value"],适配所有层的QKVr=8,  # 显式设置LoRA的秩 (rank),可以调整,比如 8, 16, 32lora_alpha=32, # LoRA缩放因子,通常设为 r 的2倍或4倍lora_dropout=0.1, # LoRA层的dropout率modules_to_save=["word_embeddings"] # 除了LoRA参数外,有时需要训练(并保存)词嵌入层(`word_embeddings`)。有时调整词嵌入对适应新词汇或领域有帮助。如果不需要,可以去掉这个参数。
)
print("\nLoRA配置详情:", config)# 使用 PEFT 应用 LoRA 到模型
model = get_peft_model(model, config)print("\n应用LoRA后的模型可训练参数:")
# 打印模型中哪些参数是可训练的(主要是LoRA的A、B矩阵和word_embeddings)
for name, parameter in model.named_parameters():if parameter.requires_grad:print(name)print("\n可训练参数统计:")
model.print_trainable_parameters() # 关键:观察可训练参数占比!
3. 检查可训练参数:
  • • 可训练参数统计:原训练规模为3.9亿参数,LORA后,训练参数规模为0.43亿,训练参数规模大大降低。

trainable params: 43,815,936 || all params: 389,584,896 || trainable%: 11.2468,

第五步: 配置训练参数

args = TrainingArguments(output_dir="./chatbot_lora_tuned", # 输出目录per_device_train_batch_size=1,      # 每个GPU的批大小gradient_accumulation_steps=8,      # 梯度累积步数,实际批大小 = 1 * 8 = 8logging_steps=10,                   # 每10步记录一次日志num_train_epochs=1,                 # 训练轮数save_strategy="epoch",              # 每个epoch保存一次模型learning_rate=1e-4,                 # 学习率warmup_steps=100,                   # 预热步数# 可以添加更多参数,如 weight_decay, evaluation_strategy 等
)

第六步: 创建训练器

trainer = Trainer(model=model,args=args,tokenizer=tokenizer,train_dataset=tokenized_ds,data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True), # 使用Seq2Seq的整理器,将批次内的数据动态填充(padding)到相同长度,确保张量形状一致。
)

第七步: 模型训练

只有在LoRA配置中指定的可训练参数
(LoRA的A、B矩阵,以及 modules_to_save 中的层)会被优化器更新,原始模型权重保持冻结。

trainer.train()

第八步: 模型推理 (使用微调后的模型)

推理时,PeftModel 会自动将 LoRA 适配器(B*A)的效果加到原始权重上,无需手动操作。

from peft import PeftModel
# 加载基础模型
model = AutoModelForCausalLM.from_pretrained("D:\\git\\transformers-code-master\\model-download\\bloom-389m-zh")
tokenizer = AutoTokenizer.from_pretrained("D:\\git\\transformers-code-master\\model-download\\bloom-389m-zh")
print("基础模型加载完成:", type(model))# 加载Lora模型
p_model = PeftModel.from_pretrained(model, model_id="./chatbot/checkpoint-3357/")
print("Lora模型加载结果:", p_model)# 生成对话
ipt = tokenizer("Human: {}\\n{}".format("考试有哪些技巧?", "").strip() + "\\n\\nAssistant: ", return_tensors="pt")
generated = p_model.generate(**ipt, do_sample=False)
response = tokenizer.decode(generated[0], skip_special_tokens=True)
print("生成的回答:", response)# 模型合并
merge_model = p_model.merge_and_unload()
print("合并后的模型结构:", merge_model)# 验证合并模型效果
ipt_merged = tokenizer("Human: {}\\n{}".format("考试有哪些技巧?", "").strip() + "\\n\\nAssistant: ", return_tensors="pt")
merged_response = tokenizer.decode(merge_model.generate(**ipt_merged,max_length=1024,        # 保持总长度限制max_new_tokens=500,     # 新增关键参数:控制新生成token数量do_sample=True,    #启用采样,让生成结果更多样化(否则可能总是生成最可能的词)。temperature=0.8,       # 提高随机性 (0.7-1.0)top_p=0.9,             # 核采样增加多样性repetition_penalty=1.2,# 抑制重复同时允许合理扩展early_stopping=False    # 防止过早停止)[0], skip_special_tokens=True
)
print("合并模型生成的回答:", merged_response)# 保存完整模型
merge_model.save_pretrained("./chatbot/merge_model")
print("模型已保存至:", "./chatbot/merge_model")

总结

  1. 1. LoRA的核心思想:
    冻结原始模型参数,只在特定层旁边添加两个小的矩阵(A和B)并进行训练,用 B*A 近似模拟所需的模型调整。

  2. 2. 代码体现

    • • LoraConfig :定义哪些层要加适配器 (target_modules),适配器的秩 r 是多少等。

    • • get_peft_model 把 LoRA 配置应用到原始模型上,返回 PeftModel

    • • model.print_trainable_parameters() :可训练参数大大减少。

    • • 训练时 (trainer.train()) 只更新这些少量参数。

    • • 推理时 (model.generate()):自动结合原始权重和LoRA适配器的效果。

  3. 3. 优势: 训练快、省显存、模型存储小、任务切换方便,效果有保障。


http://www.hkcw.cn/article/QOYfJclYgw.shtml

相关文章

越界检测算法AI智能分析网关V4打造多场景化的应用解决方案

一、方案概述 随着社会发展&#xff0c;传统安防系统在复杂环境下暴露出误报率高、响应慢等短板。AI智能分析网关V4依托先进算法与强大算力&#xff0c;实现周界区域精准监测与智能分析&#xff0c;显著提升入侵防范效能。本方案通过部署该网关及其越界检测功能&#xff0c;为…

使用SkiaSharp打造专业级12导联心电图查看器:性能与美观兼具的可视化实践

前言 欢迎关注dotnet研习社&#xff0c;今天我们研究的Google Skia图形库的.NET绑定SkiaSharp图形库。 在医疗软件开发领域&#xff0c;心电图(ECG)数据的可视化是一个既有挑战性又极其重要的任务。作为开发者&#xff0c;我们需要创建既专业又直观的界面来展示复杂的生物医学…

24位高精度数据采集卡NET8860音频振动信号采集监测满足自动化测试应用现场的多样化需求

NET8860 高分辨率数据采集卡技术解析 阿尔泰科技的NET8860是一款高性能数据采集卡&#xff0c;具备8路同步模拟输入通道和24bit分辨率&#xff0c;适用于高精度信号采集场景。其输入量程覆盖10V、5V、2V、1V&#xff0c;采样速率高达256KS/s&#xff0c;能够满足多种工业与科研…

2025年05月30日Github流行趋势

项目名称&#xff1a;agenticSeek 项目地址url&#xff1a;https://github.com/Fosowl/agenticSeek项目语言&#xff1a;Python历史star数&#xff1a;13040今日star数&#xff1a;1864项目维护者&#xff1a;Fosowl, steveh8758, klimentij, ganeshnikhil, apps/copilot-pull-…

PCB设计实践(三十一)PCB设计中机械孔的合理设计与应用指南

一、机械孔的基本概念与分类 机械孔是PCB设计中用于实现机械固定、结构支撑、散热及电气连接的关键结构元件&#xff0c;其分类基于功能特性、制造工艺和应用场景的差异&#xff0c;主要分为以下几类&#xff1a; 1. 金属化机械孔 通过电镀工艺在孔内壁形成导电层&#xff0c;…

TC/BC/OC P2P/E2E有啥区别?-PTP协议基础概念介绍

前言 时间同步网络中的每个节点&#xff0c;都被称为时钟&#xff0c;PTP协议定义了三种基本时钟节点。本文将介绍这三种类型的时钟&#xff0c;以及gPTP在同步机制上与其他机制的区别 本系列文章将由浅入深的带你了解gPTP&#xff0c;欢迎关注 时钟类型 在PTP中我们将各节…

五.MySQL表的约束

1.not null空属性 和 default缺省值 两个值&#xff1a;null&#xff08;默认的&#xff09;和not null(不为空) 元素可以分为两类 1.not null 不能为空的&#xff0c;这种没有默认default 要手动设定&#xff0c;我们必须插入数据而且不能为NULL。但我们插入数据有两种方式 1.…

4.Haproxy搭建Web群集

一.案例分析 1.案例概述 Haproxy是目前比较流行的一种群集调度工具&#xff0c;同类群集调度工具有很多&#xff0c;包括LVS、Nginx&#xff0c;LVS性能最好&#xff0c;但是搭建相对复杂&#xff1b;Nginx的upstream模块支持群集功能&#xff0c;但是对群集节点健康检查功能…

NewsNow:免费好用的实时新闻聚合平台,让信息获取更优雅(深度解析、部署攻略)

名人说&#xff1a;博观而约取&#xff0c;厚积而薄发。——苏轼《稼说送张琥》 创作者&#xff1a;Code_流苏(CSDN)&#xff08;一个喜欢古诗词和编程的Coder&#x1f60a;&#xff09; 目录 一、NewsNow项目概览1. 项目核心亮点2. 技术架构特点 二、核心功能深度解析1. 智能新…

论文阅读笔记——FLOW MATCHING FOR GENERATIVE MODELING

Flow Matching 论文 扩散模型&#xff1a;根据中心极限定理&#xff0c;对原始图像不断加高斯噪声&#xff0c;最终将原始信号破坏为近似的标准正态分布。这其中每一步都构造为条件高斯分布&#xff0c;形成离散的马尔科夫链。再通过逐步去噪得到原始图像。 Flow matching 采取…

【leetcode】02.07. 链表相交

链表相交 题目代码1. 计算两个链表的长度2. 双指针 题目 02.07. 链表相交 给你两个单链表的头节点 headA 和 headB &#xff0c;请你找出并返回两个单链表相交的起始节点。如果两个链表没有交点&#xff0c;返回 null 。 图示两个链表在节点 c1 开始相交&#xff1a; 代码 …

文字转图片的字符画生成工具

软件介绍 今天要介绍的这款软件可以将文字转换成图片的排列形式&#xff0c;非常适合需要将文字图形化的场景&#xff0c;建议有需要的朋友收藏。 软件名称与用途 这款软件名为《字符画大师》&#xff0c;是一款在网吧等场所非常流行的聊天辅助工具&#xff0c;其主要功能就…

Bitlocker密钥提取之SYSTEM劫持

该漏洞编号CVE-2024-20666&#xff0c;本文实现复现过程&#xff0c;Windows系统版本如下 简介 从Windows10&#xff08;th1&#xff09;开始&#xff0c;微软在winload模块中&#xff0c;增加了systemdatadevice字段值的获取&#xff0c;该字段值存储在BCD引导配置文件中。当…

明场检测与暗场检测的原理

知识星球里的学员问&#xff1a;明场检测与暗场检测原理上有什么区别&#xff1f; 如上图&#xff0c; 明场检测&#xff08;Bright-field Inspection&#xff09; 工作原理&#xff1a; 光线从近乎垂直照射到样品表面。 如果表面平整、无缺陷&#xff0c;光线会直接反射回镜…

STL解析——vector的使用及模拟实现

目录 1.使用篇 1.1默认成员函数 1.2其他常用接口 2.模拟实现 2.1源码逻辑参考 2.2基本函数实现 2.3增 2.4删 2.5迭代器失效 2.6拷贝构造级其他接口 2.7赋值运算符重载(现代写法) 2.8深层次拷贝优化 3.整体代码 在C中vector算正式STL容器&#xff0c;功能可以类比于…

day2实训

实训任务1 FTPASS wireshark打开 实训任务2 数据包中的线索 解码的图片 实训任务3 被嗅探的流量 过滤http&#xff0c;追踪post的http流 实训任务6 小明的保险箱 winhex打开

Window10+ 安装 go环境

一、 下载 golang 源码&#xff1a; 去官网下载&#xff1a; https://go.dev/dl/ &#xff0c;当前时间&#xff08;2025-05&#xff09;最新版本如下: 二、 首先在指定的磁盘下创建几个文件夹 比如在 E盘创建 software 文件夹 E:\SoftWare,然后在创建如下几个文件夹 E:\S…

8.5 Q1|广州医科大学CHARLS发文 甘油三酯葡萄糖指数累积变化与 0-3期心血管-肾脏-代谢综合征人群中风发生率的相关性

1.第一段-文章基本信息 文章题目&#xff1a;Association between cumulative changes of the triglyceride glucose index and incidence of stroke in a population with cardiovascular-kidney-metabolic syndrome stage 0-3: a nationwide prospective cohort study 中文标…

重读《人件》Peopleware -(13)Ⅱ 办公环境 Ⅵ 电话

当你开始收集有关工作时间质量的数据时&#xff0c;你的注意力自然会集中在主要的干扰源之一——打进来的电话。一天内接15个电话并不罕见。虽然这看似平常&#xff0c;但由于重新沉浸所需的时间&#xff0c;它可能会耗尽你几乎一整天的时间。当一天结束时&#xff0c;你会纳闷…

ARXML解析与可视化工具

随着汽车电子行业的快速发展,AUTOSAR标准在车辆软件架构中发挥着越来越重要的作用。然而,传统的ARXML文件处理工具往往存在高昂的许可费用、封闭的数据格式和复杂的使用门槛等问题。本文介绍一种基于TXT格式输出的ARXML解析方案,为开发团队提供了一个高效的替代解决方案。 …