|
- import enum
- from torchvision import transforms
- from torch.utils.data import DataLoader, Dataset
- from PIL import Image
- import numpy as np
- import glob
-
-
- # 标签名称
- label_name = ["airplane", "automobile", "bird",
- "cat", "deer", "dog",
- "frog", "horse", "ship", "truck"]
-
-
- # 训练集数据预处理
- train_transform = transforms.Compose([
- transforms.RandomHorizontalFlip(), # 随机水平翻转
- transforms.RandomVerticalFlip(), # 随机垂直翻转
- transforms.ToTensor(), # 将数据转换成张量Tensor对象
- # 分别指定三个通道的均值和标准差,进行归一化数据
- transforms.Normalize((0.49,0.48,0.44),
- (0.21,0.18,0.22))
- ])
-
-
- # 测试集数据预处理
- test_transform = transforms.Compose([
- transforms.CenterCrop((32,32)), # 从中心位置裁剪指定大小的图像
- transforms.ToTensor(), # 将数据转换成张量Tensor对象
- # 分别指定三个通道的均值和标准差,进行归一化数据
- transforms.Normalize((0.49,0.48,0.44),
- (0.21,0.18,0.22))
- ])
-
-
- label_dict = {}
-
- for idx,name in enumerate(label_name):
- label_dict[name] = idx
-
-
- def default_loader(path):
- return Image.open(path).convert("RGB")
-
-
- # 自定义的Dataset是一个包装类,用来将数据包装为Dataset类,
- # 然后传入DataLoader中从而使DataLoader类更加快捷的对数据进行操作
- class MyDataset(Dataset):
- def __init__(self,im_list,transform=None,loader=default_loader):
- super(MyDataset,self).__init__()
- imgs = []
-
- for im_item in im_list:
- # im_item形如:'data\\TRAIN\\airplane\\xxx.png'
- im_label_name = im_item.split("\\")[-2]
- # imgs存储的是图片路径和其对应的label
- imgs.append([im_item,label_dict[im_label_name]])
-
- self.imgs = imgs
- self.transform = transform # 预处理
- self.loader = loader # 加载函数
-
- # index参数是一个索引,这个索引的取值范围要根据__len__这个方法的返回值确定
- def __getitem__(self, index):
- im_path,im_label = self.imgs[index]
- im_data = self.loader(im_path)
- if self.transform is not None:
- im_data = self.transform(im_data) # 图像预处理
- return im_data, im_label
-
- def __len__(self):
- return len(self.imgs)
-
- # 读取数据集的所有文件名列表
- im_train_list = glob.glob("data\\TRAIN\\*\\*.png")
- im_test_list = glob.glob("data\\TEST\\*\\*.png")
-
-
- train_dataset = MyDataset(im_train_list,
- transform=train_transform)
- test_dataset = MyDataset(im_test_list,
- transform=test_transform)
-
- # 训练数据加载器
- train_loader = DataLoader(dataset=train_dataset,
- batch_size=128,
- shuffle=True)
-
- # 测试数据加载器
- test_loader = DataLoader(dataset=test_dataset,
- batch_size=128,
- shuffle=False)
-
- print("训练集数据量:",len(train_dataset))
- print("测试集数据量:",len(test_dataset))
-
-
-
-
-
-
-
-
-
-
|