|
- import os
- # with open("../code/requirements.txt", 'r') as f:
- # lines = f.readlines()
- # for line in lines:
- # cmd = f'pip install {line}'
- # print(cmd)
- # # cmd='pip install -r ./requirements.txt'
- # os.system(cmd)
- # print('安装结束')
- import torch
- import transformers
- from transformers import BartForConditionalGeneration
- from utils.config import init_argument
- from utils.tokenizer import T5PegasusTokenizer
- from utils.dataset import prepare_data, create_data, get_data_token
- from utils.train_model import train_model
- from utils.data_first_deal import add_key_word
- from sentence_transformers import SentenceTransformer, util
-
- from ltp import LTP
- print('导包结束')
-
- def get_data_list():
- base_path = '../dataset/data/rank_data/token/'
- file_list = os.listdir(base_path)
- # file_list = sorted([int(item[:1]) for item in file_list])
- print('【文件列表】', file_list)
-
- train_data_list = []
- for file in file_list:
- # print('【file】', file)
- train_data_list.extend(get_data_token(base_path + file, 'train'))
-
- return train_data_list
-
- if __name__ == '__main__':
- # 加载设置参数
- args = init_argument()
- # 设置训练设备
- device = 'cuda:' + args.device if torch.cuda.is_available() else 'cpu'
- # 加载分词器
- tokenizer = T5PegasusTokenizer(vocab_file='../dataset/vocabulary/vocab.txt')
- # tokenizer = tokenizer_raw
- print('ltp模型')
- ltp = LTP('../dataset/LTP/base-tgz-extracted/')
- sen_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
- # sen_model = ''
- add_key_word(args, ltp)
-
- # 加载训练数据集和验证数据集
- # print('加载训练集...')
- # train_data = prepare_data(args, args.train_data, tokenizer, tokenizer_raw, ltp, sen_model, term='train')
- print('验证数据集')
- dev_data = prepare_data(args, args.dev_data, tokenizer, ltp, sen_model, term='dev')
- # quit()
- # 加载预训练模型
- # base_model_path = os.path.join(args.model_dir, 'summary_model.bin')
- base_model_path = '../model/bart_model.bin'
- if os.path.exists(base_model_path):
- print('已存在预训练模型。')
- model = torch.load(base_model_path).to(device)
- # torch.save(model, base_model_path)
- else:
- # model = BartForConditionalGeneration.from_pretrained(args.pretrain_model).to(device)
- # MT5ForConditionalGeneration.generate
- model_config = transformers.models.bart.BartConfig.from_json_file(args.model_config)
- model = BartForConditionalGeneration(config=model_config).to(device)
-
- args.local_rank = 100
- args.data_parallel = False
-
- # 多卡训练
- if args.data_parallel and torch.cuda.is_available():
- device_ids = range(torch.cuda.device_count())
- print(device_ids)
- model = torch.nn.DataParallel(model, device_ids=device_ids)
-
- # 训练模型
- # print("训练开始...")
- # print(train_data)
- # quit()
- sen_model = []
-
- train_data_list = get_data_list()
-
- adam = torch.optim.AdamW(model.parameters(), lr=args.lr)
- train_model(model, adam, train_data_list, dev_data, tokenizer, ltp, device, args)
- print("训练结束!")
|