|
- import matplotlib.pyplot as plt
- import torch
- import torch.nn as nn
- import torchvision
- import torchvision.transforms as transforms
- import numpy as np
- from .misc import *
-
- __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single']
-
- # functions to show an image
- def make_image(img, mean=(0,0,0), std=(1,1,1)):
- for i in range(0, 3):
- img[i] = img[i] * std[i] + mean[i] # unnormalize
- npimg = img.numpy()
- return np.transpose(npimg, (1, 2, 0))
-
- def gauss(x,a,b,c):
- return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a)
-
- def colorize(x):
- ''' Converts a one-channel grayscale image to a color heatmap image '''
- if x.dim() == 2:
- torch.unsqueeze(x, 0, out=x)
- if x.dim() == 3:
- cl = torch.zeros([3, x.size(1), x.size(2)])
- cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
- cl[1] = gauss(x,1,.5,.3)
- cl[2] = gauss(x,1,.2,.3)
- cl[cl.gt(1)] = 1
- elif x.dim() == 4:
- cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)])
- cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
- cl[:,1,:,:] = gauss(x,1,.5,.3)
- cl[:,2,:,:] = gauss(x,1,.2,.3)
- return cl
-
- def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
- images = make_image(torchvision.utils.make_grid(images), Mean, Std)
- plt.imshow(images)
- plt.show()
-
-
- def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
- im_size = images.size(2)
-
- # save for adding mask
- im_data = images.clone()
- for i in range(0, 3):
- im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
-
- images = make_image(torchvision.utils.make_grid(images), Mean, Std)
- plt.subplot(2, 1, 1)
- plt.imshow(images)
- plt.axis('off')
-
- # for b in range(mask.size(0)):
- # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
- mask_size = mask.size(2)
- # print('Max %f Min %f' % (mask.max(), mask.min()))
- mask = (upsampling(mask, scale_factor=im_size/mask_size))
- # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
- # for c in range(3):
- # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
-
- # print(mask.size())
- mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
- # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
- plt.subplot(2, 1, 2)
- plt.imshow(mask)
- plt.axis('off')
-
- def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
- im_size = images.size(2)
-
- # save for adding mask
- im_data = images.clone()
- for i in range(0, 3):
- im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
-
- images = make_image(torchvision.utils.make_grid(images), Mean, Std)
- plt.subplot(1+len(masklist), 1, 1)
- plt.imshow(images)
- plt.axis('off')
-
- for i in range(len(masklist)):
- mask = masklist[i].data.cpu()
- # for b in range(mask.size(0)):
- # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
- mask_size = mask.size(2)
- # print('Max %f Min %f' % (mask.max(), mask.min()))
- mask = (upsampling(mask, scale_factor=im_size/mask_size))
- # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
- # for c in range(3):
- # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
-
- # print(mask.size())
- mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
- # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
- plt.subplot(1+len(masklist), 1, i+2)
- plt.imshow(mask)
- plt.axis('off')
-
-
-
- # x = torch.zeros(1, 3, 3)
- # out = colorize(x)
- # out_im = make_image(out)
- # plt.imshow(out_im)
- # plt.show()
|