|
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.autograd import Variable
- import numpy as np
- import math
- import dgl
- import random
- import dgl.nn as dglnn
- from . import BaseModel, register_model
- from .CompGCN import CompGraphConvLayer
- import os
-
- def get_norm_id(id_map, some_id):
- #如果不存在,返回一个id最大值
- if some_id not in id_map:
- id_map[some_id] = len(id_map)
- return id_map[some_id]
-
- def norm_graph(node_id_map, edge_id_map, edge_list):
- norm_edge_list = []
- for e in edge_list:
- norm_edge_list.append(
- (
- get_norm_id(node_id_map, e[0]),
- get_norm_id(node_id_map, e[1]),
- get_norm_id(edge_id_map, e[2]),
- )
- )
- return norm_edge_list
- class NodeEncoder(torch.nn.Module):
- def __init__(
- self,
- base_embedding_dim,
- num_nodes,
- pretrained_node_embedding_tensor,
- is_pre_trained,
- ):
-
- super().__init__()
- self.pretrained_node_embedding_tensor = pretrained_node_embedding_tensor
- self.base_embedding_dim = base_embedding_dim
-
- if not is_pre_trained:
- self.base_embedding_layer = torch.nn.Embedding(
- num_nodes, base_embedding_dim
- ).cuda()
- self.base_embedding_layer.weight.data.uniform_(-1, 1)
- else:
- self.base_embedding_layer = torch.nn.Embedding.from_pretrained(
- pretrained_node_embedding_tensor
- ).cuda()
-
- def forward(self, node_id):
- node_id = torch.LongTensor([int(node_id)]).cuda()
- x_base = self.base_embedding_layer(node_id)
-
- return x_base
-
- class GCNGraphEncoder(torch.nn.Module):
- def __init__(
- self,
- G,
- pretrained_node_embedding_tensor,
- is_pre_trained,
- base_embedding_dim,
- max_length,
- ):
-
- super().__init__()
- self.g = G
- self.base_embedding_dim = base_embedding_dim
- self.max_length = max_length
- self.no_nodes = self.g.num_nodes() #用DGL的表示方式
- self.no_relations = self.g.num_edges()
- # print('check *************', self.no_relations)
-
- self.node_embedding = NodeEncoder(
- base_embedding_dim,
- self.no_nodes,
- pretrained_node_embedding_tensor,
- is_pre_trained,
- )
-
- self.special_tokens = {"[PAD]": 0, "[MASK]": 1}
- self.special_embed = torch.nn.Embedding(
- len(self.special_tokens), base_embedding_dim
- )
- self.special_embed.weight.data.uniform_(-1, 1)
-
- def forward(self, subgraphs_list, masked_nodes):
- num_subgraphs = len(subgraphs_list)
-
- node_emb = torch.zeros(
- num_subgraphs, self.max_length + 1, self.base_embedding_dim#+1是因为包含
- )
-
- for ii,subgraph in enumerate(subgraphs_list):
- #node_id_map = batch_id_maps[ii][0]
- #edge_type_map = batch_id_maps[ii][1]
- masked_set = masked_nodes[ii]
- for node in subgraph.nodes():
- node_id=subgraph.ndata[dgl.NID][int(node)]
- if node_id not in masked_set: # used to ignore the masked nodes
- node_emb[ii][node] = self.node_embedding(int(node_id))
-
- # get embeddings for special tokens
- # will be used for masking and padding before bert layer
- special_tokens_embed = {}
- for token in self.special_tokens:
- node_id = Variable(torch.LongTensor([self.special_tokens[token]]))
- tmp_embed = self.special_embed(node_id)
- special_tokens_embed[self.special_tokens[token] + self.no_nodes] = {
- "token": token,
- "embed": tmp_embed,
- }
-
- return node_emb
-
- def get_attn_pad_mask(subgraph_list, pad_id, max_len):
- #seq_q and seq_k are both all_nodes, which is list(list(subgraph_nodes))
- batch_size = len(subgraph_list)
- len_q=max_len
- # print(batch_size, len_q, len_k)
- pad_attn_mask = []
- for itm in subgraph_list:
- tmp_mask = []
- for sub in itm.ndata[dgl.NID]:
- if sub == pad_id:
- tmp_mask.append(True)
- else:
- tmp_mask.append(False)
- if len(tmp_mask)<max_len:
- tmp_mask=tmp_mask+[True]*(max_len-len(tmp_mask))
- pad_attn_mask.append(tmp_mask)
- # print(tmp_mask)
- # print('mask', len(pad_attn_mask), len(pad_attn_mask[0]))
- pad_attn_mask = Variable(torch.ByteTensor(pad_attn_mask)).unsqueeze(1)
- pad_attn_mask = pad_attn_mask.cuda()
-
- return pad_attn_mask.expand(batch_size, len_q, len_q) # batch_size x len_q x len_k
-
-
- def gelu(x):
- """"Implementation of the gelu activation function by Hugging Face."""
- return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
-
-
- class ScaledDotProductAttention(torch.nn.Module):
- def __init__(self, d_k):
-
- super(ScaledDotProductAttention, self).__init__()
- self.d_k = d_k
-
- def forward(self, Q, K, V, attn_mask=None):
- # print('mask', attn_mask.size())
- scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
- scores.masked_fill_(attn_mask == True, -1e9)#change dropped softmax value into
- attn = torch.nn.Softmax(dim=-1)(scores)
- context = torch.matmul(attn, V)
-
- return context, attn
-
-
- class MultiHeadAttention(torch.nn.Module):
- def __init__(self, d_model, d_k, d_v, n_heads):
-
- super(MultiHeadAttention, self).__init__()
- self.n_heads = n_heads
- self.d_k = d_k #dimension of K and Q
- self.d_v = d_v #dimension of V
- self.d_model = d_model
-
- self.W_Q = torch.nn.Linear(d_model, d_k * n_heads)
- self.W_K = torch.nn.Linear(d_model, d_k * n_heads)
- self.W_V = torch.nn.Linear(d_model, d_v * n_heads)
- self.scaled_dot_prod_attn = ScaledDotProductAttention(d_k)
- self.wrap = torch.nn.Linear(self.n_heads * self.d_v, self.d_model)
- self.layerNorm = torch.nn.LayerNorm(self.d_model)
-
- def forward(self, Q, K, V, attn_mask=None):
- #This V is not the V matrix of dot attention.
- residual, batch_size = Q, Q.size(0)
- q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)#128(batcch)*4(head)*7(n_nodes)*64(d_k)
- k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
- v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
-
- if attn_mask is not None:
- attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)
- context, attn = self.scaled_dot_prod_attn(q_s, k_s, v_s, attn_mask=attn_mask)#context is H*A
- context = (
- context.transpose(1, 2)
- .contiguous()
- .view(batch_size, -1, self.n_heads * self.d_v)
- )
- output = self.wrap(context)
-
- return self.layerNorm(output + residual), attn
-
- #fNN in the paper
- class PoswiseFeedForwardNet(torch.nn.Module):
- def __init__(self, d_model, d_ff):
-
- super(PoswiseFeedForwardNet, self).__init__()
- self.fc1 = torch.nn.Linear(d_model, d_ff)
- self.fc2 = torch.nn.Linear(d_ff, d_model)
-
- def forward(self, x):
-
- return self.fc2(gelu(self.fc1(x)))
-
-
- class EncoderLayer(torch.nn.Module):
- def __init__(self, d_model, d_k, d_v, d_ff, n_heads):
- super(EncoderLayer, self).__init__()
- self.enc_self_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads)
- self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff)
-
- def forward(self, enc_inputs, enc_self_attn_mask):
- enc_outputs, attn = self.enc_self_attn(
- enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask
- ) # enc_inputs to same Q,K,V
- enc_outputs = self.pos_ffn(
- enc_outputs
- ) # enc_outputs: [batch_size x len_q x d_model]
-
- return enc_outputs, attn
-
-
- @register_model('SLiCE')
- class SLiCE(BaseModel):
- @classmethod
- def build_model_from_args(cls, args, hg):
- # if args.embed_dir:
- # pretrained_node_embedding_tensor=load_pickle(args.embed_dir)
-
- return cls(G=hg,pretrained_node_embedding_tensor=None,args=args)#to-do: 命令行解析
- def load_pretrained_node2vec(self,filename, base_emb_dim):
- """
- loads embeddings from node2vec style file, where each line is
- nodeid node_embedding
- returns tensor containing node_embeddings
- for graph nodes 0 to n-1
- """
- node_embeddings = torch.empty(self.g.num_nodes(), base_emb_dim)
- with open(filename, "r") as f:
- header = f.readline()
- emb_dim = int(header.strip().split()[1])
- for line in f:
- arr = line.strip().split()
- graph_node_id = arr[0]
- node_emb = [float(x) for x in arr[1:]]
- vocab_id = int(graph_node_id)
- if vocab_id >= 0:
- node_embeddings[vocab_id] = torch.tensor(node_emb)
- # print(torch.tensor(node_emb).size())
- out = node_embeddings
- print("node2vec tensor", out.size())
- return out
- #参数来自原论文默认参数
- def __init__(self,
- G, #G为DGLGraph
- args,
- pretrained_node_embedding_tensor,
- num_layers=6,
- d_model=200,
- d_k=64,
- d_v=64,
- d_ff=200 * 4,
- n_heads=4,
- is_pre_trained=False,
- base_embedding_dim=200,#dimension of base embedding
- max_length=6,#max length of walks
- num_gcn_layers=2,#number of gcn layers before bert
- node_edge_composition_func="mult",#options for node and edge compostion, sub|circ_conv|mult|no_rel
- get_embeddings=False,#indicate if need to get node vectors from BERT encoder output
- fine_tuning_layer=False,):
-
- super().__init__()
- #initialize
- self.g=G
- self.num_layers = num_layers
- self.d_model = d_model
- self.max_length = max_length
- self.get_embeddings = get_embeddings
- self.node_edge_composition_func = node_edge_composition_func
- self.fine_tuning_layer = fine_tuning_layer
- self.no_nodes = G.num_nodes()
- self.n_pred=args.n_pred
- #pretraining use node2vec if not exist
- if not os.path.exists(args.pretrained_embeddings):
- print("Run Node2vec to obtain pre-trained node embeddings ...")
- walks=[]
- for _ in range(10):
- nodes=list(G.nodes())
- random.shuffle(nodes)
- walk = dgl.sampling.node2vec_random_walk(G, torch.tensor(nodes), 1, 1, walk_length=80-1).tolist()#len=walk_length+1
- walks.extend(walk)
- walks = [list(map(str, walk)) for walk in walks]
- from gensim.models import Word2Vec
- model = Word2Vec(
- walks,
- size=base_embedding_dim,
- window=10,
- min_count=0,
- sg=1,
- workers=8,
- iter=1,
- )
- model.wv.save_word2vec_format(args.pretrained_embeddings)
-
- pretrained_node_embedding_tensor = self.load_pretrained_node2vec(
- args.pretrained_embeddings, base_embedding_dim
- )# (n_nodes*d_model)
- #FIXME 暂时是用随机初始化,pretrain tensor是None
- self.gcn_graph_encoder = GCNGraphEncoder(
- G,
- pretrained_node_embedding_tensor,
- is_pre_trained,
- base_embedding_dim,
- max_length,
- )
-
- self.layers = torch.nn.ModuleList(
- [EncoderLayer(d_model, d_k, d_v, d_ff, n_heads) for _ in range(num_layers)]
- ).cuda()
- self.linear = torch.nn.Linear(d_model, d_model).cuda()
- self.norm = torch.nn.LayerNorm(d_model).cuda()
-
- # decoder
- self.decoder = torch.nn.Linear(self.d_model, self.no_nodes).cuda()
- def set_fine_tuning(self):
- self.fine_tuning_layer = True
- def GCN_MaskGeneration(self,subgraph_sequences):
- n_pred=self.n_pred
- masked_nodes = []#node id masked
- masked_position = []# node index masked
- for subgraph in subgraph_sequences:
- num_nodes = subgraph.num_nodes()
- mask_index = random.sample(range(num_nodes), n_pred)
- subgraph_masked_nodes = []
- subgraph_masked_position = []
- for i in range(num_nodes):
- if i in mask_index:
- subgraph_masked_nodes.append(subgraph.ndata[dgl.NID][i])
- subgraph_masked_position.append(i)
- masked_nodes.append(subgraph_masked_nodes)
- masked_position.append(subgraph_masked_position)
-
- return torch.tensor(masked_nodes), torch.tensor(masked_position)
- def forward(self, subgraph_list):
- #subgraph list is a list of node subgraphs sampled by slice_sampler
- if self.fine_tuning_layer:
- masked_nodes=Variable(torch.LongTensor([[] for ii in range(len(subgraph_list))]))
- else:
- masked_nodes,masked_pos=self.GCN_MaskGeneration(subgraph_list)
- # 将节点embedding和关系的embedding初始化,并采样得到
- # context generation
- node_emb = self.gcn_graph_encoder(subgraph_list, masked_nodes)
- output = node_emb.cuda()
- enc_self_attn_mask = get_attn_pad_mask(subgraph_list,self.no_nodes,self.max_length+1)
- # contextual translation
- for layer in self.layers:
- output, enc_self_attn = layer(output, enc_self_attn_mask)
- try:
- layer_output = torch.cat((layer_output, output.unsqueeze(1)), 1)#output embedding of each layer
- except NameError: # FIXME - replaced bare except
- layer_output = output.unsqueeze(1).cuda()
-
- if self.fine_tuning_layer:
- try:
- att_output = torch.cat((att_output, enc_self_attn.unsqueeze(0)), 0)#output attention of each layer
- except NameError: # FIXME - replaced bare except
- att_output = enc_self_attn.unsqueeze(0)
-
- # new added for ablation study
- if self.num_layers == 0:
- layer_output = output.unsqueeze(1)
- att_output = "NA"
-
- if self.fine_tuning_layer:
- # print(output.size(), layer_output.size(), att_output.size())
- return output, layer_output, att_output
- else:
- masked_pos = masked_pos[:,:,None].expand(
- -1, -1, output.size(-1)
- ) # [batch_size, maxlen, d_model]
- h_masked = torch.gather(
- output, 1, masked_pos.cuda()
- ) # masking position [batch_size, len, d_model]
- h_masked = self.norm(gelu(self.linear(h_masked)))
- pred_score = self.decoder(h_masked) # [batch_size, maxlen, n_vocab]
- # print('check====', pred_score.size())
-
- if self.get_embeddings:
- return pred_score, masked_nodes, output
- else:
- return pred_score, masked_nodes
-
- class SLiCEFinetuneLayer(torch.nn.Module):
- @classmethod
- def build_model_from_args(cls, args):
- return cls(d_model=args.d_model,ft_d_ff=args.ft_d_ff,
- ft_layer=args.ft_layer,ft_drop_rate=args.ft_drop_rate,
- ft_input_option=args.ft_input_option,n_layers=args.num_layers)
- def __init__(
- self,
- d_model,
- ft_d_ff,
- ft_layer,
- ft_drop_rate,
- ft_input_option,
- num_layers,
- ):
-
- super().__init__()
- self.d_model = d_model
- self.ft_layer = ft_layer
- self.ft_input_option = ft_input_option
- self.num_layers = num_layers
-
- if ft_input_option in ["last", "last4_sum"]:
- cnt_layers = 1
- elif ft_input_option in ["last4_cat"]:
- cnt_layers = 4
-
- if self.num_layers == 0:
- cnt_layers = 1
-
- if self.ft_layer == "linear":
- self.ft_decoder = torch.nn.Linear(d_model * cnt_layers, d_model).cuda()
- elif self.ft_layer == "ffn":
- self.ffn1 = torch.nn.Linear(d_model * cnt_layers, ft_d_ff).cuda()
- print(self.num_layers, cnt_layers, self.ffn1)
- self.dropout = torch.nn.Dropout(ft_drop_rate).cuda()
- self.ffn2 = torch.nn.Linear(ft_d_ff, d_model).cuda()
-
- def forward(self, graphbert_layer_output):
- """
- graphbert_output = batch_sz * [CLS, source, target, relation, SEP] *
- [emb_size]
- """
- if self.ft_input_option == "last":
- # use the output from laster layer of graphbert
- graphbert_output = graphbert_layer_output[:, -1, :, :].squeeze(1)
- source_embedding = graphbert_output[:, 0, :].unsqueeze(1)
- destination_embedding = graphbert_output[:, 1, :].unsqueeze(1)
- else:
- # concatenate the output from the last four last four layers
- # add for ablation study
- no_layers = graphbert_layer_output.size(1)
- if no_layers == 1:
- start_layer = 0
- else:
- start_layer = no_layers - 4
- for ii in range(start_layer, no_layers):
- source_embed = graphbert_layer_output[:, ii, 0, :].unsqueeze(1)
- destination_embed = graphbert_layer_output[:, ii, 1, :].unsqueeze(1)
- if self.ft_input_option == "last4_cat":
- try:
- source_embedding = torch.cat(
- (source_embedding, source_embed), 2
- )
- destination_embedding = torch.cat(
- (destination_embedding, destination_embed), 2
- )
- except:
- source_embedding = source_embed
- destination_embedding = destination_embed
- elif self.ft_input_option == "last4_sum":
- try:
- source_embedding = torch.add(source_embedding, 1, source_embed)
- destination_embedding = torch.add(
- destination_embedding, 1, destination_embed
- )
- except:
- source_embedding = source_embed
- destination_embedding = destination_embed
- # print(source_embedding.size(), destination_embedding.size())
-
- if self.ft_layer == "linear":
- src_embedding = self.ft_decoder(source_embedding)
- dst_embedding = self.ft_decoder(destination_embedding)
- elif self.ft_layer == "ffn":
- src_embedding = torch.relu(self.dropout(self.ffn1(source_embedding)))
- src_embedding = self.ffn2(src_embedding)
- dst_embedding = torch.relu(self.dropout(self.ffn1(destination_embedding)))
- dst_embedding = self.ffn2(dst_embedding)
-
- dst_embedding = dst_embedding.transpose(1, 2)
- pred_score = torch.bmm(src_embedding, dst_embedding).squeeze(1)
- pred_score = torch.sigmoid(pred_score)
- # print('check+++++', pred_score.size())
-
- return pred_score, src_embedding, dst_embedding.transpose(1, 2)
|