|
- import torch
- import torch.nn as nn
- import numpy as np
- import torch.nn.functional as F
- import math
- from torch.nn import Parameter
-
- # class ChineseMEN(nn.Module):
- # def __init__(self, same_bert, bert_model):
- # super(ChineseMEN, self).__init__()
- # # self.selector
- # # self.discriminator
-
- class BertSelection(nn.Module):
- def __init__(self, bert_model, topk):
- super(BertSelection, self).__init__()
-
- self.bert = bert_model
- self.topk = topk
-
- def forward(self, mentions, candidates):
- mention_emb = self.bert(**mentions)[0][:, 0, :].unsqueeze(1)
- candidates_emb = self.bert(**candidates)[0][:, 0, :].squeeze(1).reshape(-1, self.topk, 768)
-
- dense_scores = torch.bmm(mention_emb, candidates_emb.permute(0,2,1)).squeeze(1)
-
- return dense_scores
-
-
- class BertNormalization(nn.Module):
- def __init__(self, bert_model, dropout_prob, topk, score_type ,pair_weight=1, add_sparse=False, calculate_dense=False, add_dense2score=False, add_match_atten=False, add_graph=False, graph_model=None):
- super(BertNormalization, self).__init__()
-
- self.bert = bert_model
- self.sparse_weight = nn.Parameter(torch.empty(1).cuda())
- self.sparse_weight.data.fill_(0) # init sparse_weight
-
- self.topk = topk
-
- self.dropout = nn.Dropout(dropout_prob)
- self.classifier = nn.Linear(768, 1)
-
- self.add_sparse = add_sparse
- self.calculate_dense = calculate_dense
- self.add_dense2score = add_dense2score
-
- self.score_type = score_type
-
- self.pair_weight = pair_weight
-
- self.alpha1, self.alpha2 = nn.Parameter(torch.empty(1).cuda()), nn.Parameter(torch.empty(1).cuda())
- self.alpha1.data.fill_(1)
- self.alpha2.data.fill_(1)
-
- self.add_match_atten = add_match_atten
- self.match_atten = MatchAttention(score_type)
-
- self.add_graph = add_graph
- self.graph_model = graph_model
-
- def forward(self, mentions, candidates, pairs, sparse_scores, cand_graph_embs):
- pairs_emb = self.bert(**pairs)[0][:, 0, :].unsqueeze(1)
- pairs_emb = pairs_emb.reshape(-1, self.topk, 768)
-
- if self.add_match_atten:
- pairs_emb = self.match_atten(pairs_emb)
-
- pair_score = self.classifier(pairs_emb).squeeze(-1)
-
- output = pair_score * self.pair_weight # + sparse_scores * self.sparse_weight # + dense_scores
-
- if self.add_sparse:
- sparse_scores /= torch.norm(sparse_scores, 2)
- output += self.sparse_weight * sparse_scores
-
- dense_scores = None
- if self.calculate_dense:
- mention_emb = self.bert(**mentions)[0][:, 0, :].unsqueeze(1)
- candidates_emb = self.bert(**candidates)[0][:, 0, :].squeeze(1).reshape(-1, self.topk, 768)
- dense_scores = torch.bmm(mention_emb, candidates_emb.permute(0,2,1)).squeeze(1)
-
- if self.add_dense2score:
- dense_scores_ori = dense_scores
- dense_scores_norm = torch.norm(dense_scores, 2)
- output += dense_scores_ori / dense_scores_norm
-
- if self.add_graph:
- if not self.add_dense:
- mention_emb = self.bert(**mentions)[0][:, 0, :].unsqueeze(1)
-
- graph_scores_ori = torch.bmm(mention_emb, cand_graph_embs.permute(0,2,1)).squeeze(1)
- graph_embed_norm = torch.norm(graph_scores_ori, 2)
- graph_scores = graph_scores_ori / graph_embed_norm
-
- output += graph_scores
-
- return dense_scores, output
-
-
- class MatchAttention(nn.Module):
- def __init__(self, score_type="dot"):
- super(MatchAttention, self).__init__()
- self.score_type = score_type
-
- if score_type == "bilinear":
- self.W = nn.Linear(768, 768)
-
- def forward(self, options):
-
- batch_size, option_num, hidden = options.size()
-
- if self.score_type == "dot":
- mask = torch.from_numpy(np.eye(option_num)).float().cuda()
- att_matrix = options.bmm(torch.transpose(options, 1, 2))
-
- # print(att_matrix[0, 0])
- att_matrix /= 28
- # print(att_matrix[0, 0])
-
- # att_matrix = att_matrix * (1 - mask)
- att_weight = torch.nn.functional.softmax(att_matrix, dim=-1)
- # att_weight = att_weight * (1 - mask)
-
- att_vec = att_weight.bmm(options)
-
- att_options = att_vec.view(batch_size, option_num, -1)
-
- # att_options += options
-
- # att_matrix = options.bmm(torch.transpose(options, 1, 2))
- # att_matrix /= 28
- # att_matrix = att_matrix * (1 - mask)
-
- # att_weight = torch.nn.functional.softmax(att_matrix, dim=-1)
-
- # # print(att_weight[0, 0])
- # att_matrix = att_matrix * (1 - mask)
-
- # att_vec = att_weight.bmm(options)
- # att_options = att_vec.view(batch_size, option_num, -1)
-
- # # print(att_options[0, 0, :10])
- # # print("attention")
- # # print(att_options[0, 0, :10])
- # # print(options[0, 0, :10])
- # # print(att_options[0, 0, :10].argmax())
- # # print(att_options[0, 0, :10])
- # # print(options[0, 0, :10])
- # att_options += options
- # # print(att_options[0, 0, :10])
- # # print(att_options[0, 0, :10].argmax())
-
- # return att_options
-
- elif self.score_type == "bilinear":
- mask = torch.from_numpy(np.eye(option_num)).float().cuda()
-
- linear_options = self.W(options)
- att_matrix = options.bmm(torch.transpose(linear_options, 1, 2))
- att_weight = torch.nn.functional.softmax(att_matrix, dim=-1)
-
- att_vec = att_weight.bmm(options)
- att_options = att_vec.view(batch_size, option_num, -1)
-
- att_options += options
-
- return att_options
-
- class GraphConvolution(nn.Module):
- """
- Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
- """
- def __init__(self, in_features, out_features, bias=True):
- super(GraphConvolution, self).__init__()
- self.in_features = in_features
- self.out_features = out_features
- self.weight = Parameter(torch.FloatTensor(in_features, out_features))
- if bias:
- self.bias = Parameter(torch.FloatTensor(out_features))
- else:
- self.register_parameter('bias', None)
- self.reset_parameters()
-
- def reset_parameters(self):
- stdv = 1. / math.sqrt(self.weight.size(1))
- self.weight.data.uniform_(-stdv, stdv)
- if self.bias is not None:
- self.bias.data.uniform_(-stdv, stdv)
-
- # @torchsnooper.snoop()
- def forward(self, input, adj):
- support = torch.matmul(input, self.weight) # (b, t, h)
- # output = torch.bmm(adj, support) #(b, t, h)
- output = torch.mm(adj, support) #(b, t, h)
- if self.bias is not None:
- return output + self.bias
- else:
- return output
-
- class GCN_update(nn.Module):
- def __init__(self, in_features, out_features, bias=False):
- super(GCN_update, self).__init__()
- self.gcn1 = GraphConvolution(in_features, out_features)
- self.gcn2 = GraphConvolution(out_features, out_features)
- self.relu = nn.ReLU()
-
- def reset_parameters(self):
- stdv = 1. / math.sqrt(self.weight.size(1))
- self.weight.data.uniform_(-stdv, stdv)
- if self.bias is not None:
- self.bias.data.uniform_(-stdv, stdv)
-
- def forward(self, input, adj):
- """
- @description :
- @param : input-->(nodeNum, feature), adj-->(nodeNum, nodeNum)
- @Returns :
- """
- nodeEmb1 = self.gcn1(input, adj)
- nodeEmb2 = self.gcn2(self.relu(nodeEmb1), adj)
- return nodeEmb2
|