|
- import os
- import os.path
- import numpy as np
- import glob
- # import torch.utils.data as data
- import tensorflow as tf
- from PIL import Image
- import glob
- # import scipy.io as scio
- import h5py
- from networkTool import trainDataRoot,levelNumK
-
- IMG_EXTENSIONS = [
- 'MPEG',
- 'MVUB',
- 'PCL',
- 'mat'
- ]
-
- def is_image_file(filename):
- return any(extension in filename for extension in IMG_EXTENSIONS)
-
- def default_loader(path):
- mat = h5py.File(path)
- # data = scio.loadmat(path)
- cell = mat['patchFile']
- return cell,mat
-
- class TFdataset():
- def __init__(self, root,epochs, TreePoint,batch_size,dataLenPerFile=391583.14392244595, transform=None ,loader=default_loader):
- dataNames = []
- for filename in sorted(glob.glob(root)):
- if is_image_file(filename):
- dataNames.append('{}'.format(filename))
- self.root = root
- self.dataNames =sorted(dataNames)
- self.transform = transform
- self.loader = loader
- self.index = 0
- self.datalen = 0
- self.dataBuffer = []
- self.fileIndx = 0
- self.TreePoint = TreePoint
- self.fileLen = len(self.dataNames)
- assert self.fileLen>0,'no file found!'
- self.dataLenPerFile = dataLenPerFile #每个新的数据集都要算一下
- # print(self.dataNames[0])
- self.dataset = tf.data.Dataset.from_tensor_slices(np.arange(self.fileLen))#self.dataNames
- self.dataset = self.dataset.map(lambda x: tf.py_function(self.parse_file, [x], [tf.int32]))
- # self.dataset = self.dataset.shuffle(buffer_size=self.fileLen)
- self.dataset = self.dataset.batch(batch_size, drop_remainder=True) #把最后不足1 batch的丢掉
- self.dataset = self.dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)#
- self.dataset = self.dataset.repeat(epochs) #epochs这个值要具体根据数据集和batchsize大小计算一下
- self.iterator = iter(self.dataset)
- # self.iterator = self.dataset.make_one_shot_iterator()
-
- def parse_file(self,index):
- # print(index) #1个mat文件=18个index
- while(self.index+self.TreePoint>self.datalen):
- filename = self.dataNames[self.fileIndx] #用自己的一套顺序循环处理,dataset只是借用它的封装
- print(filename)
- if self.dataBuffer:
- a = [self.dataBuffer[0][self.index:].copy()]
- else:
- a=[]
- cell,mat = self.loader(filename)
- for i in range(cell.shape[1]):
- data = np.transpose(mat[cell[0,i]]) #shape[ptNum,Kparent, Seq[1],Level[1],Octant[1],Pos[3] ] e.g 123456*7*6
- data[:,:,0] = data[:,:,0] - 1
- a.append(data[:,-levelNumK:,:])# only take levelNumK level feats
-
- self.dataBuffer = []
- self.dataBuffer.append(np.vstack(tuple(a)))
-
- self.datalen = self.dataBuffer[0].shape[0]
- self.fileIndx+=1 # shuffle step = 1, will load continuous mat
- self.index = 0
- if(self.fileIndx>=self.fileLen):
- self.fileIndx=index%self.fileLen
- # try read
- img = []
- img.append(self.dataBuffer[0][self.index:self.index+self.TreePoint])
-
- self.index+=self.TreePoint
-
- if self.transform is not None:
- img = self.transform(img)
- return tf.convert_to_tensor(img)
-
- def get_next(self):
- return next(self.iterator,None) #如果迭代结束,就返回None,这样就不会触发StopIteration的异常 self.iterator.get_next()
-
- if __name__=="__main__":
-
- TreePoint = 1024*16 # the number of the continuous occupancy code in data, TreePoint*batch_size divisible by batchSize
- batchSize = 128
- # train_set = DataFolder(root=trainDataRoot, TreePoint=TreePoint,transform=None,dataLenPerFile=875881.6538461539) # will load (batch_size,TreePoint,...) shape data
- # train_loader = data.DataLoader(dataset=train_set, batch_size=1, shuffle=True, num_workers=4,drop_last=True)
- train_set = TFdataset(root='../Data/Obj/train1/*.mat',epochs=10, TreePoint=TreePoint,batch_size=batchSize,dataLenPerFile=391583.14392244595) #train_test dataLenPerFile=401278.2
- try:
- while(True):
- data = train_set.get_next()
- if data:
- print(len(data))
- print(data[0].shape)
- batchSize=32
- train_data = tf.transpose(tf.reshape(data[0],(batchSize,-1,4,6)), perm=[1,0,2,3])
- bptt = 1024
- print(train_data.shape[0])
- # print(data[0])
- else:
- print("None")
- break
- except StopIteration:
- print("end!")
- # for batch, d in enumerate(train_loader):
- # data_source = d[0].reshape((batchSize,-1,4,6)).permute(1,0,2,3) #d[0] for geometry,d[1] for attribute
- # print(batch,data_source.shape)
- # print(data_source[:,0,:,0])
- # print(d[0][0],d[0].shape)
- # %%
|