|
- import torch
- import torch.utils.data
- import numpy as np
- import os
- from src import model as model
- from src import anchor as anchor
- from src.dataset import my_dataloader
- from tqdm import tqdm
- import matplotlib.pyplot as plt
- import matplotlib.patches as patch
- from matplotlib.lines import Line2D
- import random
-
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
- connectivity = [(0, 1), (1, 13), (2, 3), (3, 13), (4, 5), (5, 13), (6, 7),
- (7, 13), (8, 9), (9, 10), (10, 13), (11, 13), (12, 13)]
-
- # DataHyperParams
- keypointsNumber = 14
- cropWidth = 128
- cropHeight = 128
- batch_size = 32
-
- randomseed = 0
- random.seed(randomseed)
- np.random.seed(randomseed)
- torch.manual_seed(randomseed)
-
- model_dir = 'result/pretrained.pth'
-
-
- def visualize(img, pred, label):
- pred[:, :, 0:2] *= 64
- pred[:, :, 0:2] += (cropWidth / 2 - 0.5)
- pred[:, :, 1] = cropWidth - pred[:, :, 1]
- pred[:, :, 2] = -pred[:, :, 2] + 2
- label[:, :, 0:2] *= 64
- label[:, :, 0:2] += (cropWidth / 2 - 0.5)
- label[:, :, 1] = cropWidth - label[:, :, 1]
- label[:, :, 2] = -label[:, :, 2] + 2
- num_show = 3
- fig, axes = plt.subplots(1, num_show, figsize=(10, 4))
- for k in range(num_show):
- axes[k].imshow(img[k, 0], cmap='gray')
- for i in range(keypointsNumber):
- axes[k].add_patch(patch.Circle(label[k, i, 0:2], radius=label[k, i, 2], color='red'))
- axes[k].add_patch(patch.Circle(pred[k, i, 0:2], radius=pred[k, i, 2], color='yellow'))
- # axes[k].Circle(label[k, i, 0:2], radius=2, facecolor='green')
- # axes[k].text(x=label[k, i, 0], y=label[k, i, 1], s=str(i))
- for pair in connectivity:
- axes[k].add_line(Line2D((label[k, pair[0], 0], label[k, pair[1], 0]),
- (label[k, pair[0], 1], label[k, pair[1], 1]), 2, color='red'))
- axes[k].add_line(Line2D((pred[k, pair[0], 0], pred[k, pair[1], 0]),
- (pred[k, pair[0], 1], pred[k, pair[1], 1]), 2, color='yellow'))
- fig.savefig('./result/visual_result.jpg', dpi=300)
- plt.show()
-
-
- def test():
- test_image_datasets = my_dataloader(FileDir="NYU_part", mode="test",augment=False)
- test_dataloaders = torch.utils.data.DataLoader(test_image_datasets, batch_size=batch_size,
- shuffle=False, num_workers=4)
-
- net = model.A2J_model(num_classes=keypointsNumber)
- net.load_state_dict(torch.load(model_dir))
- net = net.cuda()
- net.eval()
-
- post_precess = anchor.post_process(shape=[cropHeight // 16, cropWidth // 16], stride=16, P_h=None, P_w=None)
-
- output = torch.FloatTensor()
- labels = torch.FloatTensor()
- torch.cuda.synchronize()
- j = 0
- for i, (img_ori, label) in tqdm(enumerate(test_dataloaders)):
- with torch.no_grad():
- img, label = img_ori.cuda(), label.cuda()
- heads = net(img)
- pred_keypoints = post_precess(heads)
- if j == 1:
- visualize(img_ori, pred_keypoints.data.cpu(), label.data.cpu())
- j += 1
- output = torch.cat([output, pred_keypoints.data.cpu()], 0)
- labels = torch.cat([labels, label.data.cpu()], 0)
-
- torch.cuda.synchronize()
-
- result = output.cpu().data.numpy()
- labels = labels.cpu().data.numpy()
- error = errorCompute(result, labels)
- print('Error:', error)
-
-
- def errorCompute(source, target):
- assert np.shape(source) == np.shape(target), "source has different shape with target"
-
- outputs = source.copy()
- labels = target.copy()
-
- errors = np.sqrt(np.sum((labels - outputs) ** 2, axis=2))
-
- return np.mean(errors)
-
-
- if __name__ == '__main__':
- test()
|