|
- # -*- coding = utf-8 -*-
- '''
- # @time:2023/2/19 15:14
- # Author:DFTL
- # @File:train_0219.py
- '''
-
- import argparse
- import time
- import os
- import torch
- import json
- from datetime import datetime
- from transformers import BertTokenizer,BertModel
- # from Models.Duiqi import MyModule
- import torch.nn as nn
- from tqdm import tqdm
- from utils import Dataset_df
- from torch.utils.data import DataLoader
- from torch.optim import SGD,Adam
- from torch.nn import CrossEntropyLoss
- from Module import MyModule_v1,MyModuel_v2,MyModuel_v3,MyModuel_v4
- import numpy as np
- # from sklearn.metrics import classification_report
- from Fusion_block import Fusion_Module
-
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- target_names = ['class 0', 'class 1', 'class 2','class 3','class 4','class 5','class 6','class 7','class 8','class 9','class 10','class 11']
-
- def getKnowledge(knowledge_dict,labels): #根据标签取出三条三元组
-
- kd1 = []
- kd2 = []
- kd3 = []
-
- for label in labels:
-
- triples = knowledge_dict[str(label.int().item())]#根据标签取出相关三元组
-
- kd1.append(triples[0][0] + triples[0][1] + triples[0][2])
- kd2.append(triples[1][0] + triples[1][1] + triples[1][2])
- kd3.append(triples[2][0] + triples[2][1] + triples[2][2])
-
- return kd1,kd2,kd3
-
- def getTokenizer(bert_base_chinese_path):
-
- tokenizer = BertTokenizer.from_pretrained(bert_base_chinese_path)
- bert = BertModel.from_pretrained(bert_base_chinese_path)
- # encoded_text = get_encoded_text(text,max_len,tokenizer,bert)
- return tokenizer,bert.to(device)
-
- def getEncodedtext(triples,max_len,tokenizer,bert):
-
- # tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
-
- tokenized = tokenizer(triples, max_length=max_len, truncation=True)
- tokens = tokenized['input_ids']
- # masks = tokenized['attention_mask']
- # text_len = len(tokens)
-
- token_ids = torch.tensor([tokens],dtype=torch.long).to(device)
- # masks = torch.tensor(masks, dtype=torch.bool)
-
- encoded_text = bert(token_ids)[1]
-
- return encoded_text #, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails # , json_data['spo_list']
-
- def pre_train(model,loss_,optimizer,train_dataset,args):
-
- model.train()
-
- tokenizer, bert = getTokenizer(args.pretrained_bert)
-
- with open(args.knowledge_path,'r',encoding='utf-8') as f:
- knowledge_dict = json.load(f)
-
- loss_min = 1000000
- for epoch in range(args.epoch):
-
-
- print('======================epoch:{}/{}========================='.format(epoch,args.epoch))
- epoch_loss = 0
- i = 0
-
- for data in train_dataset: #进度条tqdm
-
- optimizer.zero_grad()
- i+=1
-
- #影像数据
- image, label = data
- image = image.to(device)
- label = label.to(device)
-
- t1,t2,t3 = getKnowledge(knowledge_dict,label)
-
- '''预训练的BERT对三元组文本编码'''
- #三元组数据,数量:3
-
- time1 = time.time()
-
- encodedevent1 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t1]).to(device)
- encodedevent2 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t2]).to(device)
- encodedevent3 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t3]).to(device)
-
- time2 = time.time()
- print('bert time: ', time2 - time1)
-
- time3 = time.time()
-
- loss1,loss2 = model(loss_,label,image,encodedevent1,encodedevent2,encodedevent3)
-
- time4 = time.time()
- print('model time: ', time4 - time3)
-
- loss = 0.7*loss1 + 0.3*loss2
-
- loss.backward()
- optimizer.step()
-
- print('\r','step: ',i,' loss: ', loss, end='', flush=True)
-
- epoch_loss += loss
-
- # triples = knowledge_dict[str(label.int().item())]#根据标签取出相关三元组
- #
- # kd = []
- # for triple in triples:
- # event = triple[0]+triple[1]+triple[2]
- # kd.append(event)
-
- # print(kd)
- # output1, output2 = model(image,kd)
- # loss_step.backward()
- # optimizer.step()
- # print('\r','step_loss:{}'.format(loss_step / len(output2)),end='')
-
-
-
- # epoch_loss += loss_step
- print('step_loss:{} \n'.format(epoch_loss / len(train_dataset)))
-
- '''save models'''
- if epoch_loss/len(train_dataset) < loss_min:
- loss_min = epoch_loss/len(train_dataset)
- torch.save(model,os.path.join("Weights","model.pt"))
-
- print('epoch: {} | loss: {} | Saving model... \n'.format(epoch,loss_min))
-
- def pre_train_v1(model,loss_,optimizer,train_dataset,args):
-
- model.train()
-
- tokenizer, bert = getTokenizer(args.pretrained_bert)
-
- with open(args.knowledge_path,'r',encoding='utf-8') as f:
- knowledge_dict = json.load(f)
-
- loss_min = 1000000
- for epoch in range(args.epoch):
-
- print('======================epoch:{}/{}========================='.format(epoch,args.epoch))
- epoch_loss = 0
- i = 0
-
- for data in train_dataset: #进度条tqdm
-
- optimizer.zero_grad()
- i+=1
-
- #影像数据
- image, label = data
- image = image.to(device)
- label = label.to(device)
-
- t1,t2,t3 = getKnowledge(knowledge_dict,label)
-
- '''预训练的BERT对三元组文本编码'''
- #三元组数据,数量:3
-
- # time1 = time.time()
-
- encodedevent1 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t1]).to(device)
- # encodedevent2 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t2]).to(device)
- # encodedevent3 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t3]).to(device)
-
- # time2 = time.time()
- # print('bert time: ', time2 - time1)
-
- # time3 = time.time()
-
- loss = model(loss_,label,image,encodedevent1)
-
- # time4 = time.time()
- # print('model time: ', time4 - time3)
-
- # loss = 0.7*loss1 + 0.3*loss2
-
- loss.backward()
- optimizer.step()
-
- print('\r','step: ',i,' loss: ', loss, end='', flush=True)
-
- epoch_loss += loss
-
- # triples = knowledge_dict[str(label.int().item())]#根据标签取出相关三元组
- #
- # kd = []
- # for triple in triples:
- # event = triple[0]+triple[1]+triple[2]
- # kd.append(event)
-
- # print(kd)
- # output1, output2 = model(image,kd)
- # loss_step.backward()
- # optimizer.step()
- # print('\r','step_loss:{}'.format(loss_step / len(output2)),end='')
-
-
-
- # epoch_loss += loss_step
- print('step_loss:{} \n'.format(epoch_loss / len(train_dataset)))
-
- '''save models'''
- if epoch_loss / len(train_dataset) < loss_min:
- loss_min = epoch_loss / len(train_dataset)
-
- weight_name = str(epoch) + '_' + str(loss_min.data) + '.pt'
- torch.save(model, os.path.join(args.weights_path, weight_name))
-
- print('epoch: {} | loss: {} | Saving model... \n'.format(epoch, loss_min))
-
- def pre_train_v2(model,loss_,optimizer,train_dataset,args):
-
- model.train()
-
- tokenizer, bert = getTokenizer(args.pretrained_bert)
-
- with open(args.knowledge_path,'r',encoding='utf-8') as f:
- knowledge_dict = json.load(f)
-
- loss_min = 1000000
- for epoch in range(args.epoch):
-
-
- print('======================epoch:{}/{}========================='.format(epoch,args.epoch))
- epoch_loss = 0
- i = 0
-
- for data in train_dataset: #进度条tqdm
-
- optimizer.zero_grad()
- i+=1
-
- #影像数据
- image, label = data
- image = image.to(device)
- label = label.to(device)
-
- t1,t2,t3 = getKnowledge(knowledge_dict,label)
-
- '''预训练的BERT对三元组文本编码'''
- #三元组数据,数量:3
-
- # time1 = time.time()
-
- encodedevent1 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t1]).to(device)
- # encodedevent2 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t2]).to(device)
- # encodedevent3 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t3]).to(device)
-
- # time2 = time.time()
- # print('bert time: ', time2 - time1)
-
- # time3 = time.time()
-
- loss1,loss2 = model(loss_,label,image,encodedevent1)
-
- # time4 = time.time()
- # print('model time: ', time4 - time3)
-
- loss = 0.7*loss1 + 0.3*loss2
-
- loss.backward()
- optimizer.step()
-
- print('\r','step: ',i,' loss: ', loss, end='', flush=True)
-
- epoch_loss += loss
-
- # triples = knowledge_dict[str(label.int().item())]#根据标签取出相关三元组
- #
- # kd = []
- # for triple in triples:
- # event = triple[0]+triple[1]+triple[2]
- # kd.append(event)
-
- # print(kd)
- # output1, output2 = model(image,kd)
- # loss_step.backward()
- # optimizer.step()
- # print('\r','step_loss:{}'.format(loss_step / len(output2)),end='')
-
-
-
- # epoch_loss += loss_step
- print('step_loss:{} \n'.format(epoch_loss / len(train_dataset)))
-
- '''save models'''
- if epoch_loss / len(train_dataset) < loss_min:
- loss_min = epoch_loss / len(train_dataset)
-
- weight_name = str(epoch) + '_' + str(loss_min) + '.pt'
- torch.save(model, os.path.join(args.weights_path, weight_name))
-
- print('epoch: {} | loss: {} | Saving model... \n'.format(epoch, loss_min))
-
- def classifier_test(model,criterion,test_dataset,args):
- model.eval()
-
- with open(args.knowledge_path,'r',encoding='utf-8') as f:
- knowledge_dict = json.load(f)
-
- count = 0
-
- def hook(model, input, output):
- # features.append(input)
- features_out.append(output)
- return None
-
- with torch.no_grad():
- for data in tqdm(test_dataset):
-
- image,label = data
-
- image = image.to(device)
- label = label.to(device)
-
- #获取分类头结果
-
- ##钩子截流,获取分类头输出、语义表示等
- # features = []
- features_out = []
-
- if image.shape[0] != 8:
- break
-
-
- model.classifier.register_forward_hook(hook)
-
- text = torch.randn(8, 768)
- text2 = torch.randn(8, 768)
- text3 = torch.randn(8, 768)
-
- model(criterion, label, image, text, text2, text3)
-
- # print(features_out[0].shape)
-
- pred = np.argmax(features_out[0].detach().numpy(), axis=1)
- if count == 0:
- y_pred_test = pred
- y_true_test = label.detach().numpy()
- count = 1
- else:
- y_pred_test = np.concatenate((y_pred_test, pred))
- y_true_test = np.concatenate((y_true_test, label.detach().numpy()))
-
- # print(classification_report(y_pred_test, y_true_test, target_names=target_names))
-
- def train(model,pretrained_model,optimizer,train_dataset,args):
-
- model.train()
-
- tokenizer, bert = getTokenizer(args.pretrained_bert)
-
- loss_ = CrossEntropyLoss()
-
- with open(args.knowledge_path,'r',encoding='utf-8') as f:
- knowledge_dict = json.load(f)
-
- '''钩子截流(!!!只需要注册一次!!!)'''
- features = []
- feature_event = []
-
- def hook_image(net, input, output):
- features.append(output)
- return None
-
- def hook_event(net, input, output):
- feature_event.append(output)
- return None
-
- pretrained_model.semantic_token.register_forward_hook(hook_image)
- pretrained_model.event_embeding.register_forward_hook(hook_event)
-
- loss_min = 1000000
- for epoch in range(args.epoch):
-
- print('======================epoch:{}/{}========================='.format(epoch,args.epoch))
- epoch_loss = 0
- i = 0
-
- for data in train_dataset: #进度条tqdm
-
- optimizer.zero_grad()
- i+=1
-
- #影像数据
- image, label = data
- image = image.to(device)
- label = label.to(device)
-
- t1,t2,t3 = getKnowledge(knowledge_dict,label)
-
- '''预训练的BERT对三元组文本编码'''
- #三元组数据,数量:3
-
- encodedevent1 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t1]).to(device)
- encodedevent2 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t2]).to(device)
- encodedevent3 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t3]).to(device)
-
- losses = pretrained_model(loss_,label,image,encodedevent1,encodedevent2,encodedevent3)
-
- # loss = 0.7*loss1 + 0.3*loss2
- output = model(features[0],feature_event[0],feature_event[1],feature_event[2])
-
- loss = loss_(output,label.long())
-
- loss.backward()
- optimizer.step()
-
- features.clear()
- feature_event.clear()
-
- print('\r','step: ',i,' loss: ', loss, end='', flush=True)
-
- epoch_loss += loss
-
- # epoch_loss += loss_step
- print('step_loss:{} \n'.format(epoch_loss / len(train_dataset)))
-
- '''save models'''
- if epoch_loss/len(train_dataset) < loss_min:
- loss_min = epoch_loss/len(train_dataset)
- torch.save(model,os.path.join("Weights","model.pt"))
-
- print('epoch: {} | loss: {} | Saving model... \n'.format(epoch,loss_min))
-
- def get_E(model,args):
-
- tokenizer, bert = getTokenizer(args.pretrained_bert)
-
- with open(args.knowledge_path,'r',encoding='utf-8') as f:
- knowledge_dict = json.load(f)
-
- Placeholder_image = torch.randn(1,32,27,27).to(device)
- Placeholder_label = torch.randint(0, 12, [1]).to(device)
- Placeholder_loss = CrossEntropyLoss()
-
- list = []
-
- features_event = []
-
- def hook1(model, input, output):
- features_event.append(output)
- # features_out.append(output)
- return None
- model.event_embeding.register_forward_hook(hook1)
-
- for k,v in knowledge_dict.items():
-
- encodedevent1 = torch.cat([getEncodedtext(v[0][0]+v[0][1]+v[0][2],args.max_len,tokenizer,bert)]).to(device)
- encodedevent2 = torch.cat([getEncodedtext(v[1][0]+v[1][1]+v[1][2],args.max_len,tokenizer,bert)]).to(device)
- encodedevent3 = torch.cat([getEncodedtext(v[2][0]+v[2][1]+v[2][2],args.max_len,tokenizer,bert)]).to(device)
-
- model(Placeholder_loss,Placeholder_label,Placeholder_image,encodedevent1,encodedevent2,encodedevent3)
-
- list.append(features_event[0])
- list.append(features_event[1])
- list.append(features_event[2])
-
- features_event.clear()
-
- encodedevent = torch.stack([x for x in list], dim=0)
- encodedevent = torch.squeeze(encodedevent)
-
- torch.save(encodedevent,"Weights/all_event.pt")
- return encodedevent
-
- def search(E,I):
- #I.shape = (batch,64) E.shape=(36,64)
- W = torch.mm(I,E.t())
- F = torch.nn.Softmax(1)
- W = F(W) # shape=(batch,36)
-
- output = []
-
- for i in range(3):
- pos = torch.argmax(W,1)
- list = []
- for b in range(len(pos)):
- list.append(E[pos[b], :])
-
- output1 = torch.stack([x for x in list], dim=0)
-
- output.append(output1)
-
- for j in range(I.shape[0]):
- W[j][pos[j]] = 0
-
- return output[0],output[1],output[2]
-
-
- if __name__ == '__main__':
-
- parser = argparse.ArgumentParser(description='Model Controller')
-
- parser.add_argument('--mode',type=str,default='pre_train',help='pre_train/test/train')
-
- parser.add_argument('--weights_path', type=str, default='/model', help='the path saving weights')
- parser.add_argument('--image_path',type=str,default="/dataset/X_train.npy")
- parser.add_argument('--label_path', type=str, default="/dataset/Y_train.npy")
- parser.add_argument('--knowledge_path',type=str, default="/code/DataSet/Knowledge_v2.json")
-
- parser.add_argument('--val_image_path',type=str, default='',help='val_imageset path')
- parser.add_argument('--val_label_path', type=str, default='', help='val_labelset path')
-
- parser.add_argument('--max_len', type=int, default=300)
- parser.add_argument('--pretrained_bert',type=str,default='/dataset/bert-base-chinese')
- # parser.add_argument('--out_dim',type=int,)
-
- parser.add_argument('--batch', type=int, default=64)
- parser.add_argument('--lr', type=float, default=1e-5, help='learning rate')
- parser.add_argument('--epoch', type=int, default=200, help='train_epochs')
-
- args = parser.parse_args()
-
- now_time = datetime.now()
- time_str = datetime.strftime(now_time, '%m-%d_%H-%M-%S')
- print(time_str)
-
- # model = MyModule(args).to(device)
- model = MyModule_v1().to(device)
-
- # '''余弦相似度损失'''
- # criterion = nn.CosineEmbeddingLoss(margin=0.3)
- # target = torch.tensor([[1,-1]], dtype=torch.float)#, reduction='sum'
- criterion = CrossEntropyLoss()
-
-
- '''优化器更新参数'''
- # optimizer = SGD(model.parameters(),lr=args.lr)
- # optimizer1 = Adam(model.parameters(),lr=args.lr)
- optimizer1 = Adam(model.parameters(), lr=args.lr)
-
- mydataset = Dataset_df.MyDataset(args.image_path,args.label_path)
- data_loader = DataLoader(dataset=mydataset,batch_size=args.batch, shuffle=True, pin_memory=True)
-
- '''训练'''
- if args.mode == 'pre_train':
-
- print("pre_train")
- pre_train_v1(model,criterion,optimizer1,data_loader,args)
- elif args.mode == 'test':
-
- print('test')
- model.load_state_dict(torch.load(r"Weights/model_0221.pt", map_location=device).state_dict())
-
- classifier_test(model, criterion, data_loader, args)
- else:
-
- print('train')
- new_model = Fusion_Module().to(device)
- model.load_state_dict(torch.load(r"Weights/model_0221.pt", map_location=device).state_dict())
- train(new_model,model,optimizer1,data_loader,args)
|