|
- import os
- import argparse
- import json
- import numpy as np
- import torch
- import torch.nn as nn
- import copy
- import pandas as pd
- from tqdm import tqdm
- from torch.utils.data import Dataset, DataLoader
- from torch.nn import BCEWithLogitsLoss
- from transformers import AdamW, get_linear_schedule_with_warmup
- from datetime import datetime
- import random
-
- from encoder import SparseEncoder, DenseEncoder
- from model import BertNormalization, BertSelection, GCN_update
- from utils import set_seed, read_data, read_dictionary, score_f1, marginal_nll, data_filter
- from utils import get_score_matrix, retrieve_candidate, evaluate_candidate, build_surgery_graph, build_disease_graph
- from utils import normalize_sp_adj, get_spAdj, normalize
-
- TASK2DATADIR = {
- "surgery": "./dataset/surgery",
- "disease": "./dataset/disease_chip2020",
- "disease-online": "./dataset/disease_chip2020_online",
- }
-
- TASK2GRAPH = {
- "surgery": "./dataset/surgery_3.json",
- "disease": "./dataset/disease_yibao.json",
- "disease-online": "./dataset/disease_yibao.json",
- }
-
- TASK2STD_DICT_PATH = {
- "surgery": "./dataset/std_surgery.txt",
- "disease": "./dataset/std_disease.txt",
- "disease-online": "./dataset/disease_chip2020_online/std_disease.txt",
- }
-
- def parse_args():
- parser = argparse.ArgumentParser()
-
- # 数据集相关参数
- parser.add_argument('--model_dir', required=True, help='BERT预训练模型权重文件夹')
- parser.add_argument('--bert_type', required=True, help='BERT类型 roberta-base/nezha-base')
- parser.add_argument('--task_name', help='数据集名称') # 一个数据集都需要有至少三个文件,train.txt, dev.txt, test.txt base.dictionary
- # parser.add_argument('--extra_dictionary', type=str, help="补充标准术语集,缩略词同义词等")
- parser.add_argument('--output_dir', type=str, required=True, help="输出文件夹,保存模型")
- # parser.add_argument('--task_type', type=str, default=None, help="指定标准术语集类型 disease/surgery")
-
- # 训练相关参数
- parser.add_argument('--seed', type=int, default=42, help="随机种子")
- parser.add_argument('--topk', type=int, default=20, help="候选词个数")
- parser.add_argument('--bert_learning_rate', default=1e-5, type=float, help='BERT学习率')
- parser.add_argument('--other_learning_rate', default=0.001, type=float, help="非BERT部分学习率")
- parser.add_argument('--weight_decay', default=0.01, type=float)
- parser.add_argument('--warmup', help='warm up', default=0.1, type=float)
- parser.add_argument('--train_batch_size', default=16, type=int)
- parser.add_argument('--dev_batch_size', default=32, type=int)
- parser.add_argument('--epoch', default=10, type=int)
- parser.add_argument('--max_len', type=int, help="最大长度,如果设为-1则直接batch内进行填充到最长")
-
- # 动态候选
- parser.add_argument('--dynamic', action="store_true", help="动态候选")
- parser.add_argument('--dense_ratio', type=float, default=0.5, help="稠密表示选出的候选个数在topk个中的占比")
- # 实体标准化使用特征
- parser.add_argument('--add_sparse', action="store_true", help="将稀疏表示添加到实体标准化模块中")
- parser.add_argument('--add_dense2score', action="store_true", help="增加稠密表示,加到实体标准化模块中")
- parser.add_argument('--pair_weight', default=1, type=float, help="pair表示的权重")
-
- # 共享BERT相关参数 不选择共享BERT则会不共享BERT
- parser.add_argument('--share_bert', action="store_true", help="共享BERT")
- parser.add_argument('--listwise_loss', action="store_true", help="list wise loss")
- # parser.add_argument('--add_threshold', action="store_true", help="给动态候选部分的loss设置阈值")
- parser.add_argument('--loss_combine', type=str, default="add", help="多任务loss的合并方式,add就是直接相加,alpha表示基于不确定度的求和")
- parser.add_argument('--normalization_loss_weight', type=float, default=1, help="loss加权合并时loss_match的权重")
-
- # 增加pair表示间的交互
- parser.add_argument('--add_match_atten', action="store_true", help="pair表示加交互")
- parser.add_argument('--score_type', type=str, default="dot", help="交互模块进行self attention时使用的打分函数,默认使用缩放点积")
-
- parser.add_argument('--do_train', action="store_true")
- parser.add_argument('--do_test', action="store_true")
- parser.add_argument('--add_test', action="store_true", help="训练最终模型,测试集加到训练集中")
- parser.add_argument('--debug', action="store_true")
- parser.add_argument('--do_augment', action="store_true")
-
- # 层次结构编码相关参数
- parser.add_argument('--add_graph', action="store_true", help="是否增加层次结构编码模块")
- parser.add_argument('--graph_file', type=str, help="用于构造层级结构的文件")
-
- args = parser.parse_args()
- return args
-
- class CandidateDataset(Dataset):
- def __init__(self, data, dictionary, candidates, sparse_score, graph_embedding=None, dict_idx2graph_idx=None, name2code=None):
- self.data = data
- self.candidates = candidates
- self.dict_names = dictionary[:, 0]
- self.dict_cuis = dictionary[:, 1]
- self.sparse_score = sparse_score
- self.graph_embedding = graph_embedding
- self.dict_idx2graph_idx = dict_idx2graph_idx
- self.name2code = name2code
-
- def __getitem__(self, idx):
- name, cuis = self.data[idx][0], self.data[idx][1]
- cand_idx = self.candidates[idx]
- cand_names = self.dict_names[cand_idx]
- cand_cuis = self.dict_cuis[cand_idx]
-
- cand_graph_emb = None
- if self.graph_embedding is not None:
- cand_graph_emb = self.graph_embedding[cand_idx]
-
- sparse_score = self.sparse_score[idx][cand_idx]
-
- label = self.get_label(cuis, cand_cuis, self.name2code)
-
- return name, cuis, cand_names, cand_cuis, sparse_score, cand_graph_emb, label
-
- def __len__(self):
- return len(self.data)
-
- @staticmethod
- def get_label(cuis, cand_cuis, name2code=None):
- label = []
- if name2code is None:
- for cand_cui in cand_cuis:
- if cand_cui in cuis:
- label.append(1)
- else:
- label.append(0)
- else:
- cuis_code = [name2code.get(cui, "--") for cui in cuis]
- for cand_cui in cand_cuis:
- cand_code = name2code.get(cand_cui, "--")
- tag = 0
- for cui_code in cuis_code:
- if cand_code == cui_code and cand_code != "--":
- tag = 2
- break
- elif cand_code in cui_code and cand_code != '--':
- tag = 1
- label.append(tag)
- return label
-
- def collate_wraper(tokenizer, max_len, mode="train"):
- def collate_fn(batch_data):
- pairs = []
- labels = []
-
- mentions = []
- candidates = []
- sparse_score_list = []
- cand_graph_emb_list = []
-
- for d in batch_data:
- # cui, _, cand_cuis, sparse_score, cand_graph_emb, label = d
- name, cuis, cand_names, cand_cuis, sparse_score, cand_graph_emb, label = d
- mentions.append(name)
- sparse_score_list.append(sparse_score)
-
- if cand_graph_emb is not None:
- cand_graph_emb_list.append(cand_graph_emb)
-
- for cand_name in cand_names:
- pairs.append([name, cand_name])
- candidates.append(cand_name)
-
- labels.extend(label)
-
- sparse_scores = np.stack(sparse_score_list)
-
- if len(cand_graph_emb_list) > 0:
- cand_graph_embs = torch.stack(cand_graph_emb_list)
- else:
- cand_graph_embs = None
-
- if max_len != -1:
- mention_tokenized = tokenizer(mentions, padding="max_length", max_length=max_len, truncation=True)
- candidates_tokenized = tokenizer(candidates, padding="max_length", max_length=max_len, truncation=True)
- pair_tokenized = tokenizer(pairs, padding="max_length", max_length=max_len, truncation=True)
- else:
- mention_tokenized = tokenizer(mentions, padding="longest")
- candidates_tokenized = tokenizer(candidates, padding="longest")
- pair_tokenized = tokenizer(pairs, padding="longest")
-
- if mode == "train":
- return mention_tokenized, candidates_tokenized, pair_tokenized, sparse_scores, cand_graph_embs, labels
- else:
- return mention_tokenized, candidates_tokenized, pair_tokenized, sparse_scores, cand_graph_embs, labels, batch_data
-
- return collate_fn
-
- def evaluate(model, loader, topk, output_file=None):
- model.eval()
-
- preds = []
- golds = []
- top1 = []
-
- data = []
-
- for batch in tqdm(loader, desc="evaluating"):
- mentions, candidates, pairs, sparse_scores, cand_graph_embs, label, batch_data = batch
- for k in mentions.keys():
- mentions[k] = torch.tensor(mentions[k]).cuda()
- for k in candidates.keys():
- candidates[k] = torch.tensor(candidates[k]).cuda()
- for k in pairs.keys():
- pairs[k] = torch.tensor(pairs[k]).cuda()
- sparse_scores = torch.from_numpy(sparse_scores).cuda()
-
- if cand_graph_embs is not None:
- cand_graph_embs = cand_graph_embs.cuda()
-
- with torch.no_grad():
- _, output = model(mentions, candidates, pairs, sparse_scores, cand_graph_embs)
- output = nn.Sigmoid()(output)
- pred = (output > 0.5).long().cpu().flatten().numpy()
- pred = pred.reshape(-1, topk)
- label = np.array(label).reshape(-1, topk)
-
- preds.append(pred)
- golds.append(label)
-
- data.extend(batch_data)
-
- preds = np.concatenate(preds, axis=0)
- golds = np.concatenate(golds, axis=0)
-
- total_match_pair = 0
- acc_match_pair = 0
- for pred, gold in zip(preds, golds):
- total_match_pair += len(pred)
- acc_match_pair += sum(pred == gold)
- acc = acc_match_pair / total_match_pair
-
- # name, cuis, cand_names, cand_cuis, sparse_score, cand_graph_emb, label = d
- mentions = [k[0] for k in data]
- gold_cuis = [list(k[1]) for k in data]
- candidates_names = [k[2] for k in data]
- candidates_cuis = [k[3] for k in data]
-
- pred_cuis = []
-
- for candidate_cui, pred in zip(candidates_cuis, preds):
- pred_cui = candidate_cui[pred == 1]
- pred_cuis.append(pred_cui)
-
- if output_file is not None:
- with open(output_file+".pred", "w", encoding="utf-8") as f:
- for pred_cui in pred_cuis:
- f.write("##".join(pred_cui)+"\n")
- with open(output_file+".candidate", "w", encoding="utf-8") as f:
- for candidates_name in candidates_names:
- f.write("##".join(candidates_name)+"\n")
-
- return score_f1(gold_cuis, pred_cuis), acc
-
- def get_label_index(train_data, cuis_in_dictionary):
- label_index = []
- for idx, (mention, label) in enumerate(train_data):
- temp = []
- for L in label:
- temp.append(np.argmax(cuis_in_dictionary == L).tolist())
-
- assert len(temp) == len(label)
-
- label_index.append(temp)
-
- return label_index
-
- def get_hybrid(args, sparse_candidate, dense_candidate, label_candidate):
- # 混合 稀疏表示的topk相似度和稠密表示的topk个候选 得到最终的topk个候选
- # 训练时强行把标准答案增加到候选集合中
- hybrid_candidata = []
- n_dense = int(args.topk * args.dense_ratio)
- n_sparse = args.topk - n_dense
- for idx in range(len(sparse_candidate)):
- s_cand = sparse_candidate[idx]
- d_cand = dense_candidate[idx]
-
- cand = s_cand[:n_sparse]
- for k in d_cand:
- if len(cand) == args.topk:
- break
- if k not in cand:
- cand = np.append(cand, k)
-
- assert len(cand) == args.topk
-
- if label_candidate is not None:
- l_cand = label_candidate[idx]
- cand[-1*len(l_cand):] = l_cand
-
- hybrid_candidata.append(cand)
-
- return hybrid_candidata
-
- if __name__ == "__main__":
- args = parse_args()
-
- set_seed(args.seed)
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- n_gpu = torch.cuda.device_count()
- print("device", device, "n_gpu", n_gpu)
- args.n_gpu = n_gpu
-
- if not os.path.exists(args.output_dir):
- os.makedirs(args.output_dir)
-
- with open(os.path.join(args.output_dir, "args.json"), "w", encoding="utf-8") as f:
- print(json.dumps(args.__dict__, indent=4))
- json.dump(args.__dict__, f, indent=4)
-
- ################################################# 读入数据
- # 训练数据格式(实体,标准实体)
- # data_dirs = [TASK2DATADIR[task_name] for task_name in args.task_names] # 多个数据集
- data_dirs = [TASK2DATADIR[args.task_name]]
- train_data, dev_data, test_data = [], [], []
- for data_dir in data_dirs:
- train_file = os.path.join(data_dir, "train.txt")
- dev_file = os.path.join(data_dir, "dev.txt")
- test_file = os.path.join(data_dir, "test.txt")
-
- train_temp = read_data(train_file)
- dev_temp = read_data(dev_file)
- test_temp = read_data(test_file)
-
- train_data.extend(train_temp)
- dev_data.extend(dev_temp)
- test_data.extend(test_temp)
-
- # 用训练集中出现过的标准词集合做基础词典,确保训练集的golden label一定能找到
- train_std_term = []
- for t in train_data:
- train_std_term.extend(list(t[1]))
- train_std_term = sorted(list(set(train_std_term))) # 确保每次跑标准词顺序一致
-
- print("原始数据总数目: train={} dev={} test={}".format(len(train_data), len(dev_data), len(test_data)))
- print("原始基础词典大小: base_dictionary={} ".format(len(train_std_term)))
- print("数据示例:", train_data[0], dev_data[0], test_data[0])
-
- std_file = TASK2STD_DICT_PATH[args.task_name]
- base_dictionary = read_dictionary(std_file)
-
- std_term_set = set([k[0] for k in base_dictionary])
- for e in train_std_term:
- if e not in std_term_set:
- base_dictionary.append((e,e))
-
- print("标准术语集扩充后词典大小: base_dictionary={} ".format(len(base_dictionary)))
- print("字典示例:", base_dictionary[0], base_dictionary[-1])
-
- # 扩充缩写词典
- # if args.extra_dictionary is not None:
- # # 同义词缩略词扩充到词典中
- # dictionary_temp = read_dictionary(args.extra_dictionary)
- # base_dictionary.extend(dictionary_temp)
- # print("缩略词表扩充后字典数目: base_dictionary={}".format(len(base_dictionary)))
- # print("扩充后字典示例:", base_dictionary[0], base_dictionary[-1])
-
- # 确定了词表后把它存下,测试时用
- with open(os.path.join(args.output_dir, "dictionary.txt"), "w", encoding="utf-8") as f:
- for d in base_dictionary:
- f.write(d[0]+"\t"+d[1]+"\n")
-
- # 对训练集做数据增强,用一对一的构造一对多的
- if args.do_augment:
- data_augmented = []
- one2one = []
- # 把训练集中的所有标准词都加进去
- for d in train_std_term:
- one2one.append((d, tuple([d])))
- print(one2one[:5])
- for d in train_data:
- if len(d[1]) == 1:
- one2one.append(d)
- data_augmented.append(d)
- aug_size = 2000
- while aug_size > 0:
- k = random.choice([2,3])
- concated_mention = ""
- concated_entities = tuple([])
- while k > 1:
- mention, entity = random.choice(one2one)
- concated_mention = concated_mention + mention + ","
- concated_entities += entity
- k -= 1
- data_augmented.append((concated_mention[:-1], concated_entities))
- aug_size -= 1
- train_data = data_augmented
- print("数据增强后总数目: train={} dev={} test={}".format(len(train_data), len(dev_data), len(test_data)))
-
- # 层级关系做数据增强
- # if args.do_augment:
- # df = pd.read_excel("./dataset/std_files/标准术语集——国际疾病分类 ICD-10北京临床版v601.xlsx", header=None)
-
- # entry_list = []
- # for code, name, _, _, _, _ in df.values:
- # entry_list.append((code, name))
-
- # root = Node("疾病")
- # parent_nodes = []
-
- # data_augment = []
-
- # for i in range(len(entry_list)):
- # entry = entry_list[i]
- # if "/" not in entry[0]: # 过滤掉增补的
- # while len(parent_nodes) > 0 and parent_nodes[-1].name.split("-")[0] not in entry[0]:
- # parent_nodes.pop(-1)
-
- # parent = root
- # if len(parent_nodes) >= 1:
- # parent = parent_nodes[-1]
-
- # new_node = Node("---".join(entry), parent=parent)
- # if "---" in parent.name:
- # data_augment.append((entry[1], tuple([parent.name.split("---")[1]])))
- # if i+1 < len(entry_list) and entry_list[i][0] in entry_list[i+1][0]:
- # parent_nodes.append(new_node)
-
- # def filter_data(data):
- # new_data = []
- # for d in data:
- # if "其他" in d[0] or "其他" in d[1][0] or "不可归类在他处者" in d[1][0] or "未特指" in d[1][0]:
- # continue
- # if d[1][0] in train_std_term and d[1][0] != d[0]:
- # new_data.append(d)
- # return new_data
-
- # data_augment = filter_data(data_augment)
-
- # train_data.extend(data_augment)
-
- # print("数据增强后总数目: train={} dev={} test={}".format(len(train_data), len(dev_data), len(test_data)))
-
- # list-wise loss
- name2code = None
- if args.listwise_loss:
- df = pd.read_excel("./dataset/std_files/标准术语集——国际疾病分类 ICD-10北京临床版v601.xlsx", header=None)
- name2code = {}
- for code, name, _, _, _, _ in df.values:
- name2code[name] = code
-
- ################################################# 开始编码
- train_data = np.array(train_data, dtype=object)
- dev_data = np.array(dev_data, dtype=object)
- test_data = np.array(test_data, dtype=object)
- base_dictionary = np.array(base_dictionary, dtype=object)
-
- names_in_train_data = train_data[:, 0]
- names_in_dev_data = dev_data[:, 0]
- names_in_test_data = test_data[:, 0]
-
- names_in_base_dictionary = base_dictionary[:, 0]
-
- print("开始计算稀疏编码")
- sparse_encoder = SparseEncoder()
- sparse_encoder.fit(names_in_train_data.tolist() + names_in_base_dictionary.tolist()) # 用训练数据+词典数据训练稀疏编码器
-
- # 数据和标准术语集的稀疏表示
- dict_sparse = sparse_encoder.transform(names_in_base_dictionary)
-
- train_sparse = sparse_encoder.transform(names_in_train_data)
- train_sparse_score = get_score_matrix(train_sparse, dict_sparse)
- train_label_candidate = get_label_index(train_data, base_dictionary[:, 1]) # 得到训练集的标准实体,在混合时将它加入
- train_sparse_candidate = retrieve_candidate(train_sparse_score, topk=args.topk)
- train_sparse_recall = evaluate_candidate(train_data, base_dictionary, train_sparse_candidate)
-
- dev_sparse = sparse_encoder.transform(names_in_dev_data)
- dev_sparse_score = get_score_matrix(dev_sparse, dict_sparse)
- dev_sparse_candidate = retrieve_candidate(dev_sparse_score, topk=args.topk)
- dev_sparse_recall = evaluate_candidate(dev_data, base_dictionary, dev_sparse_candidate)
- print("训练集/验证集 稀疏表示召回率: {:.4f}/{:.4f}".format(train_sparse_recall, dev_sparse_recall))
-
- if args.do_test:
- test_sparse = sparse_encoder.transform(names_in_test_data)
- # test_dict_sparse = sparse_encoder.transform(names_in_test_dictionary)
- test_sparse_score = get_score_matrix(test_sparse, dict_sparse)
- test_sparse_candidate = retrieve_candidate(test_sparse_score, topk=args.topk)
- test_sparse_recall = evaluate_candidate(test_data, base_dictionary, test_sparse_candidate)
- print("测试集 稀疏表示召回率:{:.4f}".format(test_sparse_recall))
-
- if "nezha" in args.bert_type:
- bert_type = "nezha"
- else:
- bert_type = "bert"
-
- print("初始化稠密表示编码器")
- dense_encoder = DenseEncoder(bert_path=args.model_dir, bert_type=bert_type)
-
- ################################################# 层级结构构造图
- graph = None
- dict_idx2graph_idx = None
- if args.add_graph:
- # 根据标准术语集的编码构造层级结构
- graph_file = TASK2GRAPH[args.task_name]
- if "surgery" == args.task_name:
- dict_idx2graph_idx, sp_adj, graph_names = build_surgery_graph(graph_file, base_dictionary)
- elif "disease" == args.task_name:
- dict_idx2graph_idx, sp_adj, graph_names = build_disease_graph(graph_file, base_dictionary)
- else:
- raise ValueError("没找到对应的图网络定义文件")
- print("图结构定义文件:", graph_file)
- graph = GCN_update(in_features=768, out_features=768, bias=True).cuda()
- adj = normalize_sp_adj(get_spAdj(sp_adj))
-
- candidate_selector = None
- if not args.share_bert:
- print("非共享BERT,拷贝一个新BERT...") # 如果不共享bert就拷贝一个bert用于后面预测标准实体
- bert4normalization = copy.deepcopy(dense_encoder.encoder).cuda()
- # 专门用于动态候选的model
- candidate_selector = BertSelection(dense_encoder.encoder, args.topk)
-
- # 用于预测标准实体的model,其中使用的bert如果共享bert就和dense_encoder中一样,否则就是新的拷贝
- model = BertNormalization(dense_encoder.encoder if args.share_bert else bert4normalization,
- dropout_prob=0.3,
- topk=args.topk,
- score_type=args.score_type,
- pair_weight=args.pair_weight,
- add_sparse=args.add_sparse,
- calculate_dense=args.share_bert, # 共享BERT则标准化模型就需要计算稠密表示内积
- add_match_atten=args.add_match_atten,
- add_graph=args.add_graph,
- graph_model=graph)
- model.cuda()
- # if args.n_gpu > 1:
- # model = nn.DataParallel(model)
-
- # 实体标准化模型的优化器
- no_decay = ["bias", "LayerNorm.weight"]
- model_param = list(model.named_parameters())
-
- bert_param_optimizer, other_param_optimizer = [], []
- for name, para in model_param:
- space = name.split('.')
- if space[0] == 'bert':
- bert_param_optimizer.append((name, para))
- else:
- other_param_optimizer.append((name, para))
-
- optimizer_grouped_parameters = [
- # BERT 参数
- {"params": [p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay)],
- "weight_decay": args.weight_decay, 'lr': args.bert_learning_rate},
- {"params": [p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay)],
- "weight_decay": 0.0, 'lr': args.bert_learning_rate},
-
- # 其他参数,差分学习率
- {"params": [p for n, p in other_param_optimizer if not any(nd in n for nd in no_decay)],
- "weight_decay": args.weight_decay, 'lr': args.other_learning_rate},
- {"params": [p for n, p in other_param_optimizer if any(nd in n for nd in no_decay)],
- "weight_decay": 0.0, 'lr': args.other_learning_rate},
- ]
-
- t_total = int(args.epoch * (len(train_data) // args.train_batch_size + 1))
- optimizer = AdamW(optimizer_grouped_parameters, lr=args.bert_learning_rate)
- scheduler = None
- if args.warmup > 0:
- scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup * t_total), num_training_steps=t_total)
-
- # 如果不共享BERT,动态候选模块和实体标准化模块分开优化
- optimizer4selector = None
- if not args.share_bert:
- no_decay = ["bias", "LayerNorm.weight"]
- model_param = list(candidate_selector.named_parameters())
- optimizer_grouped_parameters = [
- {"params": [p for n, p in model_param if not any(nd in n for nd in no_decay)],
- "weight_decay": args.weight_decay, 'lr': args.bert_learning_rate},
- {"params": [p for n, p in model_param if any(nd in n for nd in no_decay)],
- "weight_decay": 0.0, 'lr': args.bert_learning_rate}
- ]
- optimizer4selector = AdamW(optimizer_grouped_parameters, lr=args.bert_learning_rate)
-
- # 选候选和标准化的两个loss function
- selection_loss_func = marginal_nll
-
- dev_result_output_dir = os.path.join(args.output_dir, "dev_results")
- if not os.path.exists(dev_result_output_dir):
- os.makedirs(dev_result_output_dir)
-
- if args.do_test:
- test_result_output_dir = os.path.join(args.output_dir, "test_results")
- if not os.path.exists(test_result_output_dir):
- os.makedirs(test_result_output_dir)
-
- ################################################# 开始训练
- for ep in range(args.epoch):
- print("************** ep: %d *****************" % ep)
-
- epoch_output_dir = os.path.join(args.output_dir, "ep-%d" % ep)
- if not os.path.exists(epoch_output_dir):
- os.makedirs(epoch_output_dir)
-
- # 提前把模型存下来
- torch.save(dense_encoder.state_dict(), os.path.join(epoch_output_dir, "dense_encoder.pt"))
- sparse_encoder.save_encoder(path=os.path.join(epoch_output_dir, "sparse_encoder.pt"))
- if graph is not None:
- torch.save(graph.state_dict(), os.path.join(epoch_output_dir, "graph.pt"))
-
- if args.dynamic or ep == 0:
-
- dict_dense = dense_encoder.transform(names_in_base_dictionary, max_len=args.max_len, desc="train dict dense")
- train_dense = dense_encoder.transform(names_in_train_data, max_len=args.max_len, desc="train dense")
- dev_dense = dense_encoder.transform(names_in_dev_data, max_len=args.max_len, desc="dev dense")
-
- graph_embedding = None
- if args.add_graph:
- feature_feat = dense_encoder.transform(graph_names, max_len=args.max_len, desc="graph node")
- feature_feat = normalize(feature_feat)
- feature_feat = torch.tensor(feature_feat).cuda()
- adj = adj.cuda()
- graph_embedding = graph(feature_feat, adj)
-
- train_dense_score = get_score_matrix(train_dense, dict_dense)
-
- train_dense_candidate = retrieve_candidate(train_dense_score, topk=args.topk)
- train_dense_recall = evaluate_candidate(train_data, base_dictionary, train_dense_candidate)
-
- dev_dense_score = get_score_matrix(dev_dense, dict_dense)
- dev_dense_candidate = retrieve_candidate(dev_dense_score, topk=args.topk)
- dev_dense_recall = evaluate_candidate(dev_data, base_dictionary, dev_dense_candidate)
- print("训练集/验证集 稠密表示召回率: {:.8f}/{:.8f}".format(train_dense_recall, dev_dense_recall))
-
- if args.do_test:
- test_dense = dense_encoder.transform(names_in_test_data, max_len=args.max_len, desc="test dense")
- test_dense_score = get_score_matrix(test_dense, dict_dense)
- test_dense_candidate = retrieve_candidate(test_dense_score, topk=args.topk)
- test_dense_recall = evaluate_candidate(test_data, base_dictionary, test_dense_candidate)
- print("测试集 稠密表示召回率: {:.8f}".format(test_dense_recall))
-
- train_hybrid_candidata = get_hybrid(args, train_sparse_candidate, train_dense_candidate, train_label_candidate)
- train_hybrid_recall = evaluate_candidate(train_data, base_dictionary, train_hybrid_candidata)
- dev_hybrid_candidata = get_hybrid(args, dev_sparse_candidate, dev_dense_candidate, None)
- dev_hybrid_recall = evaluate_candidate(dev_data, base_dictionary, dev_hybrid_candidata)
- print("训练集/验证集 混合召回率:{:.8f}/{:.8f}".format(train_hybrid_recall, dev_hybrid_recall))
-
- if args.do_test:
- test_hybrid_candidata = get_hybrid(args, test_sparse_candidate, test_dense_candidate, None)
- test_hybrid_recall = evaluate_candidate(test_data, base_dictionary, test_hybrid_candidata)
- print("测试集 混合召回率:{:.8f}".format(test_hybrid_recall))
-
- train_dataset = CandidateDataset(train_data, base_dictionary, train_hybrid_candidata, train_sparse_score, graph_embedding=graph_embedding, dict_idx2graph_idx=dict_idx2graph_idx, name2code=name2code)
- train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_wraper(dense_encoder.tokenizer, args.max_len, mode="train"))
-
- dev_dataset = CandidateDataset(dev_data, base_dictionary, dev_hybrid_candidata, dev_sparse_score, graph_embedding=graph_embedding, dict_idx2graph_idx=dict_idx2graph_idx, name2code=name2code)
- dev_loader = DataLoader(dev_dataset, batch_size=args.dev_batch_size, shuffle=False, collate_fn=collate_wraper(dense_encoder.tokenizer, args.max_len, mode="dev"))
-
- if args.do_test:
- test_dataset = CandidateDataset(test_data, base_dictionary, test_hybrid_candidata, test_sparse_score, graph_embedding=graph_embedding, name2code=name2code)
- test_loader = DataLoader(test_dataset, batch_size=args.dev_batch_size, shuffle=False, collate_fn=collate_wraper(dense_encoder.tokenizer, args.max_len, mode="dev"))
-
- epoch_loss = 0
- epoch_normalization_loss = 0
- epoch_selection_loss = 0
-
- for idx, d in enumerate(tqdm(train_loader, desc="ep: %d"%ep, ncols=50)):
- model.train()
-
- mentions, candidates, pairs, sparse_scores, cand_graph_embs, label = d
-
- for k in mentions.keys():
- mentions[k] = torch.tensor(mentions[k]).cuda()
- for k in candidates.keys():
- candidates[k] = torch.tensor(candidates[k]).cuda()
- for k in pairs.keys():
- pairs[k] = torch.tensor(pairs[k]).cuda()
-
- sparse_scores = torch.from_numpy(sparse_scores).cuda()
- label = torch.tensor(label).float().cuda()
-
- if cand_graph_embs is not None:
- cand_graph_embs = cand_graph_embs.cuda()
-
- dense_scores, output = model(mentions, candidates, pairs, sparse_scores, cand_graph_embs)
-
- # 实体标准化loss
- label = label.reshape(-1, args.topk)
- # print(output.shape, label.shape)
- if args.listwise_loss:
- weight = label.clone()
- weight[weight == 0] = 1
- weight[weight == 2] = 1
- weight[weight == 1] = 0.5
- normalization_loss_func = BCEWithLogitsLoss(weight=weight)
- normalization_loss = normalization_loss_func(output, (label != 0).float().cuda())
- else:
- normalization_loss_func = BCEWithLogitsLoss()
- normalization_loss = normalization_loss_func(output, label)
-
- if args.share_bert:
- # 共享则直接使用上面输出的dense_score计算动态选择模块的损失
- selection_loss = selection_loss_func(dense_scores, label)
- # 两种损失加和的方法
- if args.loss_combine == "alpha":
- loss = normalization_loss / model.alpha1 ** 2 + selection_loss / model.alpha2 ** 2 + torch.log(model.alpha1) + torch.log(model.alpha2)
- elif args.loss_combine == "add":
- loss = selection_loss + normalization_loss * args.normalization_loss_weight
-
- loss.backward(retain_graph=True)
-
- else:
- # 不共享bert,重新计算dense_scores
- dense_scores = candidate_selector(mentions, candidates)
- selection_loss = selection_loss_func(dense_scores, label)
- selection_loss.backward()
- normalization_loss.backward(retain_graph=True)
-
- epoch_normalization_loss += normalization_loss.item()
- epoch_selection_loss += selection_loss.item() # if not isinstance(selection_loss, int) else selection_loss
-
- optimizer.step()
- if optimizer4selector:
- optimizer4selector.step()
-
- if scheduler:
- scheduler.step()
-
- optimizer.zero_grad()
- if optimizer4selector:
- optimizer4selector.zero_grad()
-
- epoch_loss += epoch_normalization_loss + epoch_selection_loss
-
- torch.cuda.empty_cache()
-
- print("[average] selector loss:{:.4f} normazalition loss:{:.4f} total loss:{:.4f}".format(
- epoch_selection_loss / len(train_loader), epoch_normalization_loss / len(train_loader), epoch_loss / len(train_loader)
- ))
- if args.share_bert:
- print("alpha1:{:.4f} alpha2:{:.4f}".format(model.alpha1.item(), model.alpha2.item()))
-
- (p, r, f1), match_acc = evaluate(model, dev_loader, args.topk, os.path.join(dev_result_output_dir, "dev-ep%d"%ep))
- dev_str = "dev p={:.4f} r={:.4f} f1={:.4f}, acc={:.4f}".format(p, r, f1, match_acc)
- print(dev_str)
-
- if args.do_test:
- (p, r, f1), match_acc = evaluate(model, test_loader, args.topk, os.path.join(test_result_output_dir, "test-ep%d"%ep))
- test_str = "test p={:.4f} r={:.4f} f1={:.4f}, acc={:.4f}".format(p, r, f1, match_acc)
- print(test_str)
-
- torch.save(model.state_dict(), os.path.join(epoch_output_dir, "normalization_model.pt"))
- print("模型保存到:{}".format(epoch_output_dir))
-
- with open(os.path.join(args.output_dir, "result.txt"), "a+", encoding="utf-8") as f:
- f.write(str(datetime.now()) + "\t" + "ep: %3d\t" % ep + dev_str + "\t" + test_str + "\n")
|