|
- """VSE model"""
- import numpy as np
-
- import torch
- import torch.nn as nn
- import torch.nn.init
- import torch.backends.cudnn as cudnn
- from torch.nn.utils import clip_grad_norm_
-
- import logging
-
- from encoder import l2norm, get_image_encoder, get_text_encoder, get_sim_encoder, Grouping
- from loss import ContrastiveLoss
-
- logger = logging.getLogger(__name__)
-
-
- def off_diagonal(x):
- # return a flattened view of the off-diagonal elements of a square matrix
- n, m = x.shape
- assert n == m
- return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
-
-
- class VSEModel(object):
- """
- The standard VSE model
- """
-
- def __init__(self, opt):
- # Build Models
- self.n_group = 10
- self.grad_clip = opt.grad_clip
- self.img_enc = nn.ModuleList([get_image_encoder(opt.img_dim, opt.embed_size,
- precomp_enc_type=opt.precomp_enc_type,
- backbone_source=opt.backbone_source,
- backbone_path=opt.backbone_path,
- no_imgnorm=opt.no_imgnorm) for i in range(self.n_group)])
-
- self.txt_enc = get_text_encoder(opt.vocab_size, opt.embed_size, opt.word_dim, opt.num_layers, opt,
- use_bi_gru=True, no_txtnorm=opt.no_txtnorm)
-
- self.sim_enc = get_sim_encoder(opt)
-
- self.group_enc = Grouping()
-
- if torch.cuda.is_available():
- self.img_enc.cuda()
- self.txt_enc.cuda()
- self.sim_enc.cuda()
- self.group_enc.cuda()
-
- cudnn.benchmark = True
-
- # Loss and Optimizer
- self.criterion = ContrastiveLoss(opt=opt, margin=opt.margin, max_violation=opt.max_violation)
-
- params = list(self.txt_enc.parameters())
- params += list(self.img_enc.parameters())
- params += list(self.sim_enc.parameters())
- params += list(self.group_enc.parameters())
-
- self.params = params
- self.opt = opt
-
- # Set up the lr for different parts of the VSE model
- if opt.precomp_enc_type == 'basic':
- if self.opt.optim == 'adam':
- self.optimizer = torch.optim.AdamW(self.params, lr=opt.learning_rate)
- else:
- raise ValueError('Invalid optim option {}'.format(self.opt.optim))
- else:
- decay_factor = 1e-4
- if self.opt.optim == 'adam':
- self.optimizer = torch.optim.AdamW([
- {'params': self.txt_enc.parameters(), 'lr': opt.learning_rate},
- {'params': self.img_enc.backbone.top.parameters(),
- 'lr': opt.learning_rate * opt.backbone_lr_factor, },
- {'params': self.img_enc.backbone.base.parameters(),
- 'lr': opt.learning_rate * opt.backbone_lr_factor, },
- {'params': self.img_enc.image_encoder.parameters(), 'lr': opt.learning_rate},
- ], lr=opt.learning_rate, weight_decay=decay_factor)
- else:
- raise ValueError('Invalid optim option {}'.format(self.opt.optim))
-
- print('Use {} as the optimizer, with init lr {}'.format(self.opt.optim, opt.learning_rate))
-
- self.Eiters = 0
- self.training = False
- self.data_parallel = False
-
- def set_max_violation(self, max_violation):
- if max_violation:
- self.criterion.max_violation_on()
- else:
- self.criterion.max_violation_off()
-
- def state_dict(self):
- state_dict = [self.img_enc.state_dict(), self.txt_enc.state_dict(), self.sim_enc.state_dict(), self.group_enc.state_dict()]
- return state_dict
-
- def load_state_dict(self, state_dict):
- self.img_enc.load_state_dict(state_dict[0], strict=False)
- self.txt_enc.load_state_dict(state_dict[1], strict=False)
- self.sim_enc.load_state_dict(state_dict[2], strict=False)
- self.group_enc.load_state_dict(state_dict[3], strict=False)
-
- def train_start(self):
- """switch to train mode
- """
- self.training = True
- self.img_enc.train()
- self.txt_enc.train()
- self.sim_enc.train()
- self.group_enc.train()
-
- def val_start(self):
- """switch to evaluate mode
- """
- self.img_enc.eval()
- self.txt_enc.eval()
- self.sim_enc.eval()
- self.group_enc.eval()
- self.training = False
-
- def freeze_backbone(self):
- if 'backbone' in self.opt.precomp_enc_type:
- if isinstance(self.img_enc, nn.DataParallel):
- self.img_enc.module.freeze_backbone()
- else:
- self.img_enc.freeze_backbone()
-
- def unfreeze_backbone(self, fixed_blocks):
- if 'backbone' in self.opt.precomp_enc_type:
- if isinstance(self.img_enc, nn.DataParallel):
- self.img_enc.module.unfreeze_backbone(fixed_blocks)
- else:
- self.img_enc.unfreeze_backbone(fixed_blocks)
-
- def make_data_parallel(self):
- self.img_enc = nn.DataParallel(self.img_enc)
- self.txt_enc = nn.DataParallel(self.txt_enc)
- self.sim_enc = nn.DataParallel(self.sim_enc)
- self.group_enc = nn.DataParallel(self.group_enc)
- self.data_parallel = True
- print('Image encoder is data paralleled now.')
-
- @property
- def is_data_parallel(self):
- return self.data_parallel
-
- def forward_emb(self, images, captions, lengths, image_lengths=None):
- """Compute the image and caption embeddings
- """
-
- if torch.cuda.is_available():
- images = images.cuda()
- img_lens = image_lengths.cuda()
- captions = captions.cuda()
- img_len = img_lens.max()
- grouped_images = self.group_enc(images, n_group=self.n_group, n_region=int(img_len*0.8))
- img_p1 = grouped_images[0]
-
- img_embds = []
-
- for i, grouped_img in enumerate(grouped_images):
- img_length = torch.Tensor([img_p1.shape[1]]*img_p1.shape[0])
- # img_length = img_length - torch.randint(0,5, img_length.shape)
- img_embds.append(self.img_enc[i](grouped_img, img_length.cuda()))
- img_emb = torch.cat(img_embds, dim=1)
-
- lengths = torch.Tensor(lengths).cuda()
- cap_emb = self.txt_enc(captions, lengths)
- return img_emb, cap_emb
-
- def forward_sims(self, img_emb, cap_emb):
-
- # sims = img_emb.mm(cap_emb.t())
-
- sims = self.sim_enc(img_emb, cap_emb)
-
- return sims
-
- def forward_loss(self, img_emb, cap_emb):
- """Compute the loss given pairs of image and caption embeddings
- """
- sims = self.forward_sims(img_emb, cap_emb)
-
- loss0 = self.criterion(sims)
-
- self.logger.update('Le0', loss0.data.item(), img_emb.size(0))
-
- embs = torch.chunk(img_emb, self.n_group, -1)
- loss1 = 0
- for i in range(self.n_group - 1):
- embs1 = l2norm(embs[i], 0)
- for j in range(i+1, self.n_group):
- embs2 = l2norm(embs[j], 0)
- c = embs1.T @ embs2
- # # # sum the cross-correlation matrix between all gpus
- c.div_(embs[0].size(0))
-
- on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
- off_diag = off_diagonal(c).pow_(2).sum()
- loss1 += on_diag + off_diag / 1023
- # loss1 = sum(loss1)
- loss1 = loss1 / (self.n_group * (self.n_group - 1) / 2)
- self.logger.update('Le1', loss1.data.item(), img_emb.size(0))
-
- # c1 = embs[2].T @ embs[3]
-
- # # sum the cross-correlation matrix between all gpus
- # c1.div_(embs[2].size(0))
-
- # on_diag1 = torch.diagonal(c1).add_(-1).pow_(2).sum()
- # off_diag1 = off_diagonal(c1).pow_(2).sum()
- # loss2 = on_diag1 + 0.0051 * off_diag1
- # self.logger.update('Le2', loss2.data.item(), img_emb.size(0))
-
- # loss = loss0 + loss1 + loss2
- loss = loss0 + loss1
-
- return loss
-
- def train_emb(self, images, captions, lengths, image_lengths=None, warmup_alpha=None, ids=None):
- """One training step given images and captions.
- """
- self.Eiters += 1
- self.logger.update('Eit', self.Eiters)
- self.logger.update('lr', self.optimizer.param_groups[0]['lr'])
-
- # compute the embeddings
- img_emb, cap_emb = self.forward_emb(images, captions, lengths, image_lengths=image_lengths)
-
- # measure accuracy and record loss
- self.optimizer.zero_grad()
- loss = self.forward_loss(img_emb, cap_emb)
-
- if warmup_alpha is not None:
- loss = loss * warmup_alpha
-
- # compute gradient and update
- loss.backward()
- if self.grad_clip > 0:
- clip_grad_norm_(self.params, self.grad_clip)
- self.optimizer.step()
|