|
- # -*- coding = utf-8 -*-
- '''
- # @time:2022/10/30 15:11
- # Author:DFTL
- # @File:Dataset_df.py
- '''
- import json
- import os
- import torch
- from torch.utils.data import Dataset,DataLoader
- import numpy as np
- from tqdm import tqdm
-
- class MyDataset(Dataset):
- def __init__(self, image_data,label_path):
- self.image_data = np.load(image_data) # 加载npy数据
- self.label_data = np.load(label_path)
- # self.transforms = transform # 转为tensor形式
-
- def __getitem__(self, index):
-
- # 读取每一个npy的数据
- image = self.image_data[index, :, :, :]
- image = np.transpose(image,(2,0,1))
-
- label = self.label_data[index]
-
- return torch.tensor(image).float(), torch.tensor(label)
- # return ldct, hdct # 返回数据还有标签
-
- def __len__(self):
- return self.image_data.shape[0] # 返回数据的总个数
-
- class MyDataset_v2(Dataset):
- def __init__(self,image_path,label_path,triples_path):
- self.image_data = np.load(image_path)
- self.label_data = np.load(label_path)
- with open(triples_path, 'r', encoding='utf-8') as f:
- self.knowledge_dict = json.load(f)
-
- def __getitem__(self, item):
- image = self.image_data[item, :, :, :]
- image = np.transpose(image,(2,0,1))
-
- label = self.label_data[item]
-
- triples = self.knowledge_dict[str(label.astype(int))]
- event = triples[0][0] + triples[0][1] + triples[0][2]
-
- return torch.tensor(image).float(), torch.tensor(label), event
-
- def __len__(self):
- return self.image_data.shape[0]
-
- # class Pixel_dataset(Dataset):
- # def __init__(self,image_path,label_path):
- # self.image_data = np.load(image_path) # 加载npy数据
- # self.label_data = np.load(label_path)
- #
- # def __getitem__(self, index):
- # pixel = self.image_data[index]
- # label = self.label_data[index]
- #
- # return torch.tensor(pixel), torch.tensor(label)
- #
- # def __len__(self):
- # return len(self.image_data)
-
- if __name__ == '__main__':
-
- root = r"C:\\Code\\WetLand_Code"
- image_path = os.path.join(root,"DataSet/YRD_N12/X_train.npy")
- label_path = os.path.join(root,"DataSet/YRD_N12/Y_train.npy")
-
- mydataset = MyDataset(image_path, label_path)
- data_loader = DataLoader(dataset=mydataset, batch_size=1, shuffle=True, pin_memory=True)
-
- for data in tqdm(data_loader):
- X,Y = data
- print(Y)
|