|
-
- import os
- from PIL import Image
- import numpy as np
- import pickle
- import copy
- import sys
-
- class AttDataset():
- """
- person attribute dataset interface
- """
-
- def __init__(
- self,
- dataset,#哪个描述dataset的pkl
- partition,#哪个描述partion的pkl
- split='train',
- root = '/home/work/user-job-dir/data/images',
- partition_idx=0, # 制定选取partition的split的哪一部分内容
- transform=None,
- target_transform=None,
- **kwargs):
- if os.path.exists(dataset):
- self.dataset = pickle.load(open(dataset,'rb'))
- self.dataset['root'] = root;
- else:
- print(dataset + ' does not exist in dataset.')
-
- raise ValueError
- if os.path.exists(partition):
- self.partition = pickle.load(open(partition,'rb'))
- else:
- print(partition + ' does not exist in dataset.')
-
- raise ValueError
- # if not self.partition.has_key(split):
- # print(split + ' does not exist in dataset.')
- #
- # raise ValueError
-
- if partition_idx > len(self.partition[split]) - 1:
- print('partition_idx is out of range in partition.')
-
- raise ValueError
-
- self.transform = transform
- self.target_transform = target_transform
- #datasetpath=os.path.join(sys.path[0], 'dataset')
- #datasetpath=os.path.join(datasetpath,'peta')
- #image_path=os.path.join(datasetpath,'images')
- self.root=root
-
- # create image, label based on the selected partition and dataset split
- self.root_path = self.dataset['root']
- self.att_name = [self.dataset['att_name'][i] for i in
- self.dataset['selected_attribute']] # 'selected_attribute'全部att英爱是排序的,这里是表明了选择哪几个位置的att
- self.image = []
- self.label = []
- for idx in self.partition[split][partition_idx]:
- self.image.append(self.dataset['image'][idx])
- label_tmp = np.array(self.dataset['att'][idx])[
- self.dataset['selected_attribute']].tolist() # 属性中只保留一些想要的,标签样式是每个特征都有01表明
- self.label.append(label_tmp)
-
- def __getitem__(self, index):
- """
- Args:
- index (int): Index
- Returns:
- tuple: (image, target) where target is the index of the target class
- """
- imgname, target = self.image[index], self.label[index]
- # load image and labels
- imgname = os.path.join(self.root, imgname)
- img = Image.open(imgname)
- # if self.transform is not None:
- # img = self.transform(img)
-
- # default no transform
- target = np.array(target).astype(np.float32)
- target[target == 0] = 0 # 啥意思
- target[target == 2] = 0 # 莫名其妙,peta中没有2这个标签
- # if self.target_transform is not None:
- # target = self.transform(target)
-
- return (img, target)
-
- # useless for personal batch sampler
- def __len__(self):
- return len(self.image)
|