|
- import os
-
- import pandas as pd
- import numpy as np
- import cv2
- import mindspore as ms
- import mindspore.dataset as dataset
- from mindspore import dtype
- from mindspore.dataset.transforms import c_transforms as c_tr
- import mindspore.dataset.vision.c_transforms as tr
- from mindspore.communication.management import get_rank, get_group_size
- from mindspore.dataset.vision import Inter
-
- import config as cfg
- from config import DATASET_PATH
-
-
- def str2onehot(s):
- target = np.zeros(cfg.output_num, dtype=np.float32)
- target[len(s) - 1] = 1
- for n in s:
- target[int(n) + 3] = 1
- return target
-
-
- def data_preprocess():
- # 只运行一次
- train_csv = pd.read_csv(DATASET_PATH + '/train/train_label.csv').values
- test_csv = pd.read_csv(DATASET_PATH + '/test/test_label.csv').values
- val_csv = pd.read_csv(DATASET_PATH + '/val/val_label.csv').values
- for i in train_csv:
- name = i[0].split('.')[0]
- classes = i[1].split(' ')
- for j in classes:
- k = cfg.class_name.index(j)
- name += '_%d' % k
- os.rename(DATASET_PATH + '/train/images/' + i[0], DATASET_PATH + '/train/images/' + name + '.jpg')
- for i in test_csv:
- name = i[0].split('.')[0]
- classes = i[1].split(' ')
- for j in classes:
- k = cfg.class_name.index(j)
- name += '_%d' % k
- os.rename(DATASET_PATH + '/test/images/' + i[0], DATASET_PATH + '/test/images/' + name + '.jpg')
- for i in val_csv:
- name = i[0].split('.')[0]
- classes = i[1].split(' ')
- for j in classes:
- k = cfg.class_name.index(j)
- name += '_%d' % k
- os.rename(DATASET_PATH + '/val/images/' + i[0], DATASET_PATH + '/val/images/' + name + '.jpg')
-
-
- class PlantDatasetIter:
- def __init__(self, mode):
- self.mode = mode
- self.train_list = os.listdir(DATASET_PATH + '/train/images/')
- self.test_list = os.listdir(DATASET_PATH + '/test/images/')
- self.val_list = os.listdir(DATASET_PATH + '/val/images/')
-
- def __getitem__(self, index):
- if self.mode == 'train':
- label = self.train_list[index].split('.')[0].split('_')[1:]
- return np.array(cv2.imread(DATASET_PATH + '/train/images/' + self.train_list[index])), str2onehot(label)
- elif self.mode == 'test':
- label = self.test_list[index].split('.')[0].split('_')[1:]
- return np.array(cv2.imread(DATASET_PATH + '/test/images/' + self.test_list[index])), str2onehot(label)
- elif self.mode == 'val':
- label = self.val_list[index].split('.')[0].split('_')[1:]
- return np.array(cv2.imread(DATASET_PATH + '/val/images/' + self.val_list[index])), str2onehot(label)
-
- def __len__(self):
- if self.mode == 'train':
- return len(self.train_list)
- elif self.mode == 'test':
- return len(self.test_list)
- elif self.mode == 'val':
- return len(self.val_list)
-
-
- def create_plant_dataset(mode='train', rank_size=1):
- if mode == 'train':
- ds = dataset.GeneratorDataset(PlantDatasetIter(mode), ['image', 'label'], shuffle=True,
- num_shards=get_group_size(), shard_id=get_rank(), num_parallel_workers=4)
- elif mode == 'test' or mode == 'val':
- ds = dataset.GeneratorDataset(PlantDatasetIter(mode), ['image', 'label'], shuffle=False, num_parallel_workers=4)
- else:
- raise Exception
-
- aug_tr_list = [tr.RandomHorizontalFlip(0.6),
- tr.RandomRotation(15, Inter.BILINEAR),
- tr.RandomColorAdjust(0.1, 0.1, 0.1, 0.1)]
- aug_tr = tr.UniformAugment(aug_tr_list, 2)
-
- transforms_train = [
- aug_tr,
- c_tr.TypeCast(dtype.float32),
- tr.Resize([cfg.new_height, cfg.new_weight]),
- tr.Rescale(1 / 255, 0),
- tr.HWC2CHW()
- ]
- transforms_test = [
- c_tr.TypeCast(dtype.float32),
- tr.Resize([cfg.new_height, cfg.new_weight]),
- tr.Rescale(1 / 255, 0),
- tr.HWC2CHW()
- ]
- if mode == 'train':
- ds = ds.map(operations=transforms_train, input_columns='image')
- ds = ds.batch(int(cfg.train_batch_size/rank_size), drop_remainder=True)
- elif mode == 'test' or mode == 'val':
- ds = ds.map(transforms_test, 'image')
- ds = ds.batch(cfg.test_batch_size, drop_remainder=False)
- return ds
-
-
- if __name__ == '__main__':
- d = create_plant_dataset('train')
- for item, (i, j) in enumerate(d):
- print(j)
- # if i.shape[2] != 3:
- # print('i.shape[0]!=3')
|