|
- """VSE model"""
- import numpy as np
-
- import torch
- import torch.nn as nn
- import torch.nn.init
- import torch.backends.cudnn as cudnn
- from torch.nn.utils import clip_grad_norm_
-
- import logging
-
- from encoder import l2norm, get_image_encoder, get_text_encoder, get_sim_encoder, Grouping
- from loss import ContrastiveLoss
-
- logger = logging.getLogger(__name__)
-
-
- def off_diagonal(x):
- # return a flattened view of the off-diagonal elements of a square matrix
- n, m = x.shape
- assert n == m
- return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
-
-
- def get_bert_layerwise_lr_groups(bert_model, learning_rate=1e-5, layer_decay=0.9):
- """
- Gets parameter groups with decayed learning rate based on depth in network
- Layers closer to output will have higher learning rate
-
- Args:
- bert_model: A huggingface bert-like model (should have embedding layer and encoder)
- learning_rate: The learning rate at the output layer
- layer_decay: How much to decay the learning rate per depth (recommended 0.9-0.95)
- Returns:
- grouped_parameters (list): list of parameters with their decayed learning rates
- """
-
- n_layers = len(bert_model.encoder.layer) + 1 # + 1 (embedding)
-
- embedding_decayed_lr = learning_rate * (layer_decay ** (n_layers+1))
- grouped_parameters = [{"params": bert_model.embeddings.parameters(), 'lr': embedding_decayed_lr}]
- for depth in range(1, n_layers):
- decayed_lr = learning_rate * (layer_decay ** (n_layers + 1 - depth))
- grouped_parameters.append(
- {"params": bert_model.encoder.layer[depth-1].parameters(), 'lr': decayed_lr}
- )
-
- return grouped_parameters
-
-
- class VSEModel(object):
- """
- The standard VSE model
- """
-
- def __init__(self, opt):
- # Build Models
- self.n_group = 6
- self.grad_clip = opt.grad_clip
- self.img_enc = nn.ModuleList([get_image_encoder(opt.img_dim, opt.embed_size,
- precomp_enc_type=opt.precomp_enc_type,
- backbone_source=opt.backbone_source,
- backbone_path=opt.backbone_path,
- no_imgnorm=opt.no_imgnorm) for i in range(self.n_group)])
-
- self.txt_enc = get_text_encoder(opt.vocab_size, opt.embed_size, opt.word_dim, opt.num_layers, opt,
- use_bi_gru=True, no_txtnorm=opt.no_txtnorm)
-
- self.sim_enc = get_sim_encoder(opt)
-
- self.group_enc = Grouping(self.n_group, opt.img_dim)
-
- if torch.cuda.is_available():
- self.img_enc.cuda()
- self.txt_enc.cuda()
- self.sim_enc.cuda()
- self.group_enc.cuda()
-
- cudnn.benchmark = True
-
- # Loss and Optimizer
- self.criterion = ContrastiveLoss(opt=opt, margin=opt.margin, max_violation=opt.max_violation)
-
- params = list(self.txt_enc.parameters())
- params += list(self.img_enc.parameters())
- params += list(self.sim_enc.parameters())
- params += list(self.group_enc.parameters())
-
- self.params = params
- self.opt = opt
-
- decay_factor = 1e-4
- if opt.precomp_enc_type == 'basic':
- if self.opt.optim == 'adam':
- all_text_params = list(self.txt_enc.parameters())
- bert_params = list(self.txt_enc.bert.parameters())
- bert_params_ptr = [p.data_ptr() for p in bert_params]
- text_params_no_bert = list()
- for p in all_text_params:
- if p.data_ptr() not in bert_params_ptr:
- text_params_no_bert.append(p)
-
- bert_params = get_bert_layerwise_lr_groups(self.txt_enc.bert, learning_rate=opt.learning_rate * 0.1)
- train_params = [
- {'params': text_params_no_bert, 'lr': opt.learning_rate},
- # {'params': bert_params, 'lr': opt.learning_rate * 0.1},
- {'params': self.img_enc.parameters(), 'lr': opt.learning_rate},
- ]
- self.optimizer = torch.optim.AdamW(train_params + bert_params, lr=opt.learning_rate, weight_decay=decay_factor)
- elif self.opt.optim == 'sgd':
- self.optimizer = torch.optim.SGD(self.params, lr=opt.learning_rate, momentum=0.9)
- else:
- raise ValueError('Invalid optim option {}'.format(self.opt.optim))
- else:
- if self.opt.optim == 'adam':
- all_text_params = list(self.txt_enc.parameters())
- bert_params = list(self.txt_enc.bert.parameters())
- bert_params_ptr = [p.data_ptr() for p in bert_params]
- text_params_no_bert = list()
- for p in all_text_params:
- if p.data_ptr() not in bert_params_ptr:
- text_params_no_bert.append(p)
- self.optimizer = torch.optim.AdamW([
- {'params': text_params_no_bert, 'lr': opt.learning_rate},
- {'params': bert_params, 'lr': opt.learning_rate * 0.1},
- {'params': self.img_enc.backbone.top.parameters(),
- 'lr': opt.learning_rate * opt.backbone_lr_factor, },
- {'params': self.img_enc.backbone.base.parameters(),
- 'lr': opt.learning_rate * opt.backbone_lr_factor, },
- {'params': self.img_enc.image_encoder.parameters(), 'lr': opt.learning_rate},
- ], lr=opt.learning_rate, weight_decay=decay_factor)
- elif self.opt.optim == 'sgd':
- self.optimizer = torch.optim.SGD([
- {'params': self.txt_enc.parameters(), 'lr': opt.learning_rate},
- {'params': self.img_enc.backbone.parameters(), 'lr': opt.learning_rate * opt.backbone_lr_factor,
- 'weight_decay': decay_factor},
- {'params': self.img_enc.image_encoder.parameters(), 'lr': opt.learning_rate},
- ], lr=opt.learning_rate, momentum=0.9, nesterov=True)
- else:
- raise ValueError('Invalid optim option {}'.format(self.opt.optim))
-
- print('Use {} as the optimizer, with init lr {}'.format(self.opt.optim, opt.learning_rate))
-
- self.Eiters = 0
- self.training = False
- self.data_parallel = False
-
- def set_max_violation(self, max_violation):
- if max_violation:
- self.criterion.max_violation_on()
- else:
- self.criterion.max_violation_off()
-
- def state_dict(self):
- state_dict = [self.img_enc.state_dict(), self.txt_enc.state_dict(), self.sim_enc.state_dict(), self.group_enc.state_dict()]
- return state_dict
-
- def load_state_dict(self, state_dict):
- self.img_enc.load_state_dict(state_dict[0], strict=False)
- self.txt_enc.load_state_dict(state_dict[1], strict=False)
- self.sim_enc.load_state_dict(state_dict[2], strict=False)
- self.group_enc.load_state_dict(state_dict[3], strict=False)
-
- def train_start(self):
- """switch to train mode
- """
- self.training = True
- self.img_enc.train()
- self.txt_enc.train()
- self.sim_enc.train()
- self.group_enc.train()
-
- def val_start(self):
- """switch to evaluate mode
- """
- self.img_enc.eval()
- self.txt_enc.eval()
- self.sim_enc.eval()
- self.group_enc.eval()
- self.training = False
-
- def freeze_backbone(self):
- if 'backbone' in self.opt.precomp_enc_type:
- if isinstance(self.img_enc, nn.DataParallel):
- self.img_enc.module.freeze_backbone()
- else:
- self.img_enc.freeze_backbone()
-
- def unfreeze_backbone(self, fixed_blocks):
- if 'backbone' in self.opt.precomp_enc_type:
- if isinstance(self.img_enc, nn.DataParallel):
- self.img_enc.module.unfreeze_backbone(fixed_blocks)
- else:
- self.img_enc.unfreeze_backbone(fixed_blocks)
-
- def make_data_parallel(self):
- self.img_enc = nn.DataParallel(self.img_enc)
- self.txt_enc = nn.DataParallel(self.txt_enc)
- self.sim_enc = nn.DataParallel(self.sim_enc)
- self.group_enc = nn.DataParallel(self.group_enc)
- self.data_parallel = True
- print('Image encoder is data paralleled now.')
-
- @property
- def is_data_parallel(self):
- return self.data_parallel
-
- def forward_emb(self, images, captions, lengths, image_lengths=None):
- """Compute the image and caption embeddings
- """
-
- if torch.cuda.is_available():
- images = images.cuda()
- img_lens = image_lengths.cuda()
- captions = captions.cuda()
- img_len = img_lens.max()
- grouped_images, grouped_lengths = self.group_enc(images, n_group=self.n_group, n_region=int(img_len))
- img_p1 = grouped_images[0]
- img_embds = []
-
- for i, grouped_img in enumerate(grouped_images):
- # img_length = torch.Tensor([grouped_images[i].shape[1]]*grouped_images[i].shape[0])
- # img_length = img_length - torch.randint(0,5, img_length.shape)
- img_embds.append(self.img_enc[i](grouped_img, grouped_lengths[i]))
- img_emb = torch.cat(img_embds, dim=1)
-
- lengths = torch.Tensor(lengths).cuda()
- cap_emb = self.txt_enc(captions, lengths)
- return img_emb, cap_emb
-
- def forward_sims(self, img_emb, cap_emb):
-
- # sims = img_emb.mm(cap_emb.t())
-
- sims = self.sim_enc(img_emb, cap_emb)
-
- return sims
-
- def forward_loss(self, img_emb, cap_emb):
- """Compute the loss given pairs of image and caption embeddings
- """
- sims = self.forward_sims(img_emb, cap_emb)
-
- loss0 = self.criterion(sims)
-
- self.logger.update('Le0', loss0.data.item(), img_emb.size(0))
-
- embs = torch.chunk(img_emb, self.n_group, -1)
- loss1 = 0
- for i in range(self.n_group - 1):
- embs1 = l2norm(embs[i], 0)
- for j in range(i+1, self.n_group):
- embs2 = l2norm(embs[j], 0)
- c = embs1.T @ embs2
- # # # sum the cross-correlation matrix between all gpus
- c.div_(embs[0].size(0))
-
- on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
- off_diag = off_diagonal(c).pow_(2).sum()
- loss1 += on_diag + off_diag / 1023
- # loss1 = sum(loss1)
- loss1 = loss1 / (self.n_group * (self.n_group - 1) / 2)
-
- self.logger.update('Le1', loss1.data.item(), img_emb.size(0))
-
- # c1 = embs[2].T @ embs[3]
-
- # # sum the cross-correlation matrix between all gpus
- # c1.div_(embs[2].size(0))
-
- # on_diag1 = torch.diagonal(c1).add_(-1).pow_(2).sum()
- # off_diag1 = off_diagonal(c1).pow_(2).sum()
- # loss2 = on_diag1 + 0.0051 * off_diag1
- # self.logger.update('Le2', loss2.data.item(), img_emb.size(0))
-
- # loss = loss0 + loss1 + loss2
- loss = loss0 + loss1
-
- return loss
-
- def train_emb(self, images, captions, lengths, image_lengths=None, warmup_alpha=None, ids=None):
- """One training step given images and captions.
- """
- self.Eiters += 1
- self.logger.update('Eit', self.Eiters)
- self.logger.update('lr', self.optimizer.param_groups[0]['lr'])
-
- # compute the embeddings
- img_emb, cap_emb = self.forward_emb(images, captions, lengths, image_lengths=image_lengths)
-
- # measure accuracy and record loss
- self.optimizer.zero_grad()
- loss = self.forward_loss(img_emb, cap_emb)
-
- if warmup_alpha is not None:
- loss = loss * warmup_alpha
-
- # compute gradient and update
- loss.backward()
- if self.grad_clip > 0:
- clip_grad_norm_(self.params, self.grad_clip)
- self.optimizer.step()
|