|
- # -*- coding = utf-8 -*-
- '''
- # @time:2023/2/19 15:14
- # Author:DFTL
- # @File:train_0219.py
- '''
-
- import argparse
- import time
- import os
- import torch
- import json
- from datetime import datetime
- from transformers import BertTokenizer,BertModel
- from Models.Duiqi import MyModule
- import torch.nn as nn
- from tqdm import tqdm
- from utils import Dataset_df
- from torch.utils.data import DataLoader
- from torch.optim import SGD,Adam
- from torch.nn import CrossEntropyLoss
- from Module import MyModuel_v2
-
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
- def getKnowledge(knowledge_dict,labels): #根据标签取出三条三元组
-
- kd1 = []
- kd2 = []
- kd3 = []
-
- for label in labels:
-
- triples = knowledge_dict[str(label.int().item())]#根据标签取出相关三元组
-
- kd1.append(triples[0][0] + triples[0][1] + triples[0][2])
- kd2.append(triples[1][0] + triples[1][1] + triples[1][2])
- kd3.append(triples[2][0] + triples[2][1] + triples[2][2])
-
- return kd1,kd2,kd3
-
- def getTokenizer(bert_base_chinese_path):
-
- tokenizer = BertTokenizer.from_pretrained(bert_base_chinese_path)
- bert = BertModel.from_pretrained(bert_base_chinese_path)
- # encoded_text = get_encoded_text(text,max_len,tokenizer,bert)
- return tokenizer,bert
-
- def getEncodedtext(triples,max_len,tokenizer,bert):
-
- # tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
-
- tokenized = tokenizer(triples, max_length=max_len, truncation=True)
- tokens = tokenized['input_ids']
- # masks = tokenized['attention_mask']
- # text_len = len(tokens)
-
- token_ids = torch.tensor([tokens],dtype=torch.long).to(device)
- # masks = torch.tensor(masks, dtype=torch.bool)
-
- encoded_text = bert(token_ids)[1]
-
- return encoded_text #, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails # , json_data['spo_list']
-
- def train(model,loss_,optimizer,train_dataset,args):
-
- model.train()
-
- tokenizer, bert = getTokenizer(args.pretrained_bert)
- #tokenizer = tokenizer.to(device)
- bert = bert.to(device)
-
- with open(args.knowledge_path,'r',encoding='utf-8') as f:
- knowledge_dict = json.load(f)
-
- loss_min = 1000000
- for epoch in range(args.epoch):
-
- t1 = time.time()
- print('======================epoch:{}/{}========================='.format(epoch,args.epoch))
- epoch_loss = 0
-
- i = 0
-
- for data in train_dataset: #进度条tqdm
-
- optimizer.zero_grad()
- loss_step = 0
-
- #影像数据
- image, label = data
- image = image.to(device)
- label = label.to(device)
-
- t1,t2,t3 = getKnowledge(knowledge_dict,label)
-
- '''预训练的BERT对三元组文本编码'''
- #三元组数据,数量:3
-
- encodedevent1 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t1]).to(device)
- encodedevent2 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t2]).to(device)
- encodedevent3 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t3]).to(device)
-
- loss1,loss2 = model(loss_,label,image,encodedevent1,encodedevent2,encodedevent3)
-
- loss = 0.7*loss1 + 0.3*loss2
-
- loss.backward()
- optimizer.step()
- i+=1
- print('\r','step: ',i,' loss: ', loss, end='', flush=True)
-
- epoch_loss += loss
-
- # epoch_loss += loss_step
- print('step_loss:{} \n'.format(epoch_loss / len(train_dataset)))
-
- '''save models'''
- if epoch_loss/len(train_dataset) < loss_min:
- loss_min = epoch_loss/len(train_dataset)
- torch.save(model,os.path.join("Weights","model_0221.pt"))
-
- print('epoch: {} | loss: {} | Saving model... \n'.format(epoch,loss_min))
-
- if __name__ == '__main__':
-
- parser = argparse.ArgumentParser(description='Model Controller')
-
- parser.add_argument('--weights_path', type=str, default='', help='the path saving weights')
- parser.add_argument('--image_path',type=str,default="DataSet/YRD_N12/X_train.npy")
- parser.add_argument('--label_path', type=str, default="DataSet/YRD_N12/Y_train.npy")
- parser.add_argument('--knowledge_path',type=str, default='DataSet/Knowledge_v2.json')
-
- parser.add_argument('--max_len', type=int, default=300)
- parser.add_argument('--pretrained_bert',type=str,default='/sunhuan/Knowledge_Graph/CasRelPyTorch-master/pre_trained/bert-base-chinese')
- # parser.add_argument('--out_dim',type=int,)
-
- parser.add_argument('--batch', type=int, default=16)
- parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
- parser.add_argument('--epoch', type=int, default=200, help='train_epochs')
-
- args = parser.parse_args()
-
- now_time = datetime.now()
- time_str = datetime.strftime(now_time, '%m-%d_%H-%M-%S')
- print(time_str)
-
- #model = MyModule(args).to(device)
- model = MyModuel_v2().to(device)
-
- # '''余弦相似度损失'''
- # criterion = nn.CosineEmbeddingLoss(margin=0.3)
- # target = torch.tensor([[1,-1]], dtype=torch.float)#, reduction='sum'
- criterion = CrossEntropyLoss()
-
-
- '''优化器更新参数'''
- # optimizer = SGD(model.parameters(),lr=args.lr)
- # optimizer1 = Adam(model.parameters(),lr=args.lr)
- optimizer1 = Adam(model.parameters(), lr=args.lr)
-
- mydataset = Dataset_df.MyDataset(args.image_path,args.label_path)
- data_loader = DataLoader(dataset=mydataset,batch_size=args.batch, shuffle=True, pin_memory=True)
-
- train(model,criterion,optimizer1,data_loader,args)
-
-
-
-
-
-
-
-
|