Are you sure you want to delete this task? Once this task is deleted, it cannot be recovered.
Nihao- 6f889c878b | 1 year ago | |
---|---|---|
config | 1 year ago | |
data | 1 year ago | |
dataset | 1 year ago | |
fig | 1 year ago | |
lora_utils | 1 year ago | |
saved | 1 year ago | |
.gitignore | 1 year ago | |
LICENSE | 1 year ago | |
LoRA_finetune_with_stanford_alpaca.ipynb | 1 year ago | |
README.md | 1 year ago | |
example.ipynb | 1 year ago | |
example_simple.ipynb | 1 year ago | |
requirements.txt | 1 year ago | |
train.py | 1 year ago | |
train_new.py | 1 year ago |
This repository contains code for fintune ChatGLM-6b using low-rank adaptation (LoRA).
We also provide a finetuned weight.
The minimum required GPU memory is 24G, RTX3090 is enough for training.
train.py
import loralib as lora
import lora_utils.insert_lora
import dataset.GLM as GLM_Data
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
device = 'cuda'
checkpoint = "THUDM/chatglm-6b"
# load model
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True)
# get LoRA model
lora_config = {
'r': 32,
'lora_alpha':32,
'lora_dropout':0.1,
'enable_lora':[True, True, True],
}
model = lora_utils.insert_lora.get_lora_model(model, lora_config)
### trainable_params:22020096 (0.35%), non_trainable_params:6255206400
# get Dataloader
pairs = [{'prompt':'Hello!', 'completion':'Hi! This is ChatGLM.'}]
pairs_encoded = GLM_Data.encode_pairs(pairs, tokenizer)
train_dataset = GLM_Data.GLMDataset(pairs_encoded)
train_dataloader = DataLoader(dataset=train_dataset, collate_fn = GLM_Data.collate_fn, shuffle=True, batch_size=1)
# training
model.half().to(device)
batch = {k: v.to(device) for k, v in next(iter(train_dataloader)).items()}
outputs = model(**batch)
outputs.loss.backward()
Using accelerate CLI tool to launch multiprocess / distributed training:
accelerate launch --config_file config/default_config.yaml train_new.py
Likes OpenAI's fintune API, the data should be in following structure:
[
{'prompt': <enter the prompt here (can be instrcution)>, 'completion': <the expectation completion>},
{'prompt': <enter the prompt here (can be instrcution)>, 'completion': <the expectation completion>},
...,
{'prompt': <enter the prompt here (can be instrcution)>, 'completion': <the expectation completion>},
]
It is a list of prompt-completion pairs.
Here we use the Stanford Alpaca's Dataset as an example for fine-tuning. We also provide a finetuned weight.
example line:
{'prompt': 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nClassify the movie genres from the given context.\n\n### Input:\nThis movie tells the story of two brothers who were both born with magical powers.\n\n### Response:', 'completion': 'Fantasy'}
Training for Stanford Alpaca's Dataset should within 30min per epoch on 4*V100
You may observe a typical training loss curve:
Note: vary with different dataset
lora_config = {
'r': 32,
'lora_alpha':32,
'lora_dropout':0.1,
'enable_lora':[True, True, True],
}
Using above LoRA config, we have trainable_params:22020096 (0.35%), non_trainable_params:6255206400
torch.save(lora.lora_state_dict(model), 'path to file you saved')
model.load_state_dict(torch.load('path to file you saved'), strict=False)
No Description
Jupyter Notebook Python Text
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》