|
- import os
- import torch
- import torch.backends.cudnn as cudnn
- import torch.nn.functional as F
- from PIL import Image
- from scipy import misc
- from models import utils
- import time
- from options import Options
- from my_transforms import get_transforms
- from models.FullNet import MultiTaskFullNet
- from models.MultiTaskNet import MultiTaskUNet
- from data_folder import DataFolder
- from torch.utils.data import DataLoader
- import numpy as np
-
-
- def main():
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
- img_dir = "endoscope400/ade20k/images/validation"
- label_dir = "endoscope400/ade20k/annotations/validation"
- save_dir="runs/endofullnet/validation"
- model_path= "runs/checkpoints/checkpoint_289.pth"
- save_flag = True
- tta =True
-
- # check if it is needed to compute accuracies
- eval_flag = True #if label_dir else False
-
- # data transforms
- test_transform = get_transforms({
- 'scale': 240,
- 'to_tensor': 1,})
-
-
- # ----- create model ----- # add models
-
- #model = MultiTaskFullNet(color_channels=3, output_channels=8)
- model = MultiTaskUNet(n_channels=3, n_classes=8)
-
- model = torch.nn.DataParallel(model).cuda()
-
- # ----- load trained model ----- #
- print("=> loading trained model")
- best_checkpoint = torch.load(model_path)
- print(model_path)
- model.load_state_dict(best_checkpoint['state_dict']) # ,False
- print("=> loaded model at epoch {}".format(best_checkpoint['epoch']))
- model = model.module
-
- # switch to evaluate mode
- model.eval()
- counter = 0
- print("=> Test begins:")
-
- dsets = {}
-
- data_transforms = {
- 'train': get_transforms({
- 'scale': 240,
- 'horizontal_flip': True,
- 'random_affine': 0.3,
- 'random_elastic': [6, 15],
- 'random_rotation': 90,
- 'random_crop': 240,
- 'to_tensor': 1,
- }),
- 'validation': get_transforms({
- 'scale': 240,
- 'to_tensor': 1,
- })
- }
-
- for x in ['validation']:
- img_dir = img_dir
- target_dir = label_dir
- dir_list = [img_dir, target_dir]
- post_fix = ['.png']
-
- num_channels = [3, 1]
- dsets[x] = DataFolder(dir_list, post_fix, num_channels, data_transforms[x])
- # train_loader = DataLoader(dsets['train'], batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)
- val_loader = DataLoader(dsets['validation'], batch_size=1, shuffle=True, pin_memory=True,num_workers=4)
-
- for index, sample in enumerate(val_loader):
- image, target, category = sample
- image = image.cuda()
- segoutput, clsoutput = model(image)
-
- pred = np.argmax(segoutput.detach().cpu().numpy(), axis=1)
-
- # calculate the classification precision
- pred_cls = np.argmax(clsoutput.detach().cpu().numpy(), axis=1)
-
- print(segoutput.shape)
- print(clsoutput.shape)
-
-
- if __name__ == '__main__':
- main()
|