|
- #! -*- coding:utf-8 -*-
- # 三元组抽取任务,基于GlobalPointer的仿TPLinker设计
- # 文章介绍:https://kexue.fm/archives/8888
- # 数据集:http://ai.baidu.com/broad/download?dataset=sked
- # 最优f1=0.827
- # 说明:由于使用了EMA,需要跑足够多的步数(5000步以上)才生效,如果
- # 你的数据总量比较少,那么请务必跑足够多的epoch数,或者去掉EMA。
-
- import json
- import numpy as np
- from bert4keras.backend import keras, K
- from bert4keras.backend import sparse_multilabel_categorical_crossentropy
- from bert4keras.tokenizers import Tokenizer
- from bert4keras.layers import EfficientGlobalPointer as GlobalPointer
- from bert4keras.models import build_transformer_model
- from bert4keras.optimizers import Adam, extend_with_exponential_moving_average
- from bert4keras.snippets import sequence_padding, DataGenerator
- from bert4keras.snippets import open, to_array
- from tqdm import tqdm
- import tensorflow as tf
- from keras import backend as K
-
- config = tf.ConfigProto()
- config.gpu_options.allow_growth = True # 不全部占满显存, 按需分配
- sess = tf.Session(config=config)
-
- K.set_session(sess)
- maxlen = 128
- batch_size = 16
- config_path = '../chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_config.json'
- checkpoint_path = '../chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt'
- dict_path = '../chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt'
-
-
- def load_data(filename):
- """加载数据
- 单条格式:{'text': text, 'spo_list': [(s, p, o)]}
- """
- D = []
- id2predicate = {}
- predicate2id = {}
- with open(filename, encoding='utf-8') as f:
- for l in f:
- l = json.loads(l)
- entities_mapping = {}
- for i in l["entities"]:
- entities_mapping[i["id"]]=l['text'][i["start_offset"]:i["end_offset"]]
- D.append({
- 'text': l['text'],
- 'spo_list': [(entities_mapping[spo['from_id']], spo['type'], entities_mapping[spo['to_id']])
- for spo in l['relations']]
- })
- for spo in l["relations"]:
- if spo['type'] not in predicate2id:
- id2predicate[len(predicate2id)] = spo['type']
- predicate2id[spo['type']] = len(predicate2id)
-
- return D,id2predicate,predicate2id
-
-
- # 加载数据集
- all_data,id2predicate,predicate2id = load_data('train.jsonl')
- train_data = all_data[:int(len(all_data)*0.8)]
- valid_data = all_data[int(len(all_data)*0.8):]
- # predicate2id, id2predicate = {}, {}
-
- # 建立分词器
- tokenizer = Tokenizer(dict_path, do_lower_case=True)
-
-
- def search(pattern, sequence):
- """从sequence中寻找子串pattern
- 如果找到,返回第一个下标;否则返回-1。
- """
- n = len(pattern)
- for i in range(len(sequence)):
- if sequence[i:i + n] == pattern:
- return i
- return -1
-
-
- class data_generator(DataGenerator):
- """数据生成器
- """
- def __iter__(self, random=False):
- batch_token_ids, batch_segment_ids = [], []
- batch_entity_labels, batch_head_labels, batch_tail_labels = [], [], []
- for is_end, d in self.sample(random):
- token_ids, segment_ids = tokenizer.encode(d['text'], maxlen=maxlen)
- # 整理三元组 {(s, o, p)}
- spoes = set()
- for s, p, o in d['spo_list']:
- s = tokenizer.encode(s)[0][1:-1]
- p = predicate2id[p]
- o = tokenizer.encode(o)[0][1:-1]
- sh = search(s, token_ids)
- oh = search(o, token_ids)
- if sh != -1 and oh != -1:
- spoes.add((sh, sh + len(s) - 1, p, oh, oh + len(o) - 1))
- # 构建标签
- entity_labels = [set() for _ in range(2)]
- head_labels = [set() for _ in range(len(predicate2id))]
- tail_labels = [set() for _ in range(len(predicate2id))]
- for sh, st, p, oh, ot in spoes:
- entity_labels[0].add((sh, st))
- entity_labels[1].add((oh, ot))
- head_labels[p].add((sh, oh))
- tail_labels[p].add((st, ot))
- for label in entity_labels + head_labels + tail_labels:
- if not label: # 至少要有一个标签
- label.add((0, 0)) # 如果没有则用0填充
- entity_labels = sequence_padding([list(l) for l in entity_labels])
- head_labels = sequence_padding([list(l) for l in head_labels])
- tail_labels = sequence_padding([list(l) for l in tail_labels])
- # 构建batch
- batch_token_ids.append(token_ids)
- batch_segment_ids.append(segment_ids)
- batch_entity_labels.append(entity_labels)
- batch_head_labels.append(head_labels)
- batch_tail_labels.append(tail_labels)
- if len(batch_token_ids) == self.batch_size or is_end:
- batch_token_ids = sequence_padding(batch_token_ids)
- batch_segment_ids = sequence_padding(batch_segment_ids)
- batch_entity_labels = sequence_padding(
- batch_entity_labels, seq_dims=2
- )
- batch_head_labels = sequence_padding(
- batch_head_labels, seq_dims=2
- )
- batch_tail_labels = sequence_padding(
- batch_tail_labels, seq_dims=2
- )
- yield [batch_token_ids, batch_segment_ids], [
- batch_entity_labels, batch_head_labels, batch_tail_labels
- ]
- batch_token_ids, batch_segment_ids = [], []
- batch_entity_labels, batch_head_labels, batch_tail_labels = [], [], []
-
-
- def globalpointer_crossentropy(y_true, y_pred):
- """给GlobalPointer设计的交叉熵
- """
- shape = K.shape(y_pred)
- y_true = y_true[..., 0] * K.cast(shape[2], K.floatx()) + y_true[..., 1]
- y_pred = K.reshape(y_pred, (shape[0], -1, K.prod(shape[2:])))
- loss = sparse_multilabel_categorical_crossentropy(y_true, y_pred, True)
- return K.mean(K.sum(loss, axis=1))
-
-
- # 加载预训练模型
- base = build_transformer_model(
- config_path=config_path,
- checkpoint_path=checkpoint_path,
- return_keras_model=False
- )
-
- # 预测结果
- entity_output = GlobalPointer(heads=2, head_size=64)(base.model.output)
- head_output = GlobalPointer(
- heads=len(predicate2id), head_size=64, RoPE=False, tril_mask=False
- )(base.model.output)
- tail_output = GlobalPointer(
- heads=len(predicate2id), head_size=64, RoPE=False, tril_mask=False
- )(base.model.output)
- outputs = [entity_output, head_output, tail_output]
-
- # 构建模型
- AdamEMA = extend_with_exponential_moving_average(Adam, name='AdamEMA')
- optimizer = AdamEMA(learning_rate=1e-5)
- model = keras.models.Model(base.model.inputs, outputs)
- model.compile(loss=globalpointer_crossentropy, optimizer=optimizer)
- # model.summary()
-
-
- def extract_spoes(text, threshold=0):
- """抽取输入text所包含的三元组
- """
- tokens = tokenizer.tokenize(text, maxlen=maxlen)
- mapping = tokenizer.rematch(text, tokens)
- token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
- token_ids, segment_ids = to_array([token_ids], [segment_ids])
- outputs = model.predict([token_ids, segment_ids])
- outputs = [o[0] for o in outputs]
- # 抽取subject和object
- subjects, objects = set(), set()
- outputs[0][:, [0, -1]] -= np.inf
- outputs[0][:, :, [0, -1]] -= np.inf
- for l, h, t in zip(*np.where(outputs[0] > threshold)):
- if l == 0:
- subjects.add((h, t))
- else:
- objects.add((h, t))
- # 识别对应的predicate
- spoes = set()
- for sh, st in subjects:
- for oh, ot in objects:
- p1s = np.where(outputs[1][:, sh, oh] > threshold)[0]
- p2s = np.where(outputs[2][:, st, ot] > threshold)[0]
- ps = set(p1s) & set(p2s)
- for p in ps:
- spoes.add((
- text[mapping[sh][0]:mapping[st][-1] + 1], id2predicate[p],
- text[mapping[oh][0]:mapping[ot][-1] + 1]
- ))
- return list(spoes)
-
-
- class SPO(tuple):
- """用来存三元组的类
- 表现跟tuple基本一致,只是重写了 __hash__ 和 __eq__ 方法,
- 使得在判断两个三元组是否等价时容错性更好。
- """
- def __init__(self, spo):
- self.spox = (
- tuple(tokenizer.tokenize(spo[0])),
- spo[1],
- tuple(tokenizer.tokenize(spo[2])),
- )
-
- def __hash__(self):
- return self.spox.__hash__()
-
- def __eq__(self, spo):
- return self.spox == spo.spox
-
-
- def evaluate(data):
- """评估函数,计算f1、precision、recall
- """
- X, Y, Z = 1e-10, 1e-10, 1e-10
- f1, precision, recall = 1e-10, 1e-10, 1e-10
- f = open('dev_pred.json', 'w', encoding='utf-8')
- pbar = tqdm()
- for d in data:
- R = set([SPO(spo) for spo in extract_spoes(d['text'])])
- T = set([SPO(spo) for spo in d['spo_list']])
- X += len(R & T)
- Y += len(R)
- Z += len(T)
- f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
- pbar.update()
- pbar.set_description(
- 'f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall)
- )
- s = json.dumps({
- 'text': d['text'],
- 'spo_list': list(T),
- 'spo_list_pred': list(R),
- 'new': list(R - T),
- 'lack': list(T - R),
- },
- ensure_ascii=False,
- indent=4)
- f.write(s + '\n')
- pbar.close()
- f.close()
- return f1, precision, recall
-
-
- class Evaluator(keras.callbacks.Callback):
- """评估与保存
- """
- def __init__(self):
- self.best_val_f1 = 0.
-
- def on_epoch_end(self, epoch, logs=None):
- optimizer.apply_ema_weights()
- f1, precision, recall = evaluate(valid_data)
- if f1 >= self.best_val_f1:
- self.best_val_f1 = f1
- model.save_weights('best_model.weights')
- optimizer.reset_old_weights()
- print(
- 'f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5f\n' %
- (f1, precision, recall, self.best_val_f1)
- )
- if __name__ == "__main__":
- epochs = 100
- result = []
- for maxlen in [64, 128, 256]:
- for batch_size in [4, 8, 16, 32]:
- for learning_rate in [1e-5, 2e-5, 5e-5, 1e-4, 5e-4]:
-
- train_generator = data_generator(train_data, batch_size)
- evaluator = Evaluator()
- # 加载预训练模型
- base = build_transformer_model(
- config_path=config_path,
- checkpoint_path=checkpoint_path,
- return_keras_model=False
- )
-
- # 预测结果
- entity_output = GlobalPointer(heads=2, head_size=64)(base.model.output)
- head_output = GlobalPointer(
- heads=len(predicate2id), head_size=64, RoPE=False, tril_mask=False
- )(base.model.output)
- tail_output = GlobalPointer(
- heads=len(predicate2id), head_size=64, RoPE=False, tril_mask=False
- )(base.model.output)
- outputs = [entity_output, head_output, tail_output]
-
- # 构建模型
- AdamEMA = extend_with_exponential_moving_average(Adam, name='AdamEMA')
- optimizer = AdamEMA(learning_rate=learning_rate)
- model = keras.models.Model(base.model.inputs, outputs)
- model.compile(loss=globalpointer_crossentropy, optimizer=optimizer)
- early_stopping = keras.callbacks.EarlyStopping(monitor='loss', min_delta=0, patience=0, verbose=0, mode='auto',
- baseline=None, restore_best_weights=False)
- model.fit(
- train_generator.forfit(),
- steps_per_epoch=len(train_generator),
- epochs=epochs,
- callbacks=[evaluator,early_stopping]
- )
- model.load_weights('best_model.weights')
-
- f1, precision, recall = evaluate(valid_data)
- result.append(
- {"max_length": maxlen, "batch_size": batch_size, "learning_rate": learning_rate, "f1": f1,"precision":precision,"recall":recall})
- json.dump(result,open("re_result_second.json","w"))
- K.clear_session()
|