|
- """Evaluation"""
- from __future__ import print_function
- import logging
- import time
- import os
- import torch
- import numpy as np
- import sys
- from collections import OrderedDict
-
- from transformers import BertTokenizer
-
- from data import get_test_loader
- from vse import VSEModel
-
- logger = logging.getLogger(__name__)
-
- def forward_sims(model, img_embs, cap_embs):
- bs = 1000
- n_img_shard = (len(img_embs)-1)//bs + 1
- n_cap_shard = (len(cap_embs)-1)//bs + 1
-
- d = np.zeros((len(img_embs), len(cap_embs)))
- for i in range(n_img_shard):
- im_start, im_end = bs * i, min(bs * (i + 1), len(img_embs))
- for j in range(n_cap_shard):
- sys.stdout.write('\r>> shard_xattn_t2i batch (%d,%d)' % (i, j))
- cap_start, cap_end = bs * j, min(bs * (j + 1), len(cap_embs))
- im = torch.from_numpy(img_embs[im_start:im_end]).float().cuda()
- s = torch.from_numpy(cap_embs[cap_start:cap_end]).float().cuda()
- with torch.no_grad():
- sim = model.forward_sims(im, s)
- d[im_start:im_end, cap_start:cap_end] = sim.data.cpu().numpy()
- sys.stdout.write('\n')
- return d
-
- class AverageMeter(object):
- """Computes and stores the average and current value"""
-
- def __init__(self):
- self.reset()
-
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def update(self, val, n=0):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / (.0001 + self.count)
-
- def __str__(self):
- """String representation for logging
- """
- # for values that should be recorded exactly e.g. iteration number
- if self.count == 0:
- return str(self.val)
- # for stats
- return '%.4f (%.4f)' % (self.val, self.avg)
-
-
- class LogCollector(object):
- """A collection of logging objects that can change from train to val"""
-
- def __init__(self):
- # to keep the order of logged variables deterministic
- self.meters = OrderedDict()
-
- def update(self, k, v, n=0):
- # create a new meter if previously not recorded
- if k not in self.meters:
- self.meters[k] = AverageMeter()
- self.meters[k].update(v, n)
-
- def __str__(self):
- """Concatenate the meters in one log line
- """
- s = ''
- for i, (k, v) in enumerate(self.meters.items()):
- if i > 0:
- s += ' '
- s += k + ' ' + str(v)
- return s
-
- # def tb_log(self, tb_logger, prefix='', step=None):
- # """Log using tensorboard
- # """
- # for k, v in self.meters.items():
- # tb_logger.log_value(prefix + k, v.val, step=step)
-
-
- def encode_data(model, data_loader, log_step=10, logging=logger.info, backbone=False):
- """Encode all images and captions loadable by `data_loader`
- """
- batch_time = AverageMeter()
- val_logger = LogCollector()
-
- # switch to evaluate mode
- model.val_start()
-
- end = time.time()
-
- # np array to keep all the embeddings
- img_embs = None
- cap_embs = None
-
- for i, data_i in enumerate(data_loader):
- # make sure val logger is used
- if not backbone:
- images, image_lengths, captions, lengths, ids = data_i
- else:
- images, captions, lengths, ids = data_i
- model.logger = val_logger
-
- # compute the embeddings
- if not backbone:
- img_emb, cap_emb = model.forward_emb(images, captions, lengths, image_lengths=image_lengths)
- else:
- img_emb, cap_emb = model.forward_emb(images, captions, lengths)
-
- if img_embs is None:
- if img_emb.dim() == 3:
- img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1), img_emb.size(2)))
- else:
- img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1)))
- cap_embs = np.zeros((len(data_loader.dataset), cap_emb.size(1)))
- cap_lens = [0] * len(data_loader.dataset)
- # cache embeddings
- img_embs[ids] = img_emb.data.cpu().numpy().copy()
- cap_embs[ids, :] = cap_emb.data.cpu().numpy().copy()
-
- # measure accuracy and record loss
- # model.forward_loss(img_emb, cap_emb)
-
- # measure elapsed time
- batch_time.update(time.time() - end)
- end = time.time()
-
- if i % log_step == 0:
- print('Test: [{0}/{1}]\t'
- '{e_log}\t'
- 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
- .format(
- i, len(data_loader.dataset) // data_loader.batch_size + 1, batch_time=batch_time,
- e_log=str(model.logger)))
- del images, captions
- return img_embs, cap_embs
-
-
- def ensemble_evalrank(model_path1, model_path2, data_path=None, split='dev', fold5=False):
- checkpoint1 = torch.load(model_path1)
- opt1 = checkpoint1['opt']
- print(opt1)
-
- checkpoint2 = torch.load(model_path2)
- opt2 = checkpoint2['opt']
- print(opt2)
-
- if data_path is not None:
- opt1.data_path = data_path
-
-
- # load vocabulary used by the model
- vocab = deserialize_vocab(os.path.join(opt1.vocab_path, '%s_vocab.json' % opt1.data_name))
- vocab.add_word('<mask>')
- # word2idx = vocab.word2idx
- # opt2.word2idx = vocab.word2idx
-
- opt1.vocab_size = len(vocab)
- opt2.vocab_size = len(vocab)
- opt1.txt_enc_type='rnn'
- opt2.txt_enc_type = 'rnn'
-
- model1 = VSEModel(opt1)
- # model1.cuda()
- model1.load_state_dict(checkpoint1['model'])
-
- model2 = VSEModel(opt2)
- # model2.cuda()
- model2.load_state_dict(checkpoint2['model'])
-
- print('Loading dataset')
- # data_loader = get_test_loader(split, opt1.data_name, vocab, None,
- # opt1.batch_size, 0, opt1)
- data_loader = get_test_loader(split, opt1.data_name, vocab,
- opt1.batch_size, 0, opt1)
- print('Computing results...')
- img_embs_1, cap_embs_1,= encode_data(model1, data_loader)
- img_embs_2, cap_embs_2, = encode_data(model2, data_loader)
-
- if not fold5:
- # no cross-validation, full evaluation
- img_embs_1 = np.array([img_embs_1[i] for i in range(0, len(img_embs_1), 5)])
- img_embs_2 = np.array([img_embs_2[i] for i in range(0, len(img_embs_2), 5)])
- start = time.time()
-
- with torch.no_grad():
- sims1 = forward_sims(model1, img_embs_1, cap_embs_1)
- sims2 = forward_sims(model2, img_embs_2, cap_embs_2)
-
- sims = (sims1 + sims2)/2
- npts = img_embs_1.shape[0]
-
- # np.save('f30k_dev', sims)
- end = time.time()
- print("calculate similarity time:", end - start)
-
- r, rt = i2t(npts, sims, return_ranks=True)
- ri, rti = t2i(npts, sims, return_ranks=True)
- ar = (r[0] + r[1] + r[2]) / 3
- ari = (ri[0] + ri[1] + ri[2]) / 3
- rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
- print("rsum: %.1f" % rsum)
- print("Average i2t Recall: %.1f" % ar)
- print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
- print("Average t2i Recall: %.1f" % ari)
- print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)
- else:
- # 5fold cross-validation, only for MSCOCO
- results = []
- for i in range(5):
- img_embs_shard_1 = img_embs_1[i * 5000:(i + 1) * 5000:5]
- cap_embs_shard_1 = cap_embs_1[i * 5000:(i + 1) * 5000]
-
- img_embs_shard_2 = img_embs_2[i * 5000:(i + 1) * 5000:5]
- cap_embs_shard_2 = cap_embs_2[i * 5000:(i + 1) * 5000]
- start = time.time()
-
- sims1 = forward_sims(model1, img_embs_shard_1, cap_embs_shard_1)
- sims2 = forward_sims(model2, img_embs_shard_2, cap_embs_shard_2)
-
- sims = (sims1 + sims2) / 2
- end = time.time()
- print("calculate similarity time:", end - start)
- npts = img_embs_shard_1.shape[0]
- r, rt0 = i2t(npts, sims, return_ranks=True)
- print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
- ri, rti0 = t2i(npts, sims, return_ranks=True)
- print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)
-
- if i == 0:
- rt, rti = rt0, rti0
- ar = (r[0] + r[1] + r[2]) / 3
- ari = (ri[0] + ri[1] + ri[2]) / 3
- rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
- print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
- results += [list(r) + list(ri) + [ar, ari, rsum]]
-
- print("-----------------------------------")
- print("Mean metrics: ")
- mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
- print("rsum: %.1f" % (mean_metrics[12]))
- print("Average i2t Recall: %.1f" % mean_metrics[11])
- print("Image to text: %.1f %.1f %.1f %.1f %.1f" %
- mean_metrics[:5])
- print("Average t2i Recall: %.1f" % mean_metrics[12])
- print("Text to image: %.1f %.1f %.1f %.1f %.1f" %
- mean_metrics[5:10])
-
- torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar')
-
-
- def evalrank(model_path, data_path=None, split='dev', fold5=False, save_path=None, cxc=False):
- """
- Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold
- cross-validation is done (only for MSCOCO). Otherwise, the full data is
- used for evaluation.
- """
- # load model and options
- checkpoint = torch.load(model_path)
- opt = checkpoint['opt']
- opt.workers = 5
-
- print(opt)
-
- # load vocabulary used by the model
- # opt.vocab_path = '/tmp/data/vocab'
-
- # load vocabulary used by the model
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
- vocab = tokenizer.vocab
- opt.vocab_size = len(vocab)
-
- opt.backbone_path = '/tmp/data/weights/original_updown_backbone.pth'
- if data_path is not None:
- opt.data_path = data_path
-
- # construct model
- model = VSEModel(opt)
-
- if opt.precomp_enc_type == 'backbone':
- model.make_data_parallel()
-
- # load model state
- model.load_state_dict(checkpoint['model'])
- model.val_start()
-
- print('Loading dataset')
- data_loader = get_test_loader(split, opt.data_name, vocab,
- opt.batch_size, opt.workers, opt)
-
- print('Computing results...')
- with torch.no_grad():
- if opt.precomp_enc_type == 'basic':
- img_embs, cap_embs = encode_data(model, data_loader)
- else:
- img_embs, cap_embs = encode_data(model, data_loader, backbone=True)
- print('Images: %d, Captions: %d' %
- (img_embs.shape[0] / 5, cap_embs.shape[0]))
-
- if cxc:
- eval_cxc(img_embs, cap_embs, data_path)
- else:
- if not fold5:
- # no cross-validation, full evaluation
- img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)])
- start = time.time()
- sims = forward_sims(model, img_embs, cap_embs)
- end = time.time()
- print("calculate similarity time: {}".format(end - start))
- npts = img_embs.shape[0]
- del img_embs
- del cap_embs
- if save_path is not None:
- np.save(save_path, {'npts': npts, 'sims': sims})
- print('Save the similarity into {}'.format(save_path))
-
-
-
- r, rt = i2t(npts, sims, return_ranks=True)
- ri, rti = t2i(npts, sims, return_ranks=True)
- ar = (r[0] + r[1] + r[2]) / 3
- ari = (ri[0] + ri[1] + ri[2]) / 3
- rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
- print("rsum: %.1f" % rsum)
- print("Average i2t Recall: %.1f" % ar)
- print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
- print("Average t2i Recall: %.1f" % ari)
- print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)
- else:
- # 5fold cross-validation, only for MSCOCO
- results = []
- for i in range(5):
- img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5]
- cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000]
- start = time.time()
- sims = forward_sims(model, img_embs_shard, cap_embs_shard)
- end = time.time()
- print("calculate similarity time: {}".format(end - start))
-
- npts = img_embs_shard.shape[0]
- r, rt0 = i2t(npts, sims, return_ranks=True)
- print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
- ri, rti0 = t2i(npts, sims, return_ranks=True)
- print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)
-
- if i == 0:
- rt, rti = rt0, rti0
- ar = (r[0] + r[1] + r[2]) / 3
- ari = (ri[0] + ri[1] + ri[2]) / 3
- rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
- print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
- results += [list(r) + list(ri) + [ar, ari, rsum]]
-
- print("-----------------------------------")
- print("Mean metrics: ")
- mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
- print("rsum: %.1f" % (mean_metrics[12]))
- print("Average i2t Recall: %.1f" % mean_metrics[10])
- print("Image to text: %.1f %.1f %.1f %.1f %.1f" %
- mean_metrics[:5])
- print("Average t2i Recall: %.1f" % mean_metrics[11])
- print("Text to image: %.1f %.1f %.1f %.1f %.1f" %
- mean_metrics[5:10])
-
-
- def compute_sim(images, captions):
- similarities = np.matmul(images, np.matrix.transpose(captions))
- return similarities
-
-
- def i2t(npts, sims, return_ranks=False, mode='coco'):
- """
- Images->Text (Image Annotation)
- Images: (N, n_region, d) matrix of images
- Captions: (5N, max_n_word, d) matrix of captions
- CapLens: (5N) array of caption lengths
- sims: (N, 5N) matrix of similarity im-cap
- """
- ranks = np.zeros(npts)
- top1 = np.zeros(npts)
- for index in range(npts):
- inds = np.argsort(sims[index])[::-1]
- if mode == 'coco':
- rank = 1e20
- for i in range(5 * index, 5 * index + 5, 1):
- tmp = np.where(inds == i)[0][0]
- if tmp < rank:
- rank = tmp
- ranks[index] = rank
- top1[index] = inds[0]
- else:
- rank = np.where(inds == index)[0][0]
- ranks[index] = rank
- top1[index] = inds[0]
-
- # Compute metrics
- r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
- r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
- r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
- medr = np.floor(np.median(ranks)) + 1
- meanr = ranks.mean() + 1
-
- if return_ranks:
- return (r1, r5, r10, medr, meanr), (ranks, top1)
- else:
- return (r1, r5, r10, medr, meanr)
-
-
- def t2i(npts, sims, return_ranks=False, mode='coco'):
- """
- Text->Images (Image Search)
- Images: (N, n_region, d) matrix of images
- Captions: (5N, max_n_word, d) matrix of captions
- CapLens: (5N) array of caption lengths
- sims: (N, 5N) matrix of similarity im-cap
- """
- # npts = images.shape[0]
-
- if mode == 'coco':
- ranks = np.zeros(5 * npts)
- top1 = np.zeros(5 * npts)
- else:
- ranks = np.zeros(npts)
- top1 = np.zeros(npts)
-
- # --> (5N(caption), N(image))
- sims = sims.T
-
- for index in range(npts):
- if mode == 'coco':
- for i in range(5):
- inds = np.argsort(sims[5 * index + i])[::-1]
- ranks[5 * index + i] = np.where(inds == index)[0][0]
- top1[5 * index + i] = inds[0]
- else:
- inds = np.argsort(sims[index])[::-1]
- ranks[index] = np.where(inds == index)[0][0]
- top1[index] = inds[0]
-
- # Compute metrics
- r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
- r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
- r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
- medr = np.floor(np.median(ranks)) + 1
- meanr = ranks.mean() + 1
- if return_ranks:
- return (r1, r5, r10, medr, meanr), (ranks, top1)
- else:
- return (r1, r5, r10, medr, meanr)
-
-
- """
- CxC related evaluation.
- """
-
- def eval_cxc(images, captions, data_path):
- import os
- import json
- cxc_annot_base = os.path.join(data_path, 'cxc_annots')
- img_id_path = os.path.join(cxc_annot_base, 'testall_ids.txt')
- cap_id_path = os.path.join(cxc_annot_base, 'testall_capids.txt')
-
- images = images[::5, :]
-
- with open(img_id_path) as f:
- img_ids = f.readlines()
- with open(cap_id_path) as f:
- cap_ids = f.readlines()
-
- img_ids = [img_id.strip() for i, img_id in enumerate(img_ids) if i % 5 == 0]
- cap_ids = [cap_id.strip() for cap_id in cap_ids]
-
- with open(os.path.join(cxc_annot_base, 'cxc_it.json')) as f_it:
- cxc_it = json.load(f_it)
- with open(os.path.join(cxc_annot_base, 'cxc_i2i.json')) as f_i2i:
- cxc_i2i = json.load(f_i2i)
- with open(os.path.join(cxc_annot_base, 'cxc_t2t.json')) as f_t2t:
- cxc_t2t = json.load(f_t2t)
-
- sims = compute_sim(images, captions)
- t2i_recalls = cxc_inter(sims.T, img_ids, cap_ids, cxc_it['t2i'])
- i2t_recalls = cxc_inter(sims, cap_ids, img_ids, cxc_it['i2t'])
- print('T2I R@1: {}, R@5: {}, R@10: {}'.format(*t2i_recalls))
- print('I2T R@1: {}, R@5: {}, R@10: {}'.format(*i2t_recalls))
-
- i2i_recalls = cxc_intra(images, img_ids, cxc_i2i)
- t2t_recalls = cxc_intra(captions, cap_ids, cxc_t2t, text=True)
- print('I2I R@1: {}, R@5: {}, R@10: {}'.format(*i2i_recalls))
- print('T2T R@1: {}, R@5: {}, R@10: {}'.format(*t2t_recalls))
-
-
- def cxc_inter(sims, data_ids, query_ids, annot):
- ranks = list()
- for idx, query_id in enumerate(query_ids):
- if query_id not in annot:
- raise ValueError('unexpected query id {}'.format(query_id))
- pos_data_ids = annot[query_id]
- pos_data_ids = [pos_data_id for pos_data_id in pos_data_ids if str(pos_data_id[0]) in data_ids]
- pos_data_indices = [data_ids.index(str(pos_data_id[0])) for pos_data_id in pos_data_ids]
- rank = 1e20
- inds = np.argsort(sims[idx])[::-1]
- for pos_data_idx in pos_data_indices:
- tmp = np.where(inds == pos_data_idx)[0][0]
- if tmp < rank:
- rank = tmp
- ranks.append(rank)
- ranks = np.array(ranks)
- r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
- r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
- r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
- return (r1, r5, r10)
-
-
- def cxc_intra(embs, data_ids, annot, text=False):
- pos_thresh = 3.0 if text else 2.5 # threshold for positive pairs according to the CxC paper
-
- sims = compute_sim(embs, embs)
- np.fill_diagonal(sims, 0)
-
- ranks = list()
- for idx, data_id in enumerate(data_ids):
- sim_items = annot[data_id]
- pos_items = [item for item in sim_items if item[1] >= pos_thresh]
- rank = 1e20
- inds = np.argsort(sims[idx])[::-1]
- if text:
- coco_pos = list(range(idx // 5 * 5, (idx // 5 + 1) * 5))
- coco_pos.remove(idx)
- pos_indices = coco_pos
- pos_indices.extend([data_ids.index(str(pos_item[0])) for pos_item in pos_items])
- else:
- pos_indices = [data_ids.index(str(pos_item[0])) for pos_item in pos_items]
- if len(pos_indices) == 0: # skip it since there is positive example in the annotation
- continue
- for pos_idx in pos_indices:
- tmp = np.where(inds == pos_idx)[0][0]
- if tmp < rank:
- rank = tmp
- ranks.append(rank)
-
- ranks = np.array(ranks)
- r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
- r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
- r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
- return (r1, r5, r10)
|