本仓库实现了对于 ChatGLM2-6B 模型基于 P-Tuning v2 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。
下面以 ADGEN (广告生成) 数据集为例介绍代码的使用方法。
运行微调除 ChatGLM2-6B 的依赖之外,还需要安装以下依赖
pip install rouge_chinese nltk jieba datasets
ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。
{
"content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
"summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
}
从 Google Drive 或者 Tsinghua Cloud 下载处理好的 ADGEN 数据集,将解压后的 AdvertiseGen
目录放到本目录下。
运行以下指令进行训练:
bash train.sh
train.sh
中的 PRE_SEQ_LEN
和 LR
分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 quantization_bit
来被原始模型的量化等级,不加此选项则为 FP16 精度加载。
在默认配置 quantization_bit=4
、per_device_train_batch_size=1
、gradient_accumulation_steps=16
下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 per_device_train_batch_size
的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。
如果你想要从本地加载模型,可以将 train.sh
中的 THUDM/chatglm2-6b
改为你本地的模型路径。
如果需要进行全参数的 Finetune,需要安装 Deepspeed,然后运行以下指令:
bash ds_train_finetune.sh
在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,所以在推理时需要同时加载原 ChatGLM2-6B 模型以及 PrefixEncoder 的权重,因此需要指定 evaluate.sh
中的参数:
--model_name_or_path THUDM/chatglm2-6b
--ptuning_checkpoint $CHECKPOINT_PATH
如果是,只需要跟之前一样设定 model_name_or_path
:
--model_name_or_path $CHECKPOINT_PATH
评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在
./output/adgen-chatglm2-6b-pt-128-2e-2/generated_predictions.txt
。
首先载入Tokenizer:
from transformers import AutoConfig, AutoModel, AutoTokenizer
# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
config = AutoConfig.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
注意你可能需要将 pre_seq_len
改成你训练时的实际值。如果你是从本地加载模型的话,需要将 THUDM/chatglm2-6b
改成本地的模型路径(注意不是checkpoint路径)。
model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True)
之后根据需求可以进行量化,也可以直接使用:
# Comment out the following line if you don't use quantization
model = model.quantize(4)
model = model.cuda()
model = model.eval()
response, history = model.chat(tokenizer, "你好", history=[])
你也可以直接运行支持加载 P-Tuning v2 checkpoint 的 web demo
bash web_demo.sh
可能需要修改 web_demo.sh 的内容以符合你实际的 checkpoint 情况。
修改 train.sh
和 evaluate.sh
中的 train_file
、validation_file
和test_file
为你自己的 JSON 格式数据集路径,并将 prompt_column
和 response_column
改为 JSON 文件中输入文本和输出文本对应的 KEY。可能还需要增大 max_source_length
和 max_target_length
来匹配你自己的数据集中的最大输入输出长度。
如需要使用多轮对话数据对模型进行微调,可以提供聊天历史,例如以下是一个三轮对话的训练数据:
{"prompt": "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "response": "用电脑能读数据流吗?水温多少", "history": []}
{"prompt": "95", "response": "上下水管温差怎么样啊?空气是不是都排干净了呢?", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"]]}
{"prompt": "是的。上下水管都好的", "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"], ["95", "上下水管温差怎么样啊?空气是不是都排干净了呢?"]]}
训练时需要指定 --history_column
为数据中聊天历史的 key(在此例子中是 history
),将自动把聊天历史拼接。要注意超过输入长度 max_source_length
的内容会被截断。
可以参考以下指令:
bash train_chat.sh
@inproceedings{liu2022p,
title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)},
pages={61--68},
year={2022}
}
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》