|
- import torch.nn.functional as F
- import math
- import torch
- import torch.nn as nn
- import torchtext
- import numpy as np
- from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
- import logging
- import random
- from transformers import BertModel
-
- logger = logging.getLogger(__name__)
-
-
- class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
- class Attention(nn.Module):
- def __init__(self,
- dim,
- num_heads,
- out_dim=None,
- qkv_bias=False,
- qk_scale=None,
- attn_drop=0.,
- proj_drop=0.):
- super().__init__()
- if out_dim is None:
- out_dim = dim
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim**-0.5
-
- self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
- self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
- self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
-
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, out_dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, query, key=None, *, value=None):
- B, N, C = query.shape
- if key is None:
- key = query
- if value is None:
- value = key
- S = key.size(1)
- # [B, nh, N, C//nh]
- q = self.q_proj(query)
- q = q.reshape(B,N,1, self.num_heads, C // self.num_heads).permute(2,0,3,1,4).squeeze(0)
-
- k = self.k_proj(key)
- k = k.reshape(B,k.shape[1],1, self.num_heads, C // self.num_heads).permute(2,0,3,1,4).squeeze(0)
-
- v = self.v_proj(value)
- v = v.reshape(B,v.shape[1],1, self.num_heads, C // self.num_heads).permute(2,0,3,1,4).squeeze(0)
-
- # [B, nh, N, S]
- attn = (q @ k.transpose(-2, -1)) * self.scale
-
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
-
- out = (attn @ v).transpose(1, 2).reshape(B, N, -1)
- out = self.proj(out)
- out = self.proj_drop(out)
- return out
-
- class CrossAttnBlock(nn.Module):
- def __init__(self,
- dim,
- num_heads,
- mlp_ratio=4.,
- qkv_bias=False,
- qk_scale=None,
- drop=0.,
- attn_drop=0.,
- drop_path=0.,
- act_layer=nn.GELU,
- norm_layer=nn.LayerNorm,
- post_norm=False):
- super().__init__()
- self.norm_q = norm_layer(dim)
- self.norm_k = norm_layer(dim)
-
- self.attn = Attention(
- dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
- def forward(self, query, key, *, mask=None):
- x = query
- x = x + self.drop_path(self.attn(self.norm_q(query), self.norm_k(key)))
- x = x + self.drop_path(self.mlp(self.norm2(x)))
- return x
-
-
- class Grouping(nn.Module):
- def __init__(self, n_group, img_dim):
- super(Grouping, self).__init__()
-
- self.fc = nn.Linear(img_dim, 768)
- self.init_weights()
- embed_size = 768
- # GSR
- self.group = CrossAttnBlock(
- dim=embed_size,
- num_heads=6,
- mlp_ratio=4,
- qkv_bias=True,
- )
-
- self.group_token = nn.Parameter(torch.randn(n_group, embed_size))
- self.group_proj = nn.Linear(embed_size, img_dim)
-
- def init_weights(self):
- """Xavier initialization for the fully connected layer
- """
- r = np.sqrt(6.) / np.sqrt(self.fc.in_features +
- self.fc.out_features)
- self.fc.weight.data.uniform_(-r, r)
- self.fc.bias.data.fill_(0)
-
- def forward(self, images, n_group=4, n_region=30):
- fc_img_emd = self.fc(images)
-
- # GCN reasoning
- # -> B,D,N
- group_tokens = self.group_token.expand(images.shape[0], -1, -1)
-
- grouped = self.group_proj(self.group(group_tokens, fc_img_emd))
-
- grouped = l2norm(grouped, dim=-1)
- features = l2norm(images, dim=-1)
-
- A = (grouped @ features.transpose(-1, -2))
- att = A.softmax(dim=-1) + (1e-8)
-
- n = images.shape[1]
- b = images.shape[0]
-
- att_sort, sort_ind = att.sort(-1, True)
- att_mask = torch.zeros_like(sort_ind)
- # att_mask_tmp = att_mask[]
- if self.training:
- for i in range(n_group):
- j = random.randint(int(0.9*n) , n)
- att_mask[:,i,:j] = 1
- else:
- att_mask[:,:,:n_region] = 1
-
- _, sort_ind2 = sort_ind.sort()
- att_mask_ = att_mask.gather(-1, sort_ind2).bool()
- # print(att_mask.shape)
- # print(att_mask_[:,0 ].sum())
- grouped_images = []
- grouped_lengths = []
- for i in range(n_group):
- base_features = images.masked_select(att_mask_[:,i].unsqueeze(-1)).reshape(b, -1, images.shape[-1])
-
-
- feat_lengths = torch.zeros(base_features.size(0)).to(base_features.device)
- feat_lengths[:] = base_features.size(1)
-
- grouped_images.append(base_features)
- grouped_lengths.append(feat_lengths)
-
- return grouped_images, grouped_lengths
-
-
-
- class Rs_GCN(nn.Module):
-
- def __init__(self, in_channels, inter_channels, bn_layer=True):
- super(Rs_GCN, self).__init__()
-
- self.in_channels = in_channels
- self.inter_channels = inter_channels
-
- if self.inter_channels is None:
- self.inter_channels = in_channels // 2
- if self.inter_channels == 0:
- self.inter_channels = 1
-
- conv_nd = nn.Conv1d
- max_pool = nn.MaxPool1d
- bn = nn.BatchNorm1d
-
- self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
- kernel_size=1, stride=1, padding=0)
-
- if bn_layer:
- self.W = nn.Sequential(
- conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
- kernel_size=1, stride=1, padding=0),
- bn(self.in_channels)
- )
- nn.init.constant(self.W[1].weight, 0)
- nn.init.constant(self.W[1].bias, 0)
- else:
- self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
- kernel_size=1, stride=1, padding=0)
- nn.init.constant(self.W.weight, 0)
- nn.init.constant(self.W.bias, 0)
-
- self.theta = None
- self.phi = None
-
- self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
- kernel_size=1, stride=1, padding=0)
- self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
- kernel_size=1, stride=1, padding=0)
-
- def forward(self, v):
- '''
- :param v: (B, D, N)
- :return:
- '''
- batch_size = v.size(0)
-
- g_v = self.g(v).view(batch_size, self.inter_channels, -1)
- g_v = g_v.permute(0, 2, 1)
-
- theta_v = self.theta(v).view(batch_size, self.inter_channels, -1)
- theta_v = theta_v.permute(0, 2, 1)
- phi_v = self.phi(v).view(batch_size, self.inter_channels, -1)
- R = torch.matmul(theta_v, phi_v)
- N = R.size(-1)
- R_div_C = R / N
-
- y = torch.matmul(R_div_C, g_v)
- y = y.permute(0, 2, 1).contiguous()
- y = y.view(batch_size, self.inter_channels, *v.size()[2:])
- W_y = self.W(y)
- v_star = W_y + v
-
- return v_star
-
-
- class MLP(nn.Module):
- """ Very simple multi-layer perceptron (also called FFN)"""
-
- def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
- super().__init__()
- self.output_dim = output_dim
- self.num_layers = num_layers
- h = [hidden_dim] * (num_layers - 1)
- self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
- self.bns = nn.ModuleList(nn.BatchNorm1d(k) for k in h + [output_dim])
-
- def forward(self, x):
- B, N, D = x.size()
- x = x.reshape(B*N, D)
- for i, (bn, layer) in enumerate(zip(self.bns, self.layers)):
- x = F.relu(bn(layer(x))) if i < self.num_layers - 1 else layer(x)
- x = x.view(B, N, self.output_dim)
- return x
-
-
- def positional_encoding_1d(d_model, length):
- """
- :param d_model: dimension of the model
- :param length: length of positions
- :return: length*d_model position matrix
- """
- if d_model % 2 != 0:
- raise ValueError("Cannot use sin/cos positional encoding with "
- "odd dim (got dim={:d})".format(d_model))
- pe = torch.zeros(length, d_model)
- position = torch.arange(0, length).unsqueeze(1)
- div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
- -(math.log(10000.0) / d_model)))
- pe[:, 0::2] = torch.sin(position.float() * div_term)
- pe[:, 1::2] = torch.cos(position.float() * div_term)
-
- return pe
-
-
- class GPO(nn.Module):
- def __init__(self, d_pe, d_hidden):
- super(GPO, self).__init__()
- self.d_pe = d_pe
- self.d_hidden = d_hidden
-
- self.pe_database = {}
- self.gru = nn.GRU(self.d_pe, d_hidden, 1, batch_first=True, bidirectional=True)
- self.linear = nn.Linear(self.d_hidden, 1, bias=False)
-
- def compute_pool_weights(self, lengths, features):
- max_len = int(lengths.max())
- pe_max_len = self.get_pe(max_len)
- pes = pe_max_len.unsqueeze(0).repeat(lengths.size(0), 1, 1).to(lengths.device)
- mask = torch.arange(max_len).expand(lengths.size(0), max_len).to(lengths.device)
- mask = (mask < lengths.long().unsqueeze(1)).unsqueeze(-1)
- pes = pes.masked_fill(mask == 0, 0)
-
- self.gru.flatten_parameters()
- packed = pack_padded_sequence(pes, lengths.cpu(), batch_first=True, enforce_sorted=False)
- out, _ = self.gru(packed)
- padded = pad_packed_sequence(out, batch_first=True)
- out_emb, out_len = padded
- out_emb = (out_emb[:, :, :out_emb.size(2) // 2] + out_emb[:, :, out_emb.size(2) // 2:]) / 2
- scores = self.linear(out_emb)
- scores[torch.where(mask == 0)] = -10000
-
- weights = torch.softmax(scores / 0.1, 1)
- return weights, mask
-
- def forward(self, features, lengths):
- """
- :param features: features with shape B x K x D
- :param lengths: B x 1, specify the length of each data sample.
- :return: pooled feature with shape B x D
- """
- pool_weights, mask = self.compute_pool_weights(lengths, features)
-
- features = features[:, :int(lengths.max()), :]
- sorted_features = features.masked_fill(mask == 0, -10000)
- sorted_features, sort = sorted_features.sort(dim=1, descending=True)
-
- sorted_features = sorted_features.masked_fill(mask == 0, 0)
-
- pooled_features = (sorted_features * pool_weights).sum(1)
- return pooled_features, pool_weights
-
- def get_pe(self, length):
- """
-
- :param length: the length of the sequence
- :return: the positional encoding of the given length
- """
- length = int(length)
- if length in self.pe_database:
- return self.pe_database[length]
- else:
- pe = positional_encoding_1d(self.d_pe, length)
- self.pe_database[length] = pe
- return pe
-
-
- def l1norm(X, dim, eps=1e-8):
- """L1-normalize columns of X
- """
- norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
- X = torch.div(X, norm)
- return X
-
-
- def l2norm(X, dim, eps=1e-8):
- """L2-normalize columns of X
- """
- norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
- X = torch.div(X, norm)
- return X
-
-
- def maxk_pool1d_var(x, dim, k, lengths):
- results = list()
- lengths = list(lengths.cpu().numpy())
- lengths = [int(x) for x in lengths]
- for idx, length in enumerate(lengths):
- k = min(k, length)
- max_k_i = maxk(x[idx, :length, :], dim - 1, k).mean(dim - 1)
- results.append(max_k_i)
- results = torch.stack(results, dim=0)
- return results
-
-
- def maxk_pool1d(x, dim, k):
- max_k = maxk(x, dim, k)
- return max_k.mean(dim)
-
-
- def maxk(x, dim, k):
- index = x.topk(k, dim=dim)[1]
- return x.gather(dim, index)
-
-
- def get_text_encoder(vocab_size, embed_size, word_dim, num_layers, opt, use_bi_gru=True, no_txtnorm=False):
- return EncoderText(vocab_size, embed_size, word_dim, num_layers, opt, use_bi_gru=use_bi_gru,
- no_txtnorm=no_txtnorm)
-
-
- def get_sim_encoder(opt):
- return EncoderSimilarityGraph(opt)
-
-
- def get_image_encoder(img_dim, embed_size, precomp_enc_type='basic',
- backbone_source=None, backbone_path=None, no_imgnorm=False):
- """A wrapper to image encoders. Chooses between an different encoders
- that uses precomputed image features.
- """
- if precomp_enc_type == 'basic':
- img_enc = EncoderImageAggr(
- img_dim, embed_size, precomp_enc_type, no_imgnorm)
- elif precomp_enc_type == 'backbone':
- backbone_cnn = ResnetFeatureExtractor(backbone_source, backbone_path, fixed_blocks=2)
- img_enc = EncoderImageFull(backbone_cnn, img_dim, embed_size, precomp_enc_type, no_imgnorm)
- else:
- raise ValueError("Unknown precomp_enc_type: {}".format(precomp_enc_type))
-
- return img_enc
-
-
- class EncoderImageAggr(nn.Module):
- def __init__(self, img_dim, embed_size, precomp_enc_type='basic', no_imgnorm=False):
- super(EncoderImageAggr, self).__init__()
- # embed_size = embed_size * 2
- self.embed_size = embed_size
- self.no_imgnorm = no_imgnorm
- self.dropout = nn.Dropout(0.2)
- self.fc = nn.Linear(img_dim, embed_size)
-
- self.precomp_enc_type = precomp_enc_type
- if precomp_enc_type == 'basic':
- self.mlp = MLP(img_dim, embed_size // 2, embed_size, 2)
- self.gpool = GPO(32, 32)
- self.init_weights()
-
- def init_weights(self):
- """Xavier initialization for the fully connected layer
- """
- r = np.sqrt(6.) / np.sqrt(self.fc.in_features +
- self.fc.out_features)
- self.fc.weight.data.uniform_(-r, r)
- self.fc.bias.data.fill_(0)
-
-
-
- def forward(self, images, image_lengths):
- """Extract image feature vectors."""
- # if torch.cuda.is_available():
- # img_p1 = images[0].cuda()
- # img_p2 = images[1].cuda()
-
- # img_len1 = image_lengths[0].cuda()
- # img_len2 = image_lengths[1].cuda()
-
- images = self.dropout(images)
- # img_p2 = self.dropout(img_p2)
-
- reg_emb = self.fc(images)
- # reg_emb_p2 = self.fc2(img_p2)
- if self.precomp_enc_type == 'basic':
- # When using pre-extracted region features, add an extra MLP for embedding transformation
- reg_emb = self.mlp(images) + reg_emb
- # reg_emb_p2 = self.mlp2(img_p2) + reg_emb_p2
-
- img_emb, _ = self.gpool(reg_emb, image_lengths)
- # img_emb_2, _ = self.gpool2(reg_emb_p2, img_len2)
- # features = features.mean(dim=1)
-
- # img_emb = torch.cat([img_emb_1, img_emb_2], dim=1)
- if not self.no_imgnorm:
- img_emb = l2norm(img_emb, dim=-1)
- # reg_emb = l2norm(reg_emb, dim=-1)
-
- return img_emb
-
-
- class EncoderImageFull(nn.Module):
- def __init__(self, backbone_cnn, img_dim, embed_size, precomp_enc_type='basic', no_imgnorm=False):
- super(EncoderImageFull, self).__init__()
- self.backbone = backbone_cnn
- self.image_encoder = EncoderImageAggr(img_dim, embed_size, precomp_enc_type, no_imgnorm)
- self.backbone_freezed = False
-
- def forward(self, images):
- """Extract image feature vectors."""
- base_features = self.backbone(images)
-
- if self.training:
- # Size Augmentation during training, randomly drop grids
- base_length = base_features.size(1)
- features = []
- feat_lengths = []
- rand_list_1 = np.random.rand(base_features.size(0), base_features.size(1))
- rand_list_2 = np.random.rand(base_features.size(0))
- for i in range(base_features.size(0)):
- if rand_list_2[i] > 0.2:
- feat_i = base_features[i][np.where(rand_list_1[i] > 0.20 * rand_list_2[i])]
- len_i = len(feat_i)
- pads_i = torch.zeros(base_length - len_i, base_features.size(-1)).to(base_features.device)
- feat_i = torch.cat([feat_i, pads_i], dim=0)
- else:
- feat_i = base_features[i]
- len_i = base_length
- feat_lengths.append(len_i)
- features.append(feat_i)
- base_features = torch.stack(features, dim=0)
- base_features = base_features[:, :max(feat_lengths), :]
- feat_lengths = torch.tensor(feat_lengths).to(base_features.device)
- else:
- feat_lengths = torch.zeros(base_features.size(0)).to(base_features.device)
- feat_lengths[:] = base_features.size(1)
-
- features = self.image_encoder(base_features, feat_lengths)
-
- return features
-
- def freeze_backbone(self):
- for param in self.backbone.parameters():
- param.requires_grad = False
- print('Backbone freezed.')
-
- def unfreeze_backbone(self, fixed_blocks):
- for param in self.backbone.parameters(): # open up all params first, then adjust the base parameters
- param.requires_grad = True
- self.backbone.set_fixed_blocks(fixed_blocks)
- self.backbone.unfreeze_base()
- print('Backbone unfreezed, fixed blocks {}'.format(self.backbone.get_fixed_blocks()))
-
-
- # Language Model with BiGRU
- class EncoderText(nn.Module):
- def __init__(self, vocab_size, embed_size, word_dim, num_layers, opt, use_bi_gru=True, no_txtnorm=False):
- super(EncoderText, self).__init__()
- self.opt = opt
- self.embed_size = embed_size
- self.no_txtnorm = no_txtnorm
- # word embedding
- # self.embed = nn.Embedding(vocab_size, word_dim)
-
- self.bert = BertModel.from_pretrained('bert-base-uncased')
- # for p in self.bert.parameters():
- # p.requires_grad=False
- # self.embed = bert.em
-
- self.linear = nn.Linear(768, embed_size)
-
- self.dropout = nn.Dropout(0.2)
- self.gpool = GPO(32, 32)
- self.ln = nn.LayerNorm(embed_size)
- # caption embedding
-
- embed_size1 = embed_size
- self.fc = nn.Linear(embed_size, embed_size1)
- self.mlp = MLP(768, embed_size1 // 2, embed_size1, 2)
-
- def forward(self, x, lengths):
- """Handles variable size captions
- """
- # Embed word ids to vectors
- # x_emb = self.embed(x)
- # x_emb = self.dropout(x_emb)
- # self.rnn.flatten_parameters()
- # packed = pack_padded_sequence(x_emb, lengths.cpu(), batch_first=True)
- #
- # # Forward propagate RNN
- # out, _ = self.rnn(packed)
-
- bert_attention_mask = (x != 0).float()
- bert_emb = self.bert(x, bert_attention_mask)[0] # B x N x D
-
- bert_emb = self.dropout(bert_emb)
-
- cap_len = lengths
- cap_emb = self.linear(bert_emb)
-
- # features = self.fc(cap_emb)
- features = self.mlp(bert_emb) + cap_emb
-
- pooled_features, pool_weights = self.gpool(features, cap_len.to(features.device))
-
- pooled_features = self.ln(pooled_features)
-
- return pooled_features
-
-
- def log_sinkhorn_iterations(Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor, iters: int) -> torch.Tensor:
- """ Perform Sinkhorn Normalization in Log-space for stability"""
- u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
- for _ in range(iters):
- u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)
- v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)
- return Z + u.unsqueeze(2) + v.unsqueeze(1)
-
-
- def log_optimal_transport(scores: torch.Tensor, alpha: torch.Tensor, iters: int) -> torch.Tensor:
- """ Perform Differentiable Optimal Transport in Log-space for stability"""
- b, m, n = scores.shape
- one = scores.new_tensor(1)
- ms, ns = (m * one).to(scores), (n * one).to(scores)
-
- bins0 = alpha.expand(b, m, 1)
- bins1 = alpha.expand(b, 1, n)
- alpha = alpha.expand(b, 1, 1)
-
- couplings = torch.cat([torch.cat([scores, bins0], -1),
- torch.cat([bins1, alpha], -1)], 1)
-
- norm = - (ms + ns).log()
- log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm])
- log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])
- log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)
-
- Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)
- Z = Z - norm # multiply probabilities by M+N
- return Z
-
-
- class EncoderSimilarityGraph(nn.Module):
- def __init__(self, opt):
- super(EncoderSimilarityGraph, self).__init__()
- self.opt = opt
- self.block_dim = opt.block_dim
- self.alpha = torch.Tensor(opt.alpha).cuda()
- bin_score = torch.nn.Parameter(torch.tensor(0.))
- self.register_parameter('bin_score', bin_score)
-
- def forward(self, img_emb, cap_emb):
- cap_emb = l2norm(cap_emb, -1)
- n_cap, cap_dim = cap_emb.size(0), cap_emb.size(1)
- n_img, img_dim = img_emb.size(0), img_emb.size(1)
- sims = []
- for i, block_dim in enumerate(self.block_dim):
- img_blk_num, cap_blk_num = img_emb.size(1) // block_dim, cap_emb.size(1) // block_dim
- img_emb_blocks = torch.chunk(img_emb, img_blk_num, -1) # (bs, 2*n, block_dim)
- cap_emb_blocks = torch.chunk(cap_emb, cap_blk_num, -1) # (bs, n, block_dim)
-
- img_emb_blocks = torch.stack(img_emb_blocks, dim=1) # (bs, 2*n, block_dim)
- cap_emb_blocks = torch.stack(cap_emb_blocks, dim=1) # (bs, n, block_dim)
-
-
- img_emb_blocks = l2norm(img_emb_blocks, -1) # (bs, 2*n, block_dim)
- cap_emb_blocks = l2norm(cap_emb_blocks, -1)
-
- logits = torch.einsum("avc,btc->abvt", [img_emb_blocks, cap_emb_blocks]) # (bs, bs, 2*n, n)
-
- logits = log_optimal_transport(logits.reshape(-1, img_blk_num, cap_blk_num), self.bin_score, 20)[:, :-1,
- :-1].reshape(n_img, n_cap, img_blk_num, cap_blk_num)
- t2i_logits = logits.max(dim=-2)[0]
- sims.append(t2i_logits.sum(dim=-1))
-
- sims = torch.stack(sims, -1).sum(-1)
-
- return sims
|