|
- import torch.nn.functional as F
- import math
- import torch
- import torch.nn as nn
- import torchtext
- import numpy as np
- from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
- import logging
-
-
- logger = logging.getLogger(__name__)
-
-
- class Grouping(nn.Module):
- def __init__(self):
- super(Grouping, self).__init__()
-
- self.fc = nn.Linear(2054, 1024)
- self.init_weights()
- embed_size = 1024
- # GSR
- self.img_rnn = nn.GRU(embed_size, embed_size, 1, batch_first=True)
-
- # GCN reasoning
- self.Rs_GCN_1 = Rs_GCN(in_channels=embed_size, inter_channels=embed_size)
- self.Rs_GCN_2 = Rs_GCN(in_channels=embed_size, inter_channels=embed_size)
- self.Rs_GCN_3 = Rs_GCN(in_channels=embed_size, inter_channels=embed_size)
- self.Rs_GCN_4 = Rs_GCN(in_channels=embed_size, inter_channels=embed_size)
-
- self.bn = nn.BatchNorm1d(embed_size)
-
- def init_weights(self):
- """Xavier initialization for the fully connected layer
- """
- r = np.sqrt(6.) / np.sqrt(self.fc.in_features +
- self.fc.out_features)
- self.fc.weight.data.uniform_(-r, r)
- self.fc.bias.data.fill_(0)
-
- def forward(self, images, n_group=4, n_region=30):
- fc_img_emd = self.fc(images)
-
- # GCN reasoning
- # -> B,D,N
- GCN_img_emd = fc_img_emd.permute(0, 2, 1)
- GCN_img_emd = self.Rs_GCN_1(GCN_img_emd)
- GCN_img_emd = self.Rs_GCN_2(GCN_img_emd)
- GCN_img_emd = self.Rs_GCN_3(GCN_img_emd)
- GCN_img_emd = self.Rs_GCN_4(GCN_img_emd)
- # -> B,N,D
- GCN_img_emd = GCN_img_emd.permute(0, 2, 1)
-
- features = l2norm(GCN_img_emd, dim=-1)
-
- A = (features @ features.transpose(-1, -2)) * 32
- A = A.softmax(dim=-1) + (1e-8 * torch.randn(A.shape) * torch.arange(0, A.shape[0]*A.shape[1]*A.shape[2]).reshape(A.shape)).cuda()
-
- n = images.shape[1]
- b = images.shape[0]
- seed = torch.randn(b, n).cuda()
- seed_sort, sort_ind = seed.sort(-1, True)
- mask = torch.zeros_like(sort_ind)
- mask[:,:n_group] = 1
-
- _, sort_ind2 = sort_ind.sort()
- mask = mask.gather(-1, sort_ind2).bool()
- att = A.masked_select(mask.unsqueeze(-1)).reshape(b, n_group, -1)
- # 排序再弄
- # att_sort, sort_ind = att.sort(-1, True)
- # att_sort[:, :, :n_region]
- att_sort, sort_ind = att.sort(-1, True)
- att_mask = torch.zeros_like(sort_ind)
- att_mask[:,:,:n_region] = 1
-
- _, sort_ind2 = sort_ind.sort()
- att_mask_ = att_mask.gather(-1, sort_ind2).bool()
-
- grouped_images = []
- for i in range(n_group):
- grouped_images.append(images.masked_select(att_mask_[:,0].unsqueeze(-1)).reshape(b, -1, 2054))
-
- return grouped_images
-
-
- class Rs_GCN(nn.Module):
-
- def __init__(self, in_channels, inter_channels, bn_layer=True):
- super(Rs_GCN, self).__init__()
-
- self.in_channels = in_channels
- self.inter_channels = inter_channels
-
- if self.inter_channels is None:
- self.inter_channels = in_channels // 2
- if self.inter_channels == 0:
- self.inter_channels = 1
-
- conv_nd = nn.Conv1d
- max_pool = nn.MaxPool1d
- bn = nn.BatchNorm1d
-
- self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
- kernel_size=1, stride=1, padding=0)
-
- if bn_layer:
- self.W = nn.Sequential(
- conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
- kernel_size=1, stride=1, padding=0),
- bn(self.in_channels)
- )
- nn.init.constant(self.W[1].weight, 0)
- nn.init.constant(self.W[1].bias, 0)
- else:
- self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
- kernel_size=1, stride=1, padding=0)
- nn.init.constant(self.W.weight, 0)
- nn.init.constant(self.W.bias, 0)
-
- self.theta = None
- self.phi = None
-
- self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
- kernel_size=1, stride=1, padding=0)
- self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
- kernel_size=1, stride=1, padding=0)
-
- def forward(self, v):
- '''
- :param v: (B, D, N)
- :return:
- '''
- batch_size = v.size(0)
-
- g_v = self.g(v).view(batch_size, self.inter_channels, -1)
- g_v = g_v.permute(0, 2, 1)
-
- theta_v = self.theta(v).view(batch_size, self.inter_channels, -1)
- theta_v = theta_v.permute(0, 2, 1)
- phi_v = self.phi(v).view(batch_size, self.inter_channels, -1)
- R = torch.matmul(theta_v, phi_v)
- N = R.size(-1)
- R_div_C = R / N
-
- y = torch.matmul(R_div_C, g_v)
- y = y.permute(0, 2, 1).contiguous()
- y = y.view(batch_size, self.inter_channels, *v.size()[2:])
- W_y = self.W(y)
- v_star = W_y + v
-
- return v_star
-
-
- class MLP(nn.Module):
- """ Very simple multi-layer perceptron (also called FFN)"""
-
- def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
- super().__init__()
- self.output_dim = output_dim
- self.num_layers = num_layers
- h = [hidden_dim] * (num_layers - 1)
- self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
- self.bns = nn.ModuleList(nn.BatchNorm1d(k) for k in h + [output_dim])
-
- def forward(self, x):
- B, N, D = x.size()
- x = x.reshape(B*N, D)
- for i, (bn, layer) in enumerate(zip(self.bns, self.layers)):
- x = F.relu(bn(layer(x))) if i < self.num_layers - 1 else layer(x)
- x = x.view(B, N, self.output_dim)
- return x
-
-
- def positional_encoding_1d(d_model, length):
- """
- :param d_model: dimension of the model
- :param length: length of positions
- :return: length*d_model position matrix
- """
- if d_model % 2 != 0:
- raise ValueError("Cannot use sin/cos positional encoding with "
- "odd dim (got dim={:d})".format(d_model))
- pe = torch.zeros(length, d_model)
- position = torch.arange(0, length).unsqueeze(1)
- div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
- -(math.log(10000.0) / d_model)))
- pe[:, 0::2] = torch.sin(position.float() * div_term)
- pe[:, 1::2] = torch.cos(position.float() * div_term)
-
- return pe
-
-
- class GPO(nn.Module):
- def __init__(self, d_pe, d_hidden):
- super(GPO, self).__init__()
- self.d_pe = d_pe
- self.d_hidden = d_hidden
-
- self.pe_database = {}
- self.gru = nn.GRU(self.d_pe, d_hidden, 1, batch_first=True, bidirectional=True)
- self.linear = nn.Linear(self.d_hidden, 1, bias=False)
-
- def compute_pool_weights(self, lengths, features):
- max_len = int(lengths.max())
- pe_max_len = self.get_pe(max_len)
- pes = pe_max_len.unsqueeze(0).repeat(lengths.size(0), 1, 1).to(lengths.device)
- mask = torch.arange(max_len).expand(lengths.size(0), max_len).to(lengths.device)
- mask = (mask < lengths.long().unsqueeze(1)).unsqueeze(-1)
- pes = pes.masked_fill(mask == 0, 0)
-
- self.gru.flatten_parameters()
- packed = pack_padded_sequence(pes, lengths.cpu(), batch_first=True, enforce_sorted=False)
- out, _ = self.gru(packed)
- padded = pad_packed_sequence(out, batch_first=True)
- out_emb, out_len = padded
- out_emb = (out_emb[:, :, :out_emb.size(2) // 2] + out_emb[:, :, out_emb.size(2) // 2:]) / 2
- scores = self.linear(out_emb)
- scores[torch.where(mask == 0)] = -10000
-
- weights = torch.softmax(scores / 0.1, 1)
- return weights, mask
-
- def forward(self, features, lengths):
- """
- :param features: features with shape B x K x D
- :param lengths: B x 1, specify the length of each data sample.
- :return: pooled feature with shape B x D
- """
- pool_weights, mask = self.compute_pool_weights(lengths, features)
-
- features = features[:, :int(lengths.max()), :]
- sorted_features = features.masked_fill(mask == 0, -10000)
- sorted_features, sort = sorted_features.sort(dim=1, descending=True)
-
- sorted_features = sorted_features.masked_fill(mask == 0, 0)
-
- pooled_features = (sorted_features * pool_weights).sum(1)
- return pooled_features, pool_weights
-
- def get_pe(self, length):
- """
-
- :param length: the length of the sequence
- :return: the positional encoding of the given length
- """
- length = int(length)
- if length in self.pe_database:
- return self.pe_database[length]
- else:
- pe = positional_encoding_1d(self.d_pe, length)
- self.pe_database[length] = pe
- return pe
-
-
- def l1norm(X, dim, eps=1e-8):
- """L1-normalize columns of X
- """
- norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
- X = torch.div(X, norm)
- return X
-
-
- def l2norm(X, dim, eps=1e-8):
- """L2-normalize columns of X
- """
- norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
- X = torch.div(X, norm)
- return X
-
-
- def maxk_pool1d_var(x, dim, k, lengths):
- results = list()
- lengths = list(lengths.cpu().numpy())
- lengths = [int(x) for x in lengths]
- for idx, length in enumerate(lengths):
- k = min(k, length)
- max_k_i = maxk(x[idx, :length, :], dim - 1, k).mean(dim - 1)
- results.append(max_k_i)
- results = torch.stack(results, dim=0)
- return results
-
-
- def maxk_pool1d(x, dim, k):
- max_k = maxk(x, dim, k)
- return max_k.mean(dim)
-
-
- def maxk(x, dim, k):
- index = x.topk(k, dim=dim)[1]
- return x.gather(dim, index)
-
-
- def get_text_encoder(vocab_size, embed_size, word_dim, num_layers, opt, use_bi_gru=True, no_txtnorm=False):
- return EncoderText(vocab_size, embed_size, word_dim, num_layers, opt, use_bi_gru=use_bi_gru,
- no_txtnorm=no_txtnorm)
-
-
- def get_sim_encoder(opt):
- return EncoderSimilarityGraph(opt)
-
-
- def get_image_encoder(img_dim, embed_size, precomp_enc_type='basic',
- backbone_source=None, backbone_path=None, no_imgnorm=False):
- """A wrapper to image encoders. Chooses between an different encoders
- that uses precomputed image features.
- """
- if precomp_enc_type == 'basic':
- img_enc = EncoderImageAggr(
- img_dim, embed_size, precomp_enc_type, no_imgnorm)
- elif precomp_enc_type == 'backbone':
- backbone_cnn = ResnetFeatureExtractor(backbone_source, backbone_path, fixed_blocks=2)
- img_enc = EncoderImageFull(backbone_cnn, img_dim, embed_size, precomp_enc_type, no_imgnorm)
- else:
- raise ValueError("Unknown precomp_enc_type: {}".format(precomp_enc_type))
-
- return img_enc
-
-
- class EncoderImageAggr(nn.Module):
- def __init__(self, img_dim, embed_size, precomp_enc_type='basic', no_imgnorm=False):
- super(EncoderImageAggr, self).__init__()
- # embed_size = embed_size * 2
- self.embed_size = embed_size
- self.no_imgnorm = no_imgnorm
- self.dropout = nn.Dropout(0.2)
- self.fc = nn.Linear(img_dim, embed_size)
-
- self.precomp_enc_type = precomp_enc_type
- if precomp_enc_type == 'basic':
- self.mlp = MLP(img_dim, embed_size // 2, embed_size, 2)
- self.gpool = GPO(32, 32)
- self.init_weights()
-
- def init_weights(self):
- """Xavier initialization for the fully connected layer
- """
- r = np.sqrt(6.) / np.sqrt(self.fc.in_features +
- self.fc.out_features)
- self.fc.weight.data.uniform_(-r, r)
- self.fc.bias.data.fill_(0)
-
-
-
- def forward(self, images, image_lengths):
- """Extract image feature vectors."""
- # if torch.cuda.is_available():
- # img_p1 = images[0].cuda()
- # img_p2 = images[1].cuda()
-
- # img_len1 = image_lengths[0].cuda()
- # img_len2 = image_lengths[1].cuda()
-
- images = self.dropout(images)
- # img_p2 = self.dropout(img_p2)
-
- reg_emb = self.fc(images)
- # reg_emb_p2 = self.fc2(img_p2)
- if self.precomp_enc_type == 'basic':
- # When using pre-extracted region features, add an extra MLP for embedding transformation
- reg_emb = self.mlp(images) + reg_emb
- # reg_emb_p2 = self.mlp2(img_p2) + reg_emb_p2
-
- img_emb, _ = self.gpool(reg_emb, image_lengths)
- # img_emb_2, _ = self.gpool2(reg_emb_p2, img_len2)
- # features = features.mean(dim=1)
-
- # img_emb = torch.cat([img_emb_1, img_emb_2], dim=1)
- if not self.no_imgnorm:
- img_emb = l2norm(img_emb, dim=-1)
- # reg_emb = l2norm(reg_emb, dim=-1)
-
- return img_emb
-
-
- class EncoderImageFull(nn.Module):
- def __init__(self, backbone_cnn, img_dim, embed_size, precomp_enc_type='basic', no_imgnorm=False):
- super(EncoderImageFull, self).__init__()
- self.backbone = backbone_cnn
- self.image_encoder = EncoderImageAggr(img_dim, embed_size, precomp_enc_type, no_imgnorm)
- self.backbone_freezed = False
-
- def forward(self, images):
- """Extract image feature vectors."""
- base_features = self.backbone(images)
-
- if self.training:
- # Size Augmentation during training, randomly drop grids
- base_length = base_features.size(1)
- features = []
- feat_lengths = []
- rand_list_1 = np.random.rand(base_features.size(0), base_features.size(1))
- rand_list_2 = np.random.rand(base_features.size(0))
- for i in range(base_features.size(0)):
- if rand_list_2[i] > 0.2:
- feat_i = base_features[i][np.where(rand_list_1[i] > 0.20 * rand_list_2[i])]
- len_i = len(feat_i)
- pads_i = torch.zeros(base_length - len_i, base_features.size(-1)).to(base_features.device)
- feat_i = torch.cat([feat_i, pads_i], dim=0)
- else:
- feat_i = base_features[i]
- len_i = base_length
- feat_lengths.append(len_i)
- features.append(feat_i)
- base_features = torch.stack(features, dim=0)
- base_features = base_features[:, :max(feat_lengths), :]
- feat_lengths = torch.tensor(feat_lengths).to(base_features.device)
- else:
- feat_lengths = torch.zeros(base_features.size(0)).to(base_features.device)
- feat_lengths[:] = base_features.size(1)
-
- features = self.image_encoder(base_features, feat_lengths)
-
- return features
-
- def freeze_backbone(self):
- for param in self.backbone.parameters():
- param.requires_grad = False
- print('Backbone freezed.')
-
- def unfreeze_backbone(self, fixed_blocks):
- for param in self.backbone.parameters(): # open up all params first, then adjust the base parameters
- param.requires_grad = True
- self.backbone.set_fixed_blocks(fixed_blocks)
- self.backbone.unfreeze_base()
- print('Backbone unfreezed, fixed blocks {}'.format(self.backbone.get_fixed_blocks()))
-
-
- # Language Model with BiGRU
- class EncoderText(nn.Module):
- def __init__(self, vocab_size, embed_size, word_dim, num_layers, opt, use_bi_gru=True, no_txtnorm=False):
- super(EncoderText, self).__init__()
- self.opt = opt
- self.embed_size = embed_size
- self.no_txtnorm = no_txtnorm
- # word embedding
- self.embed = nn.Embedding(vocab_size, word_dim)
- self.dropout = nn.Dropout(0.2)
- self.gpool = GPO(32, 32)
- self.ln = nn.LayerNorm(embed_size)
- # caption embedding
- self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True, bidirectional=use_bi_gru)
-
- embed_size1 = embed_size
- self.fc = nn.Linear(embed_size, embed_size1)
- self.mlp = MLP(embed_size, embed_size1 // 2, embed_size1, 2)
- self.init_weights()
-
- def init_weights(self):
- # self.embed.weight.data.uniform_(-0.1, 0.1)
- wemb = torchtext.vocab.GloVe(cache="/tmp/dataset/vector_cache/vector_cache")
- word2idx = self.opt.word2idx
- # quick-and-dirty trick to improve word-hit rate
- missing_words = []
- for word, idx in word2idx.items():
- if word not in wemb.stoi:
- # word = word.replace('-', '').replace('.', '').replace("'", '')
- word = word.replace('.', '').replace("'", '')
- if '/' in word:
- word = word.split('/')[0]
- if word in wemb.stoi:
- self.embed.weight.data[idx] = wemb.vectors[wemb.stoi[word]]
- else:
- missing_words.append(word)
- print('Words: {}/{} found in vocabulary; {} words missing'.format(
- len(word2idx) - len(missing_words), len(word2idx), len(missing_words)))
- print(missing_words)
-
- def forward(self, x, lengths):
- """Handles variable size captions
- """
- # Embed word ids to vectors
- x_emb = self.embed(x)
- x_emb = self.dropout(x_emb)
- self.rnn.flatten_parameters()
- packed = pack_padded_sequence(x_emb, lengths.cpu(), batch_first=True)
-
- # Forward propagate RNN
- out, _ = self.rnn(packed)
-
- # Reshape *final* output to (batch_size, hidden_size)
- padded = pad_packed_sequence(out, batch_first=True)
- cap_emb, cap_len = padded
- cap_emb = (cap_emb[:, :, :cap_emb.size(2) // 2] + cap_emb[:, :, cap_emb.size(2) // 2:]) / 2
-
- features = self.fc(cap_emb)
- features = self.mlp(cap_emb) + features
-
- pooled_features, pool_weights = self.gpool(features, cap_len.to(features.device))
-
- pooled_features = self.ln(pooled_features)
-
- return pooled_features
-
-
- def log_sinkhorn_iterations(Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor, iters: int) -> torch.Tensor:
- """ Perform Sinkhorn Normalization in Log-space for stability"""
- u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
- for _ in range(iters):
- u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)
- v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)
- return Z + u.unsqueeze(2) + v.unsqueeze(1)
-
-
- def log_optimal_transport(scores: torch.Tensor, alpha: torch.Tensor, iters: int) -> torch.Tensor:
- """ Perform Differentiable Optimal Transport in Log-space for stability"""
- b, m, n = scores.shape
- one = scores.new_tensor(1)
- ms, ns = (m * one).to(scores), (n * one).to(scores)
-
- bins0 = alpha.expand(b, m, 1)
- bins1 = alpha.expand(b, 1, n)
- alpha = alpha.expand(b, 1, 1)
-
- couplings = torch.cat([torch.cat([scores, bins0], -1),
- torch.cat([bins1, alpha], -1)], 1)
-
- norm = - (ms + ns).log()
- log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm])
- log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])
- log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)
-
- Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)
- Z = Z - norm # multiply probabilities by M+N
- return Z
-
-
- class EncoderSimilarityGraph(nn.Module):
- def __init__(self, opt):
- super(EncoderSimilarityGraph, self).__init__()
- self.opt = opt
- self.block_dim = opt.block_dim
- self.alpha = torch.Tensor(opt.alpha).cuda()
- bin_score = torch.nn.Parameter(torch.tensor(0.))
- self.register_parameter('bin_score', bin_score)
-
- def forward(self, img_emb, cap_emb):
- cap_emb = l2norm(cap_emb, -1)
- n_cap, cap_dim = cap_emb.size(0), cap_emb.size(1)
- n_img, img_dim = img_emb.size(0), img_emb.size(1)
- sims = []
- for i, block_dim in enumerate(self.block_dim):
- img_blk_num, cap_blk_num = img_emb.size(1) // block_dim, cap_emb.size(1) // block_dim
- img_emb_blocks = torch.chunk(img_emb, img_blk_num, -1) # (bs, 2*n, block_dim)
- cap_emb_blocks = torch.chunk(cap_emb, cap_blk_num, -1) # (bs, n, block_dim)
-
- img_emb_blocks = torch.stack(img_emb_blocks, dim=1) # (bs, 2*n, block_dim)
- cap_emb_blocks = torch.stack(cap_emb_blocks, dim=1) # (bs, n, block_dim)
-
-
- img_emb_blocks = l2norm(img_emb_blocks, -1) # (bs, 2*n, block_dim)
- cap_emb_blocks = l2norm(cap_emb_blocks, -1)
-
- logits = torch.einsum("avc,btc->abvt", [img_emb_blocks, cap_emb_blocks]) # (bs, bs, 2*n, n)
-
- logits = log_optimal_transport(logits.reshape(-1, img_blk_num, cap_blk_num), self.bin_score, 20)[:, :-1,
- :-1].reshape(n_img, n_cap, img_blk_num, cap_blk_num)
- t2i_logits = logits.max(dim=-2)[0]
- sims.append(t2i_logits.sum(dim=-1))
-
- sims = torch.stack(sims, -1).sum(-1)
-
- return sims
|