|
- import os
- import torch
- import transformers
- from transformers import MBartForConditionalGeneration
- from utils.config_mbart import init_argument
- from utils.tokenizer import T5PegasusTokenizer
- from utils.dataset_de import prepare_data, create_data, get_data_token
- from utils.train_model_de import train_model
-
- if __name__ == '__main__':
- # 加载设置参数
- args = init_argument()
- # 设置训练设备
- device = 'cuda:' + args.device if torch.cuda.is_available() else 'cpu'
- # 加载分词器
- tokenizer = T5PegasusTokenizer(vocab_file='../code/vocabulary/vocab.txt')
-
- # 加载训练数据集和验证数据集
- # print('加载训练集...')
- # train_data = prepare_data(args, args.train_data, tokenizer, tokenizer_raw, ltp, sen_model, term='train')
-
- dev_data = prepare_data(args, args.dev_data, tokenizer, term='dev')
- # quit()
- # 加载预训练模型
- base_model_path = os.path.join(args.model_dir, 'model_0.bin')
- # base_model_path = 'saved_mbart_model/model_1.bin'
- if os.path.exists(base_model_path):
- print('已存在预训练模型。')
- model = torch.load(base_model_path).to(device)
- # torch.save(model, base_model_path)
- else:
- # model = BartForConditionalGeneration.from_pretrained(args.pretrain_model).to(device)
- # MT5ForConditionalGeneration.generate
- model_config = transformers.models.mbart.MBartConfig.from_json_file(args.model_config)
- model = MBartForConditionalGeneration(config=model_config).to(device)
-
- args.local_rank = 100
- args.data_parallel = False
-
- # 多卡训练
- if args.data_parallel and torch.cuda.is_available():
- device_ids = range(torch.cuda.device_count())
- print(device_ids)
- model = torch.nn.DataParallel(model, device_ids=device_ids)
-
- # 训练模型
- # print("训练开始...")
- # print(train_data)
- # quit()
- sen_model = []
-
- # train_data_list = get_data_list()
- base_path = '../dataset/defult_data/token/'
- adam = torch.optim.AdamW(model.parameters(), lr=args.lr)
- train_model(model, adam, base_path, dev_data, tokenizer, device, args)
- print("训练结束!")
|