|
- import torch
- import os
- import numpy as np
- from datasets.crowd_multi import Crowd_Multi
- from models.fpn import vgg19_fpn
- import argparse
-
-
- def test_1(data_dir, save_dir):
- os.environ['CUDA_VISIBLE_DEVICES'] = '0' # set vis gpu
-
- datasets = Crowd_Multi(os.path.join(data_dir, 'test'), 512, 8, is_gray=False, method='val')
- dataloader = torch.utils.data.DataLoader(datasets, 1, shuffle=False,
- num_workers=8, pin_memory=False)
- model = vgg19_fpn()
- device = torch.device('cuda')
- model.to(device)
- model.load_state_dict(torch.load(os.path.join(save_dir, 'best_model.pth'), device))
- epoch_minus = []
-
- for inputs, count, name in dataloader:
- inputs = inputs.to(device)
- assert inputs.size(0) == 1, 'the batch size should equal to 1'
- with torch.set_grad_enabled(False):
- outputs = model(inputs)
- temp_minu = count[0].item() - torch.sum(outputs).item()
- print(name, temp_minu, count[0].item(), torch.sum(outputs).item())
- epoch_minus.append(temp_minu)
-
- epoch_minus = np.array(epoch_minus)
- mse = np.sqrt(np.mean(np.square(epoch_minus)))
- mae = np.mean(np.abs(epoch_minus))
- log_str = 'Final Test: mae {}, mse {}'.format(mae, mse)
- return log_str
-
- def test_2(data_dir, save_dir):
- os.environ['CUDA_VISIBLE_DEVICES'] = '0' # set vis gpu
-
- datasets = Crowd_Multi(os.path.join(data_dir, 'test'), 512, 8, is_gray=False, method='val')
- dataloader = torch.utils.data.DataLoader(datasets, 1, shuffle=False,
- num_workers=8, pin_memory=False)
- model = vgg19_fpn()
- device = torch.device('cuda')
- model.to(device)
- model.load_state_dict(torch.load(os.path.join(save_dir, 'best_model.pth'), device))
- epoch_minus = []
-
- for inputs, count, name in dataloader:
- inputs = inputs.to(device)
- assert inputs.size(0) == 1, 'the batch size should equal to 1'
- with torch.set_grad_enabled(False):
- outputs = model(inputs)
- temp_minu = count[0].item() - torch.sum(outputs).item()
- print(name, temp_minu, count[0].item(), torch.sum(outputs).item())
- epoch_minus.append(int(torch.sum(outputs).item()))
- print(epoch_minus)
-
- return epoch_minus
|