|
- """Training script"""
- import os
- import time
- import numpy as np
- import torch
- import random
- import wandb
- import logging
- # import tensorboard_logger as tb_logger
-
- from arguments import get_argument_parser
- from data import get_loaders
- from evalution import forward_sims, encode_data, evalrank, AverageMeter, LogCollector, t2i, i2t
- from transformers import BertTokenizer
- from vse import VSEModel
-
- os.system("cp /tmp/dataset/bert-base-uncased.zip /tmp/code/avse_bert/")
- os.system("unzip bert-base-uncased.zip")
-
- def setup_seed(seed):
- np.random.seed(seed)
- random.seed(seed)
- torch.manual_seed(seed) # cpu
- torch.cuda.manual_seed_all(seed) # 并行gpu
- torch.backends.cudnn.deterministic = True # cpu/gpu结果一致
- torch.backends.cudnn.benchmark = True # 训练集变化不大时使训练加速
-
-
- def main():
- # Hyper Parameters
- parser = get_argument_parser()
- opt = parser.parse_known_args()[0]
- setup_seed(opt.seed)
- if not os.path.exists(opt.model_name):
- os.makedirs(opt.model_name)
- logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
- # tb_logger.configure(opt.logger_name, flush_secs=5)
-
- logger = logging.getLogger(__name__)
- print(opt)
-
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
- vocab = tokenizer.vocab
- opt.vocab_size = len(vocab)
-
- train_loader, val_loader = get_loaders(
- opt.data_path, opt.data_name, tokenizer, opt.batch_size, opt.workers, opt)
-
- model = VSEModel(opt)
-
- wandb.init(project='avse_bert', config=opt)
- wandb.watch(model.img_enc)
- wandb.watch(model.txt_enc)
- wandb.watch(model.sim_enc)
- wandb.watch(model.group_enc)
-
- lr_schedules = [opt.lr_update, ]
-
- # optionally resume from a checkpoint
- start_epoch = 0
- if opt.resume:
- if os.path.isfile(opt.resume):
- print("=> loading checkpoint '{}'".format(opt.resume))
- checkpoint = torch.load(opt.resume)
- start_epoch = checkpoint['epoch']
- best_rsum = checkpoint['best_rsum']
- if opt.precomp_enc_type == 'backbone' and not model.is_data_parallel:
- model.make_data_parallel()
- model.load_state_dict(checkpoint['model'])
- # Eiters is used to show logs as the continuation of another training
- model.Eiters = checkpoint['Eiters']
- print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})"
- .format(opt.resume, start_epoch, best_rsum))
- # validate(opt, val_loader, model)
- if opt.reset_start_epoch:
- start_epoch = 0
- else:
- print("=> no checkpoint found at '{}'".format(opt.resume))
-
- if opt.precomp_enc_type == 'backbone' and not model.is_data_parallel:
- model.make_data_parallel()
-
- # Train the Model
- best_rsum = 0
- for epoch in range(start_epoch, opt.num_epochs):
- print(opt.logger_name)
- print(opt.model_name)
-
- adjust_learning_rate(opt, model.optimizer, epoch, lr_schedules)
-
- if epoch >= opt.vse_mean_warmup_epochs:
- opt.max_violation = True
- model.set_max_violation(opt.max_violation)
-
- # Set up the all warm-up options
- if opt.precomp_enc_type == 'backbone':
- if epoch < opt.embedding_warmup_epochs:
- model.freeze_backbone()
- print('All backbone weights are frozen, only train the embedding layers')
- else:
- model.unfreeze_backbone(3)
-
- if epoch < opt.embedding_warmup_epochs:
- print('Warm up the embedding layers')
- elif epoch < opt.embedding_warmup_epochs + opt.backbone_warmup_epochs:
- model.unfreeze_backbone(3) # only train the last block of resnet backbone
- elif epoch < opt.embedding_warmup_epochs + opt.backbone_warmup_epochs * 2:
- model.unfreeze_backbone(2)
- elif epoch < opt.embedding_warmup_epochs + opt.backbone_warmup_epochs * 3:
- model.unfreeze_backbone(1)
- else:
- model.unfreeze_backbone(0)
-
- # train for one epoch
- train(opt, train_loader, model, epoch, val_loader)
-
- # evaluate on validation set
- rsum = validate(opt, val_loader, model)
-
- # remember best R@ sum and save checkpoint
- is_best = rsum > best_rsum
- best_rsum = max(rsum, best_rsum)
- if not os.path.exists(opt.model_name):
- os.mkdir(opt.model_name)
- save_checkpoint({
- 'epoch': epoch + 1,
- 'model': model.state_dict(),
- 'best_rsum': best_rsum,
- 'opt': opt,
- 'Eiters': model.Eiters,
- }, is_best, filename='checkpoint.pth'.format(epoch), prefix=opt.model_name + '/')
- # print('---test---')
- # evalrank(os.path.join(opt.model_name, "model_best.pth"), opt.data_path, split="test")
-
-
- def train(opt, train_loader, model, epoch, val_loader):
- # average meters to record the training statistics
- logger = logging.getLogger(__name__)
- batch_time = AverageMeter()
- data_time = AverageMeter()
- train_logger = LogCollector()
-
- print('image encoder trainable parameters: {}'.format(count_params(model.img_enc)))
- print('txt encoder trainable parameters: {}'.format(count_params(model.txt_enc)))
- print('txt encoder trainable parameters: {}'.format(count_params(model.sim_enc)))
-
- num_loader_iter = len(train_loader.dataset) // train_loader.batch_size + 1
-
- end = time.time()
- # opt.viz = True
- for i, train_data in enumerate(train_loader):
- # switch to train mode
- model.train_start()
-
- # measure data loading time
- data_time.update(time.time() - end)
-
- # make sure train logger is used
- model.logger = train_logger
-
- # Update the model
- if opt.precomp_enc_type == 'basic':
- images, img_lengths, captions, lengths, ids = train_data
- model.train_emb(images, captions, lengths, image_lengths=img_lengths, ids=ids)
- else:
- images, captions, lengths, _ = train_data
- if epoch == opt.embedding_warmup_epochs:
- warmup_alpha = float(i) / num_loader_iter
- model.train_emb(images, captions, lengths, warmup_alpha=warmup_alpha)
- else:
- model.train_emb(images, captions, lengths)
-
- # measure elapsed time
- batch_time.update(time.time() - end)
- end = time.time()
-
- # print log info
- if model.Eiters % opt.log_step == 0:
- if opt.precomp_enc_type == 'backbone' and epoch == opt.embedding_warmup_epochs:
- print('Current epoch-{}, the first epoch for training backbone, warmup alpha {}'.format(epoch,
- warmup_alpha))
- print(
- 'Epoch: [{0}][{1}/{2}]\t'
- '{e_log}\t'
- 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
- 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
- .format(
- epoch, i, len(train_loader.dataset) // train_loader.batch_size + 1, batch_time=batch_time,
- data_time=data_time, e_log=str(model.logger)))
-
- # Record logs in tensorboard
- # tb_logger.log_value('epoch', epoch, step=model.Eiters)
- # tb_logger.log_value('step', i, step=model.Eiters)
- # tb_logger.log_value('batch_time', batch_time.val, step=model.Eiters)
- # tb_logger.log_value('data_time', data_time.val, step=model.Eiters)
- # model.logger.tb_log(tb_logger, step=model.Eiters)
-
-
- def validate(opt, val_loader, model):
- logger = logging.getLogger(__name__)
- model.val_start()
- with torch.no_grad():
- # compute the encoding for all the validation images and captions
- img_embs, cap_embs = encode_data(
- model, val_loader, opt.log_step, logging.info, backbone=opt.precomp_enc_type == 'backbone')
-
- img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)])
-
- start = time.time()
- sims = forward_sims(model, img_embs, cap_embs)
- end = time.time()
- print("calculate similarity time:".format(end - start))
-
- # caption retrieval
- npts = img_embs.shape[0]
- # (r1, r5, r10, medr, meanr) = i2t(img_embs, cap_embs, cap_lens, sims)
- (r1, r5, r10, medr, meanr) = i2t(npts, sims)
- logging.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" %
- (r1, r5, r10, medr, meanr))
- # image retrieval
- # (r1i, r5i, r10i, medri, meanr) = t2i(img_embs, cap_embs, cap_lens, sims)
- (r1i, r5i, r10i, medri, meanri) = t2i(npts, sims)
- logging.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" %
- (r1i, r5i, r10i, medri, meanri))
- # sum of recalls to be used for early stopping
- currscore = r1 + r5 + r10 + r1i + r5i + r10i
- print('Current rsum is {}'.format(currscore))
-
- wandb.log(
- {
- "Val/i2t/r1": r1,
- "Val/i2t/r5": r5,
- "Val/i2t/r10": r10,
- "Val/i2t/medr": medr,
- "Val/i2t/meanr": meanr,
- "Val/t2i/r1": r1i,
- "Val/t2i/r5": r5i,
- "Val/t2i/r10": r10i,
- "Val/t2i/medr": medri,
- "Val/t2i/meanr": meanri,
- "Val/Score": currscore
- })
-
- return currscore
-
-
- def save_checkpoint(state, is_best, filename='checkpoint.pth', prefix=''):
- logger = logging.getLogger(__name__)
- tries = 15
-
- # deal with unstable I/O. Usually not necessary.
- while tries:
- try:
- torch.save(state, prefix + filename)
- if is_best:
- torch.save(state, prefix + 'model_best.pth')
- except IOError as e:
- error = e
- tries -= 1
- else:
- break
- print('model save {} failed, remaining {} trials'.format(filename, tries))
- if not tries:
- raise error
-
-
- def adjust_learning_rate(opt, optimizer, epoch, lr_schedules):
- logger = logging.getLogger(__name__)
- """Sets the learning rate to the initial LR
- decayed by 10 every opt.lr_update epochs"""
- if epoch in lr_schedules:
- print('Current epoch num is {}, decrease all lr by 10'.format(epoch, ))
- for param_group in optimizer.param_groups:
- old_lr = param_group['lr']
- new_lr = old_lr * 0.1
- param_group['lr'] = new_lr
- print('new lr {}'.format(new_lr))
-
-
- def count_params(model):
- model_parameters = filter(lambda p: p.requires_grad, model.parameters())
- params = sum([np.prod(p.size()) for p in model_parameters])
- return params
-
-
- if __name__ == '__main__':
- main()
|