|
- # -*- coding: utf-8 -*-
- from __future__ import print_function, division
-
- # import sys
- # sys.path.append('/home/xujiahong/openI_benchmark/vechicle_reID_VechicleNet/')
-
- import time
- import yaml
- import pickle
- import torch
- import torch.nn as nn
- import numpy as np
- from torchvision import datasets,transforms
- import os
- import scipy.io
- from tqdm import tqdm
- from data_utils.model_train import ft_net
- from utils.util import get_stream_logger
- from config.mainconfig import OUTPUT_RESULT_DIR, CONFIG_PATH
-
-
-
- def fliplr(img):
- '''flip horizontal'''
- inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W
- img_flip = img.index_select(3,inv_idx)
- return img_flip
-
- def extract_feature(model, dataloaders, flip):
- features = torch.FloatTensor()
- count = 0
- for _, data in enumerate(tqdm(dataloaders),0):
- img, _ = data
- n, c, h, w = img.size()
- count += n
-
- input_img = img.cuda()
- ff = model(input_img)
-
- if flip:
- img = fliplr(img)
- input_img = img.cuda()
- outputs_flip = model(input_img)
- ff += outputs_flip
-
- fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
- ff = ff.div(fnorm.expand_as(ff))
- #print(ff.shape)
- features = torch.cat((features,ff.data.cpu().float()), 0)
- #features = torch.cat((features,ff.data.float()), 0)
- return features
-
-
- def get_id(img_path):
- '''
- xjh:
- example of the name of the img: 0769_c013_00074310_0
- 0769 is the vehicleID, 013 is the cameraID, 00074310 is the frameID
- '''
- camera_id = []
- labels = []
- for path, _ in img_path:
- #filename = path.split('/')[-1]
- filename = os.path.basename(path) #get the name of images
- # Test Gallery Image
- if not 'c' in filename:
- labels.append(9999999)
- camera_id.append(9999999)
- else:
- #label = filename[0:4]
- label = filename[0:5] #for benchmark_person
- camera = filename.split('c')[1]
- if label[0:2]=='-1':
- labels.append(-1)
- else:
- labels.append(int(label))
- #camera_id.append(int(camera[0:3]))
- camera_id.append(int(camera[0:2]))#for benchmark_person
- #print(camera[0:3])
- return camera_id, labels
-
-
- def test(config_file_path:str, logger):
- #read config files
- with open(config_file_path, encoding='utf-8') as f:
- opts = yaml.load(f, Loader=yaml.SafeLoader)
-
- data_dir = opts['input']['dataset']['data_dir']
- name = "trained_" + opts['input']['config']['name']
- trained_model_name = name + "_last.pth"
- save_path = OUTPUT_RESULT_DIR
-
- nclass = opts['input']['config']['nclass']
- stride = opts['input']['config']['stride']
- pool = opts['input']['config']['pool']
- droprate = opts['input']['config']['droprate']
- inputsize= opts['input']['config']['inputsize']
- w = opts['input']['config']['w']
- h = opts['input']['config']['h']
- batchsize = opts['input']['config']['batchsize']
- flip = opts['test']['flip_test']
-
- trained_model_path = os.path.join(save_path, trained_model_name)
-
- ##############################load model#################################################
- ###self-train
- model = ft_net(class_num = nclass, droprate = droprate, stride=stride, init_model=None, pool = pool, return_f=False)
-
- try:
- model.load_state_dict(torch.load(trained_model_path))
- except:
- model = torch.nn.DataParallel(model)
- model.load_state_dict(torch.load(trained_model_path))
- model = model.module
- model.classifier.classifier = nn.Sequential() #model ends with feature extractor(output len is 512)
- # print(model)
-
- ##############################load dataset###############################################
-
- #transforms for input image h==w==299, inputsize==256
- if h == w:
- data_transforms = transforms.Compose([
- transforms.Resize( ( round(inputsize*1.1), round(inputsize*1.1)), interpolation=3),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ])
- else:
- data_transforms = transforms.Compose( [
- transforms.Resize((round(h*1.1), round(w*1.1)), interpolation=3), #Image.BICUBIC
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ])
-
- image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['bounding_box_test','query']}
- dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batchsize,
- shuffle=False, num_workers=8) for x in ['bounding_box_test','query']}
-
- #############################check GPU###################################################
- use_gpu = torch.cuda.is_available()
-
-
- #############################extract features############################################
- # Change to test mode
- model = model.eval()
- if use_gpu:
- model = model.cuda()
-
- gallery_path = image_datasets['bounding_box_test'].imgs
- query_path = image_datasets['query'].imgs
-
- gallery_cam,gallery_label = get_id(gallery_path)
- query_cam,query_label = get_id(query_path)
-
-
- gallery_label = np.asarray(gallery_label)
- query_label = np.asarray(query_label)
- gallery_cam = np.asarray(gallery_cam)
- query_cam = np.asarray(query_cam)
- print('Gallery Size: %d'%len(gallery_label))
- print('Query Size: %d'%len(query_label))
- # Extract feature
- since = time.time()
- with torch.no_grad():
- gallery_feature = extract_feature(model, dataloaders['bounding_box_test'], flip)
- query_feature = extract_feature(model, dataloaders['query'], flip)
- process_time = time.time() - since
- logger.info('total forward time: %.2f minutes'%(process_time/60))
-
- dist = 1-torch.mm(query_feature, torch.transpose(gallery_feature, 0, 1))
-
- # Save to Matlab for check
- extracted_feature = {'gallery_feature': gallery_feature.numpy(), 'gallery_label':gallery_label, 'gallery_cam':gallery_cam, \
- 'query_feature': query_feature.numpy(), 'query_label':query_label, 'query_cam':query_cam}
-
- result_name = os.path.join(save_path, name+'_feature.mat')
- scipy.io.savemat(result_name, extracted_feature)
-
- return_dict = {}
-
- return_dict['dist'] = dist.numpy()
- return_dict['feature_example'] = query_feature[0].numpy()
- return_dict['gallery_label'] = gallery_label
- return_dict['gallery_cam'] = gallery_cam
- return_dict['query_label'] = query_label
- return_dict['query_cam'] = query_cam
-
- pickle.dump(return_dict, open(OUTPUT_RESULT_DIR+'test_result.pkl', 'wb'), protocol=4)
-
- return
-
- # eval_result = evaluator(result, logger)
- # full_table = display_eval_result(dict = eval_result)
- # logger.info(full_table)
-
- if __name__=="__main__":
- logger = get_stream_logger('TEST')
- test(CONFIG_PATH, logger)
|