|
- import os
- from PIL import Image
- from mindspore import Tensor
- import mindspore.dataset.vision.py_transforms as py_vision
- import mindspore.dataset as ds
-
- IMG_EXTENSIONS = [
- '.jpg', '.JPG', '.jpeg', '.JPEG',
- '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
- ]
-
- category_map = {
- 'clearness': [0],
- 'blur': [1],
- 'invisible': [2],
- 'bubble_invisible': [2],
- }
-
-
- def is_image_file(filename):
- return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
-
-
- def img_loader(path, num_channels):
- if num_channels == 1:
- img = Image.open(path)
- else:
- img = Image.open(path).convert('RGB')
-
- return img
-
-
- # get the image list pairs
- def get_imgs_list(dir_list, post_fix=None):
- """
- :param dir_list: [img1_dir, img2_dir, ...]
- :param post_fix: e.g. ['label.png', 'weight.png',...]
- :return: e.g. [(img1.ext, img1_label.png, img1_weight.png), ...]
- """
- img_list = []
- if len(dir_list) == 0:
- return img_list
- if len(dir_list) != len(post_fix) + 1:
- raise (RuntimeError('Should specify the postfix of each img type except the first input.'))
-
- img_filename_list = [os.listdir(dir_list[i]) for i in range(len(dir_list))]
-
- for imgs in img_filename_list[0]:
- if (not is_image_file(imgs)) or imgs.startswith('.'):
- continue
- else:
- img = imgs
- img_name = os.path.splitext(img)[0]
- item = [os.path.join(dir_list[0], img)]
- for i in range(1, len(img_filename_list)):
-
- img_name = '{:s}{:s}'.format(img_name, post_fix[i - 1])
- # img_name = '{:s}'.format(img1_name)
-
- if img_name in img_filename_list[i]:
- img_path = os.path.join(dir_list[i], img_name)
- item.append(img_path)
-
- if len(item) == len(dir_list):
- img_list.append(tuple(item))
-
- return img_list
-
-
- # dataset that supports one input image, one target image, and one weight map (optional)
- class DataFolder:
- def __init__(self, dir_list, post_fix, num_channels, data_transform=None, loader=img_loader,
- category="endoscope400/category.json"):
- if len(dir_list) != len(post_fix) + 1:
- raise (RuntimeError('Length of dir_list is different from length of post_fix + 1.'))
- if len(dir_list) != len(num_channels):
- raise (RuntimeError('Length of dir_list is different from length of num_channels.'))
-
- self.img_list = get_imgs_list(dir_list, post_fix)
- if len(self.img_list) == 0:
- raise (RuntimeError('Found 0 image pairs in given directories.'))
-
- self.data_transform = data_transform
- self.num_channels = num_channels
- self.loader = loader
- with open(category, "rb") as f:
- import json
- self.categorylabel = json.load(f)
-
- def __getitem__(self, index):
- img_paths = self.img_list[index]
- img_name = os.path.splitext(os.path.basename(img_paths[0]))[0]
- sample = [self.loader(img_paths[i], self.num_channels[i]) for i in range(len(img_paths))]
-
- if self.data_transform is not None:
- sample = self.data_transform(sample)
-
- sample = list(sample)
- sample.append(Tensor(category_map[self.categorylabel[img_name]['category']]))
- # sample = tuple(sample)
-
- return sample
-
- def __len__(self):
- return len(self.img_list)
-
-
- def create_dataset(dir_list, post_fix, num_channels, data_transforms, column_names, batch_size, shuffle=False):
- monuseg = DataFolder(dir_list=dir_list, post_fix=post_fix, num_channels=num_channels,
- data_transform=data_transforms)
- dataset = ds.GeneratorDataset(monuseg, column_names=column_names, shuffle=shuffle)
- train_set = dataset.batch(batch_size)
- return train_set
-
-
- if __name__ == "__main__":
- from my_transforms import get_transforms
- import numpy as np
-
- dir_list = ["endoscope400/ade20k/images/train", "endoscope400/ade20k/annotations/train"]
- post_fix = [".png"]
- num_channels = [3, 1]
- data_transforms = get_transforms({
- 'scale': 240,
- 'horizontal_flip': True,
- 'random_affine': 0.3,
- 'random_rotation': 90,
- 'random_crop': 240,
- 'to_tensor': 1
- })
- image, target, category = DataFolder(dir_list, post_fix, num_channels, data_transforms).__getitem__(0)
- train_set = create_dataset(dir_list=dir_list, post_fix=post_fix, num_channels=num_channels, data_transforms=data_transforms, column_names=["input", "target", "category"], batch_size=10)
- for data in train_set:
- image, target, category = data
- print(image.shape)
|