|
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
-
- import modules.utils as utils
- from modules.caption_model import CaptionModel
-
-
- def sort_pack_padded_sequence(input, lengths):
- sorted_lengths, indices = torch.sort(lengths, descending=True)
- tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True)
- inv_ix = indices.clone()
- inv_ix[indices] = torch.arange(0, len(indices)).type_as(inv_ix)
- return tmp, inv_ix
-
-
- def pad_unsort_packed_sequence(input, inv_ix):
- tmp, _ = pad_packed_sequence(input, batch_first=True)
- tmp = tmp[inv_ix]
- return tmp
-
-
- def pack_wrapper(module, att_feats, att_masks):
- if att_masks is not None:
- packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1))
- return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
- else:
- return module(att_feats)
-
-
- class AttModel(CaptionModel):
- def __init__(self, args, tokenizer):
- super(AttModel, self).__init__()
- self.args = args
- self.tokenizer = tokenizer
- self.vocab_size = len(tokenizer.idx2token)
- self.input_encoding_size = args.d_model
- self.rnn_size = args.d_ff
- self.num_layers = args.num_layers
- self.drop_prob_lm = args.drop_prob_lm
- self.max_seq_length = args.max_seq_length
- self.att_feat_size = args.d_vf
- self.att_hid_size = args.d_model
-
- self.bos_idx = args.bos_idx
- self.eos_idx = args.eos_idx
- self.pad_idx = args.pad_idx
-
- self.use_bn = args.use_bn
-
- self.embed = lambda x: x
- self.fc_embed = lambda x: x
- self.att_embed = nn.Sequential(*(
- ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ()) +
- (nn.Linear(self.att_feat_size, self.input_encoding_size),
- nn.ReLU(),
- nn.Dropout(self.drop_prob_lm)) +
- ((nn.BatchNorm1d(self.input_encoding_size),) if self.use_bn == 2 else ())))
-
- def clip_att(self, att_feats, att_masks):
- # Clip the length of att_masks and att_feats to the maximum length
- if att_masks is not None:
- max_len = att_masks.data.long().sum(1).max()
- att_feats = att_feats[:, :max_len].contiguous()
- att_masks = att_masks[:, :max_len].contiguous()
- return att_feats, att_masks
-
- def _prepare_feature(self, fc_feats, att_feats, att_masks):
- att_feats, att_masks = self.clip_att(att_feats, att_masks)
-
- # embed fc and att feats
- fc_feats = self.fc_embed(fc_feats)
- att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
-
- # Project the attention feats first to reduce memory and computation comsumptions.
- p_att_feats = self.ctx2att(att_feats)
-
- return fc_feats, att_feats, p_att_feats, att_masks
-
- def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1):
- # 'it' contains a word index
- xt = self.embed(it)
-
- output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks)
- if output_logsoftmax:
- logprobs = F.log_softmax(self.logit(output), dim=1)
- else:
- logprobs = self.logit(output)
-
- return logprobs, state
-
- def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
- beam_size = opt.get('beam_size', 10)
- group_size = opt.get('group_size', 1)
- sample_n = opt.get('sample_n', 10)
- # when sample_n == beam_size then each beam is a sample.
- assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
- batch_size = fc_feats.size(0)
-
- p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
-
- assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
- seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long)
- seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1)
- # lets process every image independently for now, for simplicity
-
- self.done_beams = [[] for _ in range(batch_size)]
-
- state = self.init_hidden(batch_size)
-
- # first step, feed bos
- it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
- logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
-
- p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(beam_size,
- [p_fc_feats, p_att_feats,
- pp_att_feats, p_att_masks]
- )
- self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt)
- for k in range(batch_size):
- if sample_n == beam_size:
- for _n in range(sample_n):
- seq_len = self.done_beams[k][_n]['seq'].shape[0]
- seq[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['seq']
- seqLogprobs[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['logps']
- else:
- seq_len = self.done_beams[k][0]['seq'].shape[0]
- seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
- seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps']
- # return the samples and their log likelihoods
- return seq, seqLogprobs
-
- def _sample(self, fc_feats, att_feats, att_masks=None):
- opt = self.args.__dict__
- sample_method = opt.get('sample_method', 'greedy')
- beam_size = opt.get('beam_size', 1)
- temperature = opt.get('temperature', 1.0)
- sample_n = int(opt.get('sample_n', 1))
- group_size = opt.get('group_size', 1)
- output_logsoftmax = opt.get('output_logsoftmax', 1)
- decoding_constraint = opt.get('decoding_constraint', 0)
- block_trigrams = opt.get('block_trigrams', 0)
- if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
- return self._sample_beam(fc_feats, att_feats, att_masks, opt)
- if group_size > 1:
- return self._diverse_sample(fc_feats, att_feats, att_masks, opt)
-
- batch_size = fc_feats.size(0)
- state = self.init_hidden(batch_size * sample_n)
-
- p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
-
- if sample_n > 1:
- p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(sample_n,
- [p_fc_feats, p_att_feats,
- pp_att_feats, p_att_masks]
- )
-
- trigrams = [] # will be a list of batch_size dictionaries
-
- seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long)
- seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1)
- for t in range(self.max_seq_length + 1):
- if t == 0: # input <bos>
- it = fc_feats.new_full([batch_size * sample_n], self.bos_idx, dtype=torch.long)
-
- logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state,
- output_logsoftmax=output_logsoftmax)
-
- if decoding_constraint and t > 0:
- tmp = logprobs.new_zeros(logprobs.size())
- tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf'))
- logprobs = logprobs + tmp
-
- # Mess with trigrams
- # Copy from https://github.com/lukemelas/image-paragraph-captioning
- if block_trigrams and t >= 3:
- # Store trigram generated at last step
- prev_two_batch = seq[:, t - 3:t - 1]
- for i in range(batch_size): # = seq.size(0)
- prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
- current = seq[i][t - 1]
- if t == 3: # initialize
- trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
- elif t > 3:
- if prev_two in trigrams[i]: # add to list
- trigrams[i][prev_two].append(current)
- else: # create list
- trigrams[i][prev_two] = [current]
- # Block used trigrams at next step
- prev_two_batch = seq[:, t - 2:t]
- mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size
- for i in range(batch_size):
- prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
- if prev_two in trigrams[i]:
- for j in trigrams[i][prev_two]:
- mask[i, j] += 1
- # Apply mask to log probs
- # logprobs = logprobs - (mask * 1e9)
- alpha = 2.0 # = 4
- logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
-
- # sample the next word
- if t == self.max_seq_length: # skip if we achieve maximum length
- break
- it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature)
-
- # stop when all finished
- if t == 0:
- unfinished = it != self.eos_idx
- else:
- it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0
- logprobs = logprobs * unfinished.unsqueeze(1).float()
- unfinished = unfinished * (it != self.eos_idx)
- seq[:, t] = it
- seqLogprobs[:, t] = logprobs
- # quit loop if all sequences have finished
- if unfinished.sum() == 0:
- break
-
- return seq, seqLogprobs
-
- def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}):
-
- sample_method = opt.get('sample_method', 'greedy')
- beam_size = opt.get('beam_size', 1)
- temperature = opt.get('temperature', 1.0)
- group_size = opt.get('group_size', 1)
- diversity_lambda = opt.get('diversity_lambda', 0.5)
- decoding_constraint = opt.get('decoding_constraint', 0)
- block_trigrams = opt.get('block_trigrams', 0)
-
- batch_size = fc_feats.size(0)
- state = self.init_hidden(batch_size)
-
- p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
-
- trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries
-
- seq_table = [fc_feats.new_full((batch_size, self.max_seq_length), self.pad_idx, dtype=torch.long) for _ in
- range(group_size)]
- seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.max_seq_length) for _ in range(group_size)]
- state_table = [self.init_hidden(batch_size) for _ in range(group_size)]
-
- for tt in range(self.max_seq_length + group_size):
- for divm in range(group_size):
- t = tt - divm
- seq = seq_table[divm]
- seqLogprobs = seqLogprobs_table[divm]
- trigrams = trigrams_table[divm]
- if t >= 0 and t <= self.max_seq_length - 1:
- if t == 0: # input <bos>
- it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
- else:
- it = seq[:, t - 1] # changed
-
- logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats,
- p_att_masks, state_table[divm]) # changed
- logprobs = F.log_softmax(logprobs / temperature, dim=-1)
-
- # Add diversity
- if divm > 0:
- unaug_logprobs = logprobs.clone()
- for prev_choice in range(divm):
- prev_decisions = seq_table[prev_choice][:, t]
- logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda
-
- if decoding_constraint and t > 0:
- tmp = logprobs.new_zeros(logprobs.size())
- tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf'))
- logprobs = logprobs + tmp
-
- # Mess with trigrams
- if block_trigrams and t >= 3:
- # Store trigram generated at last step
- prev_two_batch = seq[:, t - 3:t - 1]
- for i in range(batch_size): # = seq.size(0)
- prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
- current = seq[i][t - 1]
- if t == 3: # initialize
- trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
- elif t > 3:
- if prev_two in trigrams[i]: # add to list
- trigrams[i][prev_two].append(current)
- else: # create list
- trigrams[i][prev_two] = [current]
- # Block used trigrams at next step
- prev_two_batch = seq[:, t - 2:t]
- mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size
- for i in range(batch_size):
- prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
- if prev_two in trigrams[i]:
- for j in trigrams[i][prev_two]:
- mask[i, j] += 1
- # Apply mask to log probs
- # logprobs = logprobs - (mask * 1e9)
- alpha = 2.0 # = 4
- logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
-
- it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1)
-
- # stop when all finished
- if t == 0:
- unfinished = it != self.eos_idx
- else:
- unfinished = seq[:, t - 1] != self.pad_idx & seq[:, t - 1] != self.eos_idx
- it[~unfinished] = self.pad_idx
- unfinished = unfinished & (it != self.eos_idx) # changed
- seq[:, t] = it
- seqLogprobs[:, t] = sampleLogprobs.view(-1)
-
- return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table,
- 1).reshape(
- batch_size * group_size, -1)
|