|
- # -*- coding: utf-8 -*-
- """
- @Time : 2022-05-22 22:04
- @Author : Zhuxb
-
- """
- import numpy as np
- import paddle
- import paddle.nn as nn
- from paddlenlp.transformers import BertTokenizer, RobertaTokenizer, LinearDecayWithWarmup
- from model.classifer.softmax_classifer import Softmax_Layer
- from model.encoder.encoder import Encoder
- from utils import BCELossForDuIE
- import os
- import json
-
- # 双层的fc全连接
- class MultiNonLinearClassifier(nn.Layer):
- def __init__(self, hidden_size, tag_size, dropout_rate):
- super(MultiNonLinearClassifier, self).__init__()
- self.tag_size = tag_size
- self.linear = nn.Linear(hidden_size, int(hidden_size / 2))
- self.hidden2tag = nn.Linear(int(hidden_size / 2), self.tag_size)
- self.dropout = nn.Dropout(dropout_rate)
-
- def forward(self, input_features):
- features_tmp = self.linear(input_features)
- features_tmp = nn.ReLU()(features_tmp)
- features_tmp = self.dropout(features_tmp)
- features_output = self.hidden2tag(features_tmp)
- return features_output
-
-
- class RE_Global_Model(nn.Layer):
- def __init__(self, args):
- super(RE_Global_Model, self).__init__()
- self.args = args
- self.loss_func_entity = BCELossForDuIE()
-
- label_map_path = os.path.join(args.data_path, "predicate2id.json")
- with open(label_map_path, 'r', encoding='utf8') as fp:
- self.label_map = json.load(fp)
-
- if args.lang == "zn":
- self.encoder = Encoder(model_name="roberta-wwm-ext-large", lang="zn")
- self.tokenizer = RobertaTokenizer.from_pretrained("roberta-wwm-ext-large")
- elif args.lang == "en":
- self.encoder = Encoder(model_name="bert-base-cased", lang="en")
- self.tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
- else:
- raise TypeError(f"wrong lang type {args.lang}")
-
- self.subject_classifier = Softmax_Layer(input_size=self.encoder.output_size, num_classes=2 * (len(self.label_map.keys()) - 2) + 1, name="sub")
- self.object_classifier = Softmax_Layer(input_size=self.encoder.output_size, num_classes=2 * (len(self.label_map.keys()) - 2) + 1, name="obj")
-
- def forward(self, input_ids, subject_labels, object_labels):
- sequence_output = self.encoder(input_ids=input_ids, encode_type="encoder")
- attention_mask = (input_ids != 0).logical_and((input_ids != 1)).logical_and((input_ids != 2))
-
- # 首先进行实体的训练
- subject_logits = self.subject_classifier(sequence_output, attention_mask.astype("int64"))
- object_logits = self.object_classifier(sequence_output, attention_mask.astype("int64"))
-
- loss_subject_entity = self.loss_func_entity(subject_logits, subject_labels, attention_mask.astype("int64"))
- loss_object_entity = self.loss_func_entity(object_logits, object_labels, attention_mask.astype("int64"))
-
- loss_entity = loss_subject_entity + loss_object_entity
-
- return loss_entity, nn.functional.sigmoid(subject_logits).numpy(), nn.functional.sigmoid(object_logits).numpy()
|