Are you sure you want to delete this task? Once this task is deleted, it cannot be recovered.
PCL-张晗 d42db4726d | 1 year ago | |
---|---|---|
.idea | 1 year ago | |
PanguTokenizer | 1 year ago | |
Pangu_chk | 1 year ago | |
chinese_gpt_chk | 1 year ago | |
configs | 1 year ago | |
dialogue_dir | 1 year ago | |
fig | 1 year ago | |
replace | 1 year ago | |
reward_data_dir/processed | 1 year ago | |
reward_model | 1 year ago | |
sft | 1 year ago | |
README.md | 1 year ago | |
dataprocess.py | 1 year ago | |
requirements.txt | 1 year ago | |
trainPanguPPO.sh | 1 year ago | |
trlx_pangu.py | 1 year ago |
trlx
库使用RLHF训练Pangu 2.6B中文对话模型pipeline基于chat-gpt的人工反馈的强化学习(RLHF)流程,开发了基于盘古-alpha 2.6B GPU版本模型的RLHF pipeline。我们的pipeline是基于OpenAI论文 "Learning to Summarize from human feedback"的复现代码trlx进行修改。
使用盘古-alpha 2.6B 模型为基础模型,通过监督预训练(SFT)在webtext等对话语料上进行微调得到对话版本盘古-alpha模型。标注人员从15个常见领域设计问题对盘古对话模型进行提问,针对盘古对话模型的输出结果,从适用性,具体性,正确性,安全性4个维度进行人工反馈评测,并收集人工反馈数据用于训练评价模型(RM)代替人工反馈。最后,使用经典RL方法PPO算法和RM模型对SFT阶段的盘古模型进行强化学习训练。
1). 需要配置trlx库相关环境,参考 "trlx"
git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 # for cuda
pip install -e .
2). 下载盘古-2.6B模型:
https://huggingface.co/imone/pangu_2_6B
模型.bin文件保存至 ./Pangu_chk
3). 准备SFT数据集(以webtext为例):
https://paperswithcode.com/dataset/webtext
数据样例保存至: ./dialogue_dir/demo.json
数据样例保存至: ./reward_data_dir/processed/demo.json
下图所示为用户标注界面,数据标注相关细节可参考: PanGu-Dialog-HFDataset
将repalce 文件夹内的 ppo_models.py 文件替换trlx/trainer/nn文件夹下的ppo_models.py
主要修改为盘古模型的载入部分.
if "pangu" in config.lower():
self.config = transformers.AutoConfig.from_pretrained(config,trust_remote_code=True)
self.base_model = transformers.AutoModelForCausalLM.from_pretrained(config,trust_remote_code=True)
gpt_branch_supported_archs = [
"GPTJForCausalLM",
"GPT2LMHeadModel",
"GPTNeoForCausalLM",
"GPTNeoXForCausalLM",
"GPTPanguForCausalLM",
]
为了和Trlx兼容,我们将分词器修改为 与CPM分词器的接口相同格式,与原始的盘古分词器有所不同。SPM文件是盘古的。可以使用入下命令导入。
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("./PanguTokenizer")
或者从Hugging face下载
tokenizer = AutoTokenizer.from_pretrained("Hanlard/Pangu_alpha")
1). 监督微调 (SFT):
cd sft/ && deepspeed train_SFT.py
2). 训练 Reward 模型:
cd reward_model/ && deepspeed train_reward_model.py
3). 使用PPO算法强化学习:
accelerate launch --config_file configs/default_accelerate_config.yaml trlx_pangu.py
备注: 至少需要1张V100显卡。
基于人工反馈增强盘古2.6B模型
Text Python Shell
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》