|
- import pickle
- def unpickle(file):
- with open(file, 'rb') as fo:
- dict = pickle.load(fo, encoding='bytes')
- return dict
-
- # 图片的标签名称(顺序已排好)
- label_name = ["airplane",
- "automobile",
- "bird",
- "cat",
- "deer",
- "dog",
- "frog",
- "horse",
- "ship",
- "truck"]
-
- import glob
- import numpy as np
- # pip install opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple/
- import cv2
- import os
-
- def save_image(filenames,save_path):
- for l in filenames:
- l_dict = unpickle(l) # 每个l_dict包含了10000张图片
- for im_idx,im_data in enumerate(l_dict[b'data']):
- im_label = l_dict[b'labels'][im_idx] # 标签(0-9)
- im_name = l_dict[b'filenames'][im_idx] # 图片名称
- im_label_name = label_name[im_label] # 标签名称
- im_data = np.reshape(im_data,[3,32,32]) # 整理图片形状
- im_data = np.transpose(im_data,(1,2,0)) # opencv中对图片的处理是HWC
-
- if not os.path.exists("{}\\{}".format(save_path,im_label_name)):
- os.mkdir("{}\\{}".format(save_path,im_label_name))
-
- # 使用cv2存储图片
- cv2.imwrite("{}\\{}\\{}".format(save_path,im_label_name,
- im_name.decode("utf-8")),im_data)
-
-
- if __name__ == '__main__':
- # 获取训练集文件名
- train_filenames = glob.glob("CIFAR10\\data_batch_*")
- train_save_path = "data\\TRAIN"
- save_image(train_filenames,train_save_path)
- print("训练数据集保存完毕!")
- test_filenames = glob.glob("CIFAR10\\test_batch*")
- test_save_path = "data\\TEST"
- save_image(test_filenames,test_save_path)
- print("测试数据集保存完毕!")
-
-
-
-
-
|