|
- # -*- coding = utf-8 -*-
- '''
- # @time:2023/3/12 15:33
- # Author:DFTL
- # @File:train_0312.py
- '''
-
-
- import argparse
- import time
- import os
- import torch
- import json
- from datetime import datetime
- from transformers import BertTokenizer,BertModel
- #from Models.Duiqi import MyModule
- import torch.nn as nn
- #from tqdm import tqdm
- from utils.Dataset_df import MyDataset,MyDataset_v2
- from torch.utils.data import DataLoader
- from torch.optim import SGD,Adam
- from torch.nn import CrossEntropyLoss
- from Module import MyModule_v9,MyModule_v8,MyModuel_v3,MyModuel_v4,MyModule_v5,MyModule_v6,MyModule_v6_1,MyModule_v6_2
- import numpy as np
- #from sklearn.metrics import classification_report
- from Fusion_block import Fusion_Module
-
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- target_names = ['class 0', 'class 1', 'class 2','class 3','class 4','class 5','class 6','class 7','class 8','class 9','class 10','class 11']
-
- def pre_train_(model,train_dataset,optimizer,scheduler,tokenizer,args):
-
- # with open(args.knowledge_path,'r',encoding='utf-8') as f:
- # knowledge_dict = json.load(f)
-
- #loss_min = 1000000
-
- for epoch in range(args.epoch):
-
- print('======================epoch:{}/{}========================='.format(epoch,args.epoch))
- epoch_loss = 0
- i = 0
-
- for data in train_dataset: #进度条tqdm
-
- optimizer.zero_grad()
- i+=1
-
- #影像数据
- image, label, event= data
- image = image.to(device)
- label = label.to(device)
-
- # e1,e2,e3 = getKnowledge(knowledge_dict,label)
-
- '''预训练的BERT对三元组文本编码'''
- #三元组数据,数量:3
- event_input1 = tokenizer(event, padding='longest', truncation=True, max_length=25, return_tensors="pt").to(device)
-
- # event_input1 = tokenizer(text, padding='longest', truncation=True, max_length=25, return_tensors="pt").to(device)
- # event_input1 = torch.cat([tokenizer(e, padding='longest', truncation=True, max_length=25, return_tensors="pt").to(device) for e in e1])
-
-
-
- # encodedevent1 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t1]).to(device)
- # encodedevent2 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t2]).to(device)
- # encodedevent3 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t3]).to(device)
-
-
- loss1,loss2,loss3 = model(image,event_input1,label)
-
- loss = loss1 + 0.5*loss2 + 0.5*loss3
-
- loss.backward()
- optimizer.step()
-
- print('\r','step: ',i,' loss: {:.6f}'.format(loss.data), end='', flush=True)
-
- epoch_loss += loss
- scheduler.step()
- # epoch_loss += loss_step
- avg_loss = epoch_loss / len(train_dataset)
- print('\n ====step_loss:{:.6f}==== \n'.format(avg_loss.data))
-
- '''save models'''
- # if avg_loss < loss_min:
- if epoch % args.save_perepoch == 0:
- # loss_min = avg_loss
- weight_name = 'epoch_'+str(epoch) + '_loss_' + str('{:.6f}'.format(avg_loss.data)) + '.pt'
- torch.save(model.state_dict(),os.path.join(args.weights_path, weight_name))
- print('epoch: {} | loss: {:.6f} | Saving model... \n'.format(epoch,avg_loss.data))
-
-
- def getKnowledge(knowledge_dict,labels): #根据标签取出三条三元组
-
- kd1 = []
- kd2 = []
- kd3 = []
-
- for label in labels:
-
- triples = knowledge_dict[str(label.int().item())]#根据标签取出相关三元组
-
- kd1.append(triples[0][0] + triples[0][1] + triples[0][2])
- kd2.append(triples[1][0] + triples[1][1] + triples[1][2])
- kd3.append(triples[2][0] + triples[2][1] + triples[2][2])
-
- return kd1,kd2,kd3
-
- def getTokenizer(bert_base_chinese_path):
-
- tokenizer = BertTokenizer.from_pretrained(bert_base_chinese_path)
- bert = BertModel.from_pretrained(bert_base_chinese_path)
- # encoded_text = get_encoded_text(text,max_len,tokenizer,bert)
- return tokenizer,bert.to(device)
-
- def getEncodedtext(triples,max_len,tokenizer,bert):
-
- # tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
-
- tokenized = tokenizer(triples, max_length=max_len, truncation=True)
- tokens = tokenized['input_ids']
- # masks = tokenized['attention_mask']
- # text_len = len(tokens)
-
- token_ids = torch.tensor([tokens],dtype=torch.long).to(device)
- # masks = torch.tensor(masks, dtype=torch.bool)
-
- encoded_text = bert(token_ids)[1]
-
- return encoded_text #, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails # , json_data['spo_list']
-
- def pre_train(model,loss_,optimizer,train_dataset,args):
-
- model.train()
-
- tokenizer, bert = getTokenizer(args.pretrained_bert)
-
- with open(args.knowledge_path,'r',encoding='utf-8') as f:
- knowledge_dict = json.load(f)
-
- loss_min = 1000000
- for epoch in range(args.epoch):
-
-
- print('======================epoch:{}/{}========================='.format(epoch,args.epoch))
- epoch_loss = 0
- i = 0
-
- for data in train_dataset: #进度条tqdm
-
- optimizer.zero_grad()
- i+=1
-
- #影像数据
- image, label = data
- image = image.to(device)
- label = label.to(device)
-
- t1,t2,t3 = getKnowledge(knowledge_dict,label)
-
- '''预训练的BERT对三元组文本编码'''
- #三元组数据,数量:3
-
- time1 = time.time()
-
- encodedevent1 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t1]).to(device)
- encodedevent2 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t2]).to(device)
- encodedevent3 = torch.cat([getEncodedtext(t,args.max_len,tokenizer,bert) for t in t3]).to(device)
-
- time2 = time.time()
- print('bert time: ', time2 - time1)
-
- time3 = time.time()
-
- loss1,loss2 = model(loss_,label,image,encodedevent1,encodedevent2,encodedevent3)
-
- time4 = time.time()
- print('model time: ', time4 - time3)
-
- loss = 0.7*loss1 + 0.3*loss2
-
- loss.backward()
- optimizer.step()
-
- print('\r','step: ',i,' loss: ', loss, end='', flush=True)
-
- epoch_loss += loss
-
- # triples = knowledge_dict[str(label.int().item())]#根据标签取出相关三元组
- #
- # kd = []
- # for triple in triples:
- # event = triple[0]+triple[1]+triple[2]
- # kd.append(event)
-
- # print(kd)
- # output1, output2 = model(image,kd)
- # loss_step.backward()
- # optimizer.step()
- # print('\r','step_loss:{}'.format(loss_step / len(output2)),end='')
-
-
-
- # epoch_loss += loss_step
- print('step_loss:{} \n'.format(epoch_loss / len(train_dataset)))
-
- '''save models'''
- if epoch_loss/len(train_dataset) < loss_min:
- loss_min = epoch_loss/len(train_dataset)
- torch.save(model,os.path.join("Weights","model.pt"))
-
- print('epoch: {} | loss: {} | Saving model... \n'.format(epoch,loss_min))
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='Model Controller')
-
- parser.add_argument('--mode',type=str,default='pre_train',help='pre_train/test/train')
-
- parser.add_argument('--weights_path', type=str, default='/model/weights', help='the path saving weights')
- parser.add_argument('--image_path',type=str,default="/dataset/X_train.npy")
- parser.add_argument('--label_path', type=str, default="/dataset/Y_train.npy")
- parser.add_argument('--knowledge_path',type=str, default="/code/DataSet/Knowledge_v2.json")
- parser.add_argument('--config_path', type=str, default="/code/configs/config_bert.json")
- parser.add_argument('--ckpt_url', type=str, default='')
-
- parser.add_argument('--val_image_path',type=str, default='',help='val_imageset path')
- parser.add_argument('--val_label_path', type=str, default='', help='val_labelset path')
-
- parser.add_argument('--max_len', type=int, default=300)
- parser.add_argument('--pretrained_bert',type=str,default='/dataset/bert-base-chinese')
- parser.add_argument('--embed_dim',type=int,default=64)
- parser.add_argument('--num_class',type=int,default=12)
-
- parser.add_argument('--batch', type=int, default=128)
- parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
- parser.add_argument('--epoch', type=int, default=300, help='train_epochs')
- parser.add_argument('--save_perepoch', type=int, default=8, help='train_epochs')
-
- args = parser.parse_args()
-
- now_time = datetime.now()
- time_str = datetime.strftime(now_time, '%m-%d_%H-%M-%S')
- print(time_str)
-
- if not os.path.exists(args.weights_path):
- os.makedirs(args.weights_path)
-
- tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert)
-
- # model = MyModule(args).to(device)
- model = MyModule_v9(args).to(device)
- #model.load_state_dict(torch.load(args.ckpt_url,map_location=device))
-
- #print(model)
- total = sum([param.nelement() for param in model.parameters()])
- print("Number of parameter: %.2fM" % (total / 1e6))
- print('epoch:{}|batch_size:{}|lr:{}|save_perepoch:{}'.format(args.epoch,args.batch,args.lr,args.save_perepoch))
-
- # '''余弦相似度损失'''
- # criterion = nn.CosineEmbeddingLoss(margin=0.3)
- # target = torch.tensor([[1,-1]], dtype=torch.float)#, reduction='sum'
- criterion = CrossEntropyLoss()
-
-
- '''优化器更新参数'''
- # optimizer = SGD(model.parameters(),lr=args.lr)
- # optimizer1 = Adam(model.parameters(),lr=args.lr)
- optimizer1 = Adam(model.parameters(), lr=args.lr,betas=(0.9,0.99))
- scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
- optimizer1,
- T_0=2, # T_0就是初始restart的epoch数目
- T_mult=4, # T_mult就是重启之后因子,即每个restart后,T_0 = T_0 * T_mult
- eta_min=1e-5 # 最低学习率
- )
-
- # mydataset = Dataset_df.MyDataset(args.image_path,args.label_path)
- mydataset = MyDataset_v2(args.image_path,args.label_path,args.knowledge_path)
- data_loader = DataLoader(dataset=mydataset,batch_size=args.batch, shuffle=True, pin_memory=True)
-
- pre_train_(model,data_loader,optimizer1,scheduler,tokenizer,args)
-
- '''训练'''
- ''' if args.mode == 'pre_train':#对齐部分训练
-
- print("pre_train")
-
- # model.load_state_dict(torch.load(r"Weights/model_0221.pt",map_location=device).state_dict())
-
- pre_train_v1(model,criterion,optimizer1,data_loader,args)
- elif args.mode == 'test':#中间分类头测试
-
- print('test')
- model.load_state_dict(torch.load(r"Weights/model_0221.pt", map_location=device).state_dict())
-
- classifier_test(model, criterion, data_loader, args)
- elif args.mode == 'final_test':
- print('final test')
-
- model.load_state_dict(torch.load(r"Weights/model_0221.pt", map_location=device).state_dict())
-
- model2 = Fusion_Module().to(device)
-
- model2.load_state_dict(torch.load(r"Weights\epoch_26_loss_0.0006.pt",map_location=device).state_dict())
-
- final_test(model,model2,data_loader,args)
- else:
- print('train')
- new_model = Fusion_Module().to(device)
- model.load_state_dict(torch.load(r"Weights/model_0221.pt", map_location=device).state_dict())
- train(new_model,model,scheduler,data_loader,args)'''
|