新闻中心

EEPW首页>智能计算>设计应用> 拿LoRA代码来微调大模型

拿LoRA代码来微调大模型

作者:高焕堂 时间:2024-01-18 来源:电子产品世界 收藏


本文引用地址://m.amcfsurvey.com/article/202401/454967.htm

1 简介LoRA

上一期介绍了如何复用免费的源代码,来搭配企业的专有数据而训练出形形色色的自用小模型。免费代码既省成本、可靠、省算力、又自有IP,可谓取之不尽、用之不竭的资源,岂不美哉!

重头开始训练自己的小模型,是一条鸟语花香之路。然而,基于别人的预训练(Pre-trained),搭配自有数据而进行微调(Fine-tuning),常常更是一条康庄大道。

随着LLM 等日益繁荣发展,基于这些大模型的迁移学习(Transfer learning),将其预训练好的模型加以微调(Fine tune),来适应到下游的各项新任务,已经成为热门的议题。关于微调技术,其中LoRA 是一种资源消耗较小的训练方法,它能在较少训练参数时就得到比较稳定的效果。

由于LoRA 的外挂模型参数非常轻量,对于各个下游任务来说,只需要搭配特定的训练数据,并独立维护自身的LoRA 参数即可。在训练时可以冻结原模型( 如ResNet50 或MT5) 的既有参数,只需要更新较轻量的LoRA 参数即可,因而微调训练的效率很高。

LoRA 的全名是:Low-Rank Adaptation of Large Language Models( 及大语言模型的低阶适应)。使用这种LoRA 微调方法进行训练时,并不需要调整原( 大)模型的参数值( 图1 里的蓝色部分),而只需要训练LoRA 模型的参数( 图1 里的棕色部分)。

image.png

图1 LoRA的架构

(引自:https://heidloff.net/article/efficient-fine-tuning-lora/)

典型的LoRA 微调途径是,使用下游任务的数据来对< 原模型 + LoRA> 进行重新训练,让该协同模型的性能在该下游任务上表现出最佳效果。

2 简介ResNet50

ResNet50是很通用的AI模型,他擅长于图像的特征提取(Feature extraction),然后依据特征来进行分类(Classification)。所以,它能帮您瞬间探索任何一张图像的特征,然后帮您识别出图片里的人或物的种类。目前的ResNet50 可以准确地识别出1000 种人或物,如日常生活中常遇到的狗、猫、食物、汽车和各种家居物品等。

3 下载LoRA源代码

首先访问这个cccntu 网页,从Github 下载minLoRA源码 ( 图2)。

1705572670833303.png

图2 Github上的免费LoRA源码

然后,按下 就自动把minLoRA 源码下载到本机里了。接着,把所下载的源代码压缩檔解开,放置于Wibdows 本机的Python 工作区里,例如 /Python310/目录区里( 图3)。

image.png

图3 放置于本机的Python环境里

这样,就能先在本机里做简单的测试,例如创建模型并拿简单数据( 或假数据) 来测试,有助于提升成功的自信心。

4 展开微调训练

Step 1:准备训练&测试数据

首先,准备了/ox_lora_data/train/ 训练图像集,包含2 个类--- 水母(Jellyfish) 和蘑菇(Mushroom),各有12 张图像,如图4。

image.png

image.png

图4 12张图像实例

此外,也准备了/ox_lora_data/test/ 测试图像集,也是水母和蘑菇,各有8 张图像。

Step 2:准备ResNet50预训练模型

本范例从torchvision.models 里加载resnet50 预训练模型。这ResNet50 属于大模型,其泛化能力很好。然而,然而对于本范例的较少类的预测( 推论) 准确度就常显得不足。现在,就拿本范例的测试图像集,来检测一下。程序码如下:

# Lora_ResNet50_001_test.py

import torch

import torch.nn as nn

from torchvision import transforms

from torchvision.datasets import ImageFolder

from torch.utils.data import Dataset, DataLoader

import torchvision.models as models

path = ‘c:/ox_lora_data/’

#----------------------------------

model = models.resnet50(

w e i g h t s = m o d e l s . R e s N e t 5 0 _We i g h t s .

IMAGENET1K_V1)

#----------------------------------

def process_lx(labels, batch_size):

lx = labels.clone()

for i in range(batch_size):

if(labels[i]==0): lx[i]=107

elif(labels[i]==1): lx[i]=947

return lx

#----------------------------------

T = transforms.Compose([

transforms.Resize((224, 224)),

transforms.ToTensor()

])

#----------------------------------

test_ds = ImageFolder(path + ‘test/’, transform=T)

test_dl = DataLoader(test_ds, batch_size=1)

model.eval()

with torch.no_grad():

j, m = (0, 0)

for idx, (image, la) in enumerate(test_dl):

labels = process_lx(la, 1)

pred = model(image)

k = torch.argmax(pred[0])

if(la[0]==0 and k==107):

j += 1

elif(la[0] == 1 and k==947):

m += 1

print(“n 水母(Jellyfish) 的正确辨识率:”, j / 8)

print(“n 蘑菇(Mushroom) 的正确辨识率:”, m / 8)

#------------------

#END

在本范例里,其图像分为2 个类:水母和蘑菇。所以在此程序里,其< 水母、蘑菇> 的类标签(Label)分别为:[0, 1]。而在ResNet50 预训练模型里,其<水母、蘑菇> 类标签分别为:[107, 947]。于是,使用process_lx() 函数,来把此程序里的类标签,转换为ResNet50 的类别标签值。在此范例里,我们拿测试数据集里的< 水母、蘑菇> 各8 张图像来给ResNet50 进行分类预测。执行时,输出如下:

-2

这显示出:蘑菇的预测准确度为:0.125,并不理想。亦即,可以观察到了,大模型ResNet50 在这范例里的下游任务上,其预测的准确度并不美好。于是,LoRA微调方法就派上用场了。

Step 3:定义LoRA模型,并展开协同训练兹回顾LoRA 的架构( 图1)。在刚才的范例里,我们加载的ResNet50 模型,就是上图里的Pretrained Weights( 即蓝色) 部分。现在,就准备添加LoRA 模型,也就是上图里的A 和B( 即棕色) 部分。

当我们把A&B 部分添加上去了,就能展开协同训练了。在协同训练时,我们会先冻结Pretrained Weights部分的参数,不去更改它;而只更新LoRA 的A&B 参数。一旦协同训练完成了,就会把LoRA 与ResNet50 的参数合并起来( 即上图右方的橘色部分。请来看看程序码:

# Lora_ResNet50_002_train.py

import numpy as np

import torch

import torch.nn as nn

from torchvision import transforms

from torchvision.datasets import ImageFolder

from torch.utils.data import Dataset, DataLoader

import torchvision.models as models

from functools import partial

import min_lora_model as Min_LoRA

import min_lora_utils as Min_LoRA_Util

path = ‘c:/ox_lora_data/’

#----------------------------------

# 把图片转换成Tensor

T = transforms.Compose([

transforms.Resize((224, 224)),

transforms.ToTensor()

])

def process_lx(labels, batch_size):

lx = labels.clone()

for i in range(batch_size):

if(labels[i]==0): lx[i]=107

elif(labels[i]==1): lx[i]=947

return lx

#----------------------------------

model = models.resnet50(

weight s =mode l s .ResNet50_We ight s .

IMAGENET1K_V1)

#-------- 添加LoRA --------

my_lora_config = { nn.Linear: { “weight”: partial(

Min_LoRA.LoRAParametrization.from_linear,

rank=16),

}, }

#---- 把LoRA 参数添加到原模型 ------

Min_LoRA.add_lora(model, lora_config=my_lora_

config)

parameters = [

{ “params”: list(Min_LoRA_Util.get_lora_

params(model))}, ]

# 只更新LoRA 的Weights

optimizer = torch.optim.Adam(parameters, lr=1e-3)

loss_fn = nn.CrossEntropyLoss()

model.train()

bz = 4

train_ds = ImageFolder(path + ‘ train/ ’ ,

transform=T)

train_dl = DataLoader(train_ds, batch_size=bz,

shuffle=True)

length = len(train_ds)

#----------------------------------

print(‘n------ 外挂LoRA 模型, 协同训

练 ------’)

epochs = 25

for ep in range(epochs+1):

total_loss = 0

for idx, (images, la) in enumerate(train_dl):

labels = process_lx(la, bz)

pred = model(images)

loss = loss_fn(pred, labels)

loss.backward()

optimizer.step()

optimizer.zero_grad()

total_loss += loss.item() * bz

if(ep%5 == 0):

print(‘ ep=’, ep, ‘, loss=’, total_loss /

length )

#-------------- testing ---------------

test_ds = ImageFolder(path + ‘test/’, transform=T)

test_dl = DataLoader(test_ds, batch_size=1)

model.eval()

with torch.no_grad():

j, m = (0, 0)

for idx, (image, la) in enumerate(test_dl):

labels = process_lx(la, 1)

pred = model(image)

k = torch.argmax(pred[0])

if(la[0]==0 and k==107): j += 1

elif(la[0] == 1 and k==947): m += 1

print(“n 水母(Jellyfish) 的正确辨识率:”, j / 8)

print(“n 蘑菇(Mushroom) 的正确辨识率:”, m / 8)

#END

在此范例程序里, 把minLoRA 的源代码, 与ResNet50预训练模型结合,展开100 回合的微调协同训练。并输出如下:

1705572969519996.png

从上述的输出结果,于是我们可以观察到,当ResNet50 在未加挂LoRA 时,其< 蘑菇> 测试的预测准确率是:0.125。当我们完成协同训练100 回合之后,其预测准确度提升到:0.75,达到微调的目的了。

5 结束语

本文就ResNet50 为例,说明如何拿LoRA 源代码,来对ResNet50 进行微调。您已经发现到了,微调可以让ResNet50 更加贴心,满足您的需求。这种途径可以适用于各种大模型,例如MT5 大语言模型、以及StableDiffusion绘图大模型等。

(本文来源于《EEPW》2024.1-2)



评论


技术专区

关闭