|
- # -*- coding: utf-8 -*-
- """
- @author: huangxs
- @License: (C)Copyright 2021, huangxs
- @CreateTime: 2021/11/20 17:09:05
- @Filename: dataset
- service api views
- """
-
- import numpy as np
- import random
- import glob
- import os
- from PIL import Image
- import mindspore.dataset as ds
- import mindspore.dataset.vision.py_transforms as py_vision
-
-
- class MoNuSegGenerator:
- def __init__(self, image_dir, target_dir, transform_list=None, shuffle=True):
- self.transform_list = transform_list
- self.image_list = []
- self.target_list = []
- _image_path_list = glob.glob(os.path.join(image_dir, '*.jpeg'))
- if shuffle:
- random.shuffle(_image_path_list)
- for _image_path in _image_path_list:
- base_name = os.path.basename(_image_path).replace('.jpeg', '')
- _target_image = os.path.join(target_dir, '%s.png' % (base_name))
- if os.path.exists(_target_image):
- self.image_list.append(_image_path)
- self.target_list.append(_target_image)
-
- def __getitem__(self, index):
- _image = np.array(Image.open(self.image_list[index]).convert('RGB'))
- _target = np.array(Image.open(self.target_list[index]).convert('RGB'))
- _image = py_vision.ToTensor()(_image)
- _target = py_vision.ToTensor()(_target)
- return [_image, _target]
-
- def __len__(self):
- return len(self.image_list)
-
-
- def create_dataset(image_dir, target_dir, column_names, batch_size, shuffle=True, transform_list=None):
- monuseg = MoNuSegGenerator(image_dir=image_dir, target_dir=target_dir, shuffle=shuffle)
- dataset = ds.GeneratorDataset(monuseg, column_names=column_names, num_parallel_workers=4)
- if transform_list:
- dataset = dataset.map(input_columns=column_names, operations=transform_list)
- train_set = dataset.batch(batch_size)
- return train_set
-
-
- if __name__ == '__main__':
- import os
-
- img_dir = os.path.join("ade20k\\images", "training")
- target_dir = os.path.join("ade20k\\annotations", "training")
- train_set = create_dataset(image_dir=img_dir, target_dir=target_dir, column_names=["input", "target"], batch_size=2,
- transform_list=None, shuffle=True)
-
- for data in train_set:
- image, target = data
- print(image.shape)
|