|
- """
- Trilinear attention
- This code is written by Huy Tran.
- """
- import torch
- import torch.nn as nn
- from torch.nn.utils.weight_norm import weight_norm
- from src.fc import FCNet
- from src.bc import BCNet
- from src.tc import TCNet
- import time
-
-
- class BiAttention(nn.Module):
- def __init__(self, x_dim, y_dim, z_dim, glimpse, dropout=[.2,.5]):
- super(BiAttention, self).__init__()
-
- self.glimpse = glimpse
- self.logits = weight_norm(BCNet(x_dim, y_dim, z_dim, glimpse, dropout=dropout, k=3), \
- name='h_mat', dim=None)
-
- def forward(self, v, q, v_mask=True):
- """
- v: [batch, k, vdim]
- q: [batch, qdim]
- """
- p, logits = self.forward_all(v, q, v_mask)
- return p, logits
-
- def forward_all(self, v, q, v_mask=True):
- v_num = v.size(1)
- q_num = q.size(1)
- logits = self.logits(v,q) # b x g x v x q
-
- if v_mask:
- mask = (0 == v.abs().sum(2)).unsqueeze(1).unsqueeze(3).expand(logits.size())
- logits.data.masked_fill_(mask.data, -float('inf'))
-
- p = nn.functional.softmax(logits.view(-1, self.glimpse, v_num * q_num), 2)
- return p.view(-1, self.glimpse, v_num, q_num), logits
-
-
- class TriAttention(nn.Module):
- def __init__(self, v_dim, q_dim, a_dim, h_dim, h_out, rank, glimpse, k, dropout=[.2, .5]):
- super(TriAttention, self).__init__()
- self.glimpse = glimpse
- self.TriAtt = TCNet(v_dim, q_dim, a_dim, h_dim, h_out, rank, glimpse, dropout=dropout, k=k)
-
- def forward(self, v, q, a):
- v_num = v.size(1)
- q_num = q.size(1)
- a_num = a.size(1)
- logits = self.TriAtt(v, q, a)
-
- mask = (0 == v.abs().sum(2)).unsqueeze(2).unsqueeze(3).unsqueeze(4).expand(logits.size())
- logits.data.masked_fill_(mask.data, -float('inf'))
-
- p = torch.softmax(logits.contiguous().view(-1, v_num*q_num*a_num, self.glimpse), 1)
- return p.view(-1, v_num, q_num, a_num, self.glimpse), logits
-
-
- class StackedAttention(nn.Module):
- def __init__(self, num_stacks, img_feat_size, ques_feat_size, att_size, output_size, drop_ratio):
- super(StackedAttention, self).__init__()
-
- self.img_feat_size = img_feat_size
- self.ques_feat_size = ques_feat_size
- self.att_size = att_size
- self.output_size = output_size
- self.drop_ratio = drop_ratio
- self.num_stacks = num_stacks
- self.layers = nn.ModuleList()
-
- self.dropout = nn.Dropout(drop_ratio)
- self.tanh = nn.Tanh()
- self.softmax = nn.Softmax(dim=1)
-
- self.fc11 = nn.Linear(ques_feat_size, att_size, bias=True)
- self.fc12 = nn.Linear(img_feat_size, att_size, bias=False)
- self.fc13 = nn.Linear(att_size, 1, bias=True)
-
- self.fc14 = nn.Linear(ques_feat_size, att_size, bias=True)
- self.fc15 = nn.Linear(img_feat_size, att_size, bias=False)
-
- for stack in range(num_stacks - 1):
- self.layers.append(nn.Linear(att_size, att_size, bias=True))
- self.layers.append(nn.Linear(img_feat_size, att_size, bias=False))
- self.layers.append(nn.Linear(att_size, 1, bias=True))
-
- def forward(self, img_feat, ques_feat, v_mask=True):
-
- # Batch size
- B = ques_feat.size(0)
-
- # Stack 1
- ques_emb = self.fc11(ques_feat)
- img_emb = self.fc12(img_feat)
-
- # Compute attention distribution
- h1 = self.tanh(ques_emb.view(B, 1, self.att_size) + img_emb)
- h1_emb = self.fc13(self.dropout(h1))
- # Mask actual bounding box sizes before calculating softmax
- if v_mask:
- mask = (0 == img_emb.abs().sum(2)).unsqueeze(2).expand(h1_emb.size())
- h1_emb.data.masked_fill_(mask.data, -float('inf'))
-
- p1 = self.softmax(h1_emb)
-
- ques_emb_1 = self.fc14(ques_feat)
- img_emb_1 = self.fc15(img_feat)
- # Compute weighted sum
- img_att_1 = img_emb_1*p1
- weight_sum_1 = torch.sum(img_att_1, dim=1)
-
- # Combine with question vector
- u1 = ques_emb_1 + weight_sum_1
-
- # Other stacks
- us = []
- ques_embs = []
- img_embs = []
- hs = []
- h_embs =[]
- ps = []
- img_atts = []
- weight_sums = []
-
- us.append(u1)
- for stack in range(self.num_stacks - 1):
- ques_embs.append(self.layers[3 * stack + 0](us[-1]))
- img_embs.append(self.layers[3 * stack + 1](img_feat))
-
- # Compute attention distribution
- hs.append(self.tanh(ques_embs[-1].view(B, -1, self.att_size) + img_embs[-1]))
- h_embs.append(self.layers[3*stack + 2](self.dropout(hs[-1])))
- # Mask actual bounding box sizes before calculating softmax
- if v_mask:
- mask = (0 == img_embs[-1].abs().sum(2)).unsqueeze(2).expand(h_embs[-1].size())
- h_embs[-1].data.masked_fill_(mask.data, -float('inf'))
- ps.append(self.softmax(h_embs[-1]))
-
- # Compute weighted sum
- img_atts.append(img_embs[-1] * ps[-1])
- weight_sums.append(torch.sum(img_atts[-1], dim=1))
-
- # Combine with previous stack
- ux = us[-1] + weight_sums[-1]
-
- # Combine with previous stack by multiple
- us.append(ux)
-
- return us[-1]
|