|
- # data loader
- from __future__ import print_function, division
- import torch
- from skimage import io, transform, color
- import numpy as np
- import math
- #import matplotlib.pyplot as plt
- from torch.utils.data import Dataset
- #==========================dataset load==========================
-
- class RescaleT(object):
-
- def __init__(self,output_size):
- assert isinstance(output_size,(int,tuple))
- self.output_size = output_size
-
- def __call__(self,sample):
- image, label = sample['image'],sample['label']
-
- h, w = image.shape[:2]
-
- if isinstance(self.output_size,int):
- if h > w:
- new_h, new_w = self.output_size*h/w,self.output_size
- else:
- new_h, new_w = self.output_size,self.output_size*w/h
- else:
- new_h, new_w = self.output_size
-
- new_h, new_w = int(new_h), int(new_w)
-
- # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
- # img = transform.resize(image,(new_h,new_w),mode='constant')
- # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
-
- img = transform.resize(image,(self.output_size,self.output_size),mode='constant')
- lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True)
-
- return {'image':img,'label':lbl}
-
- class Rescale(object):
-
- def __init__(self,output_size):
- assert isinstance(output_size,(int,tuple))
- self.output_size = output_size
-
- def __call__(self,sample):
- image, label = sample['image'],sample['label']
-
- h, w = image.shape[:2]
-
- if isinstance(self.output_size,int):
- if h > w:
- new_h, new_w = self.output_size*h/w,self.output_size
- else:
- new_h, new_w = self.output_size,self.output_size*w/h
- else:
- new_h, new_w = self.output_size
-
- new_h, new_w = int(new_h), int(new_w)
-
- # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
- img = transform.resize(image,(new_h,new_w),mode='constant')
- lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
-
- return {'image':img,'label':lbl}
-
- class CenterCrop(object):
-
- def __init__(self,output_size):
- assert isinstance(output_size, (int, tuple))
- if isinstance(output_size, int):
- self.output_size = (output_size, output_size)
- else:
- assert len(output_size) == 2
- self.output_size = output_size
- def __call__(self,sample):
- image, label = sample['image'], sample['label']
-
- h, w = image.shape[:2]
- new_h, new_w = self.output_size
-
- # print("h: %d, w: %d, new_h: %d, new_w: %d"%(h, w, new_h, new_w))
- assert((h >= new_h) and (w >= new_w))
-
- h_offset = int(math.floor((h - new_h)/2))
- w_offset = int(math.floor((w - new_w)/2))
-
- image = image[h_offset: h_offset + new_h, w_offset: w_offset + new_w]
- label = label[h_offset: h_offset + new_h, w_offset: w_offset + new_w]
-
- return {'image': image, 'label': label}
-
- class RandomCrop(object):
-
- def __init__(self,output_size):
- assert isinstance(output_size, (int, tuple))
- if isinstance(output_size, int):
- self.output_size = (output_size, output_size)
- else:
- assert len(output_size) == 2
- self.output_size = output_size
- def __call__(self,sample):
- image, label = sample['image'], sample['label']
-
- h, w = image.shape[:2]
- new_h, new_w = self.output_size
-
- top = np.random.randint(0, h - new_h)
- left = np.random.randint(0, w - new_w)
-
- image = image[top: top + new_h, left: left + new_w]
- label = label[top: top + new_h, left: left + new_w]
-
- return {'image': image, 'label': label}
-
-
- #csx
- class ToTensor(object):
- """Convert ndarrays in sample to Tensors."""
- ## 不减均值除方差
- def __call__(self, sample):
-
- image, label = sample['image'], sample['label']
- # image = (image / 255).transpose((2, 0, 1))
- # # src = torch.from_numpy(src).float()
- # # tar => float to float tensor
- # label = label.transpose((2, 0, 1))
- # # tar = torch.from_numpy(tar).float()
- tmpImg = np.zeros((image.shape[0],image.shape[1],3))
- tmpLbl = np.zeros(label.shape)
-
- # tmpImg = (image / 255).transpose((2, 0, 1))
- # tmpLbl = label.transpose((2, 0, 1))
-
- image = image/np.max(image)
- if(np.max(label)<1e-6):
- label = label
- else:
- label = label/np.max(label)
-
- # if image.shape[2]==1:
- # tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
- # tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
- # tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
- # else:
- # tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
- # tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
- # tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
-
- # tmpLbl[:,:,0] = label[:,:,0]
-
- # # change the r,g,b to b,r,g from [0,255] to [0,1]
- # #transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
- tmpImg = image.transpose((2, 0, 1))
- tmpLbl = label.transpose((2, 0, 1))
-
- return {'image': torch.from_numpy(tmpImg),
- 'label': torch.from_numpy(tmpLbl)}
- class BuildingDataset(Dataset):
- def __init__(self,img_name_list,lbl_name_list,transform=None):
- # self.root_dir = root_dir
- # self.image_name_list = glob.glob(image_dir+'*.png')
- # self.label_name_list = glob.glob(label_dir+'*.png')
- self.image_name_list = img_name_list
- self.label_name_list = lbl_name_list
- self.transform = transform
-
- def __len__(self):
- return len(self.image_name_list)
-
- def __getitem__(self,idx):
-
- # image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
- # label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
-
- image = io.imread(self.image_name_list[idx], plugin='pil')
-
- if(0==len(self.label_name_list)):
- label_3 = np.zeros(image.shape)
- else:
- label_3 = io.imread(self.label_name_list[idx], plugin='pil')
-
- #print("len of label3")
- #print(len(label_3.shape))
- #print(label_3.shape)
-
- label = np.zeros(label_3.shape[0:2])
- if(3==len(label_3.shape)):
- label = label_3[:,:,0]
- elif(2==len(label_3.shape)):
- label = label_3
-
- if(3==len(image.shape) and 2==len(label.shape)):
- label = label[:,:,np.newaxis]
- elif(2==len(image.shape) and 2==len(label.shape)):
- image = image[:,:,np.newaxis]
- label = label[:,:,np.newaxis]
-
- # #vertical flipping
- # # fliph = np.random.randn(1)
- # flipv = np.random.randn(1)
- #
- # if flipv>0:
- # image = image[::-1,:,:]
- # label = label[::-1,:,:]
- # #vertical flip
-
- sample = {'image':image, 'label':label}
-
- if self.transform:
- sample = self.transform(sample)
-
- return sample
|