|
- '''
- Author: fuchy@stu.pku.edu.cn
- Description: The encoder helper
- FilePath: /compression/encoderTool.py
- '''
-
- #%%
- import numpy as np
- # import torch
- import tensorflow as tf
- import time
- import os
- from networkTool import bptt,expName,levelNumK,MAX_OCTREE_LEVEL
- from dataset import default_loader as matloader
- import numpyAc #这个不移植了,编解码用到了再转换
- import tqdm
- bpttRepeatTime = 1
-
- def generate_square_subsequent_mask(sz):
- # mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
- # mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
- a = tf.ones([sz,sz])
- d=tf.linalg.band_part(a,0,-1)-tf.linalg.band_part(a,0,0)
- mask = tf.where(tf.equal(d,1),float('-inf'),0)
- return mask
- #%%
-
- '''
- description: Rearrange data for batch processing
- '''
- def dataPreProcess(oct_seq,bptt,batch_size,oct_len): #oct_seq是numpy
- oct_seq[:,:,0] = oct_seq[:,:,0] - 1
- # oct_seq = torch.Tensor(oct_seq).long() # [1,255]->[0,254]. shape [n,K]
- oct_seq = tf.convert_to_tensor(oct_seq, dtype=tf.int64)
- FeatDim = oct_seq.shape[-1]
- # padingdata = torch.vstack((torch.zeros((bptt,levelNumK,FeatDim)),oct_seq)) # to be the context of oct[0]
- padingdata = tf.concat([tf.zeros([bptt,levelNumK,FeatDim], dtype=tf.int64),oct_seq], axis=0)
- padingsize = batch_size - padingdata.shape[0]%batch_size
- # padingdata = torch.vstack((padingdata,torch.zeros(padingsize,levelNumK,FeatDim))).reshape((batch_size,-1,levelNumK,FeatDim)).permute(1,0,2,3) #[bptt,batch_size,K]
- padingdata = tf.concat([padingdata,tf.zeros([padingsize,levelNumK,FeatDim], dtype=tf.int64)], axis=0)
- padingdata = tf.transpose(tf.reshape(padingdata,(batch_size,-1,levelNumK,FeatDim)), perm=[1,0,2,3])
-
- # dataID = torch.hstack((torch.ones(bptt)*-1,torch.Tensor(list(range(oct_len))),torch.ones((padingsize))*-1)).reshape((batch_size,-1)).long().permute(1,0)
- dataID = tf.reshape(tf.concat([tf.ones([bptt], dtype=tf.int64)*-1, tf.convert_to_tensor(list(range(oct_len)), dtype=tf.int64),tf.ones([padingsize], dtype=tf.int64)*-1], 0),(batch_size,-1))
- dataID = tf.transpose(tf.cast(dataID,dtype=tf.int64), perm=[1,0])
-
- # padingdata = torch.vstack((padingdata, padingdata[0:bptt,list(range(1,batch_size,1))+[0]])).long()
- padingdata_slices = padingdata[0:bptt]
- padingdata_slices = tf.gather(padingdata_slices,list(range(1,batch_size,1))+[0],axis=1)
- padingdata = tf.cast(tf.concat([padingdata, padingdata_slices], axis=0),dtype=tf.int64)
-
- # dataID = torch.vstack((dataID, dataID[0:bptt,list(range(1,batch_size,1))+[0]])).long()
- dataID_slices = dataID[0:bptt]
- dataID_slices = tf.gather(dataID_slices,list(range(1,batch_size,1))+[0],axis=1)
- dataID = tf.cast(tf.concat([dataID, dataID_slices], axis=0),dtype=tf.int64)
- return dataID,padingdata
-
- def encodeNode(pro,octvalue):
- assert octvalue<=255 and octvalue>=1
- pre = np.argmax(pro)+1
- return -np.log2(pro[octvalue-1]+1e-07),int(octvalue==pre)
- '''
- description: compress function
- param {N[n treepoints]*K[k ancestors]*C[oct code,level,octant,position(xyz)] array; Octree data sequence} oct_data_seq octant=八等份中的index
- param {str;bin file name} outputfile
- param {model} model
- param {bool;Determines whether to perform entropy coding} actualcode
- return {float;estimated/true bin size (in bit)} binsz
- return {int;oct length of all octree} oct_len
- return {float;total foward time of the model} elapsed
- return {float list;estimated bin size (in bit) of depth 8~maxlevel data} binszList
- return {int list;oct length of 8~maxlevel octree} octNumList
- '''
- def compress(oct_data_seq,outputfile,model,actualcode = True,print=print,showRelut=False): #oct_data_seq是numpy
- # model.eval()
- levelID = oct_data_seq[:,-1,1].copy()
- oct_data_seq = oct_data_seq.copy()
-
- if levelID.max()>MAX_OCTREE_LEVEL:
- print('**warning!!**,to clip the level>{:d}!'.format(MAX_OCTREE_LEVEL))
-
- oct_seq = oct_data_seq[:,-1:,0].astype(int)
- oct_data_seq[:-1,0:-1,:] = oct_data_seq[1:,0:-1,:]
- oct_data_seq[:-1,-1,1:3] = oct_data_seq[1:,-1,1:3]
- oct_len = len(oct_seq)
-
- batch_size =1 # 1 for safety encoder
-
- assert(batch_size*bptt<oct_len)
-
- #%%
- dataID,padingdata = dataPreProcess(oct_data_seq,bptt,batch_size,oct_len)
- MAX_GPU_MEM_It = 2**13 # you can change this according to the GUP memory size (2**12 for 24G)
- MAX_GPU_MEM = min(bptt*MAX_GPU_MEM_It,tf.reduce_max(dataID))+2 # bptt <= MAX_GPU_MEM -1 < min(MAX_GPU,dataID)
-
- # pro = torch.zeros((MAX_GPU_MEM,255)).to(device)
- pro = tf.zeros([MAX_GPU_MEM,255])
-
- padingLength = padingdata.shape[0]
- src_mask = generate_square_subsequent_mask(bptt)
- padingdata = padingdata
- elapsed = 0
- proBit = []
- offset = 0
- if not showRelut:
- trange = range
- else:
- trange = tqdm.trange
- # with torch.no_grad():
- for n,i in enumerate(trange(0, padingLength-bptt , bptt//bpttRepeatTime)):
- seq_len = min(bptt, padingLength - 1 - i)
- # input = padingdata[i:i+seq_len].long().to(device) #input torch.Size([256, 32, 4, 3]) bptt,batch_sz,kparent,[oct,level,octant]
- input = tf.cast(padingdata[i:i+seq_len],dtype=tf.int64)
-
- if( n % MAX_GPU_MEM_It==0 and n):
- # proBit.append(pro[:nodeID.max()+1].detach().cpu().numpy())
- proBit.append(pro[:tf.reduce_max(nodeID).numpy()+1].numpy())
- offset = offset + tf.reduce_max(nodeID).numpy() +1
- # nodeID = dataID[i+seq_len - bptt//bpttRepeatTime+1:i+seq_len+1].squeeze(0) - offset
- # nodeID = tf.squeeze(dataID[i+seq_len - bptt//bpttRepeatTime+1:i+seq_len+1], axis=0) - offset
- nodeID = dataID[i+seq_len - bptt//bpttRepeatTime+1:i+seq_len+1] - offset
-
- if input.shape[0] != bptt:
- src_mask = generate_square_subsequent_mask(input.shape[0])
- start_time = time.time()
- output = model(input,src_mask,[],training=False)
- elapsed =elapsed+ time.time() - start_time
-
- # output = output[bptt-bptt//bpttRepeatTime:].reshape(-1,255)
- output = tf.reshape(output[bptt-bptt//bpttRepeatTime:],(-1,255))
-
- # nodeID[nodeID<0] = -1 #for padding data
- nodeID = tf.where(nodeID<0,pro.shape[0]-1,nodeID)
- # nodeID = nodeID.reshape(-1)
- nodeID = tf.reshape(nodeID,(-1,1))
- # p = torch.softmax(output,1)
- p = tf.nn.softmax(output,axis=1)
- # pro[nodeID,:] = p #pro:[MAX_GPU_MEM,255] p:(-1,255) nodeID与p长度要一致,nodeID里可以有重复的index,长度可以比pro长,但是index的值不能超过pro的范围
- pro = tf.tensor_scatter_nd_update(pro, nodeID, p) #nodeID(shape:[n,1],必须是2维)与p长度要一致,nodeID里可以有重复的index,长度可以比pro长,但是index的值不能超过pro的范围,并且不能是-1
-
- # proBit.append(pro[:nodeID.max()+1].detach().cpu().numpy())
- proBit.append(pro[:tf.reduce_max(nodeID).numpy()+1].numpy())
- # del pro,input,src_mask
- # torch.cuda.empty_cache()
- proBit = np.vstack(proBit)
- #%%
-
- bit = 0
- acc = 0
- templevel = 1
- binszList = []
- octNumList = []
- if True:
- # Estimate the bitrate at each level
- for i in range(oct_len):
- octvalue = int(oct_seq[i,-1])
- bit0,acc0 =encodeNode(proBit[i],octvalue)
- bit+=bit0
- acc+=acc0
- if templevel!=levelID[i]:
- templevel = levelID[i]
- binszList.append(bit)
- octNumList.append(i+1)
- binszList.append(bit)
- octNumList.append(i+1)
- binsz = bit # estimated bin size
-
- if actualcode:
- codec = numpyAc.arithmeticCoding()
- if not os.path.exists(os.path.dirname(outputfile)):
- os.makedirs(os.path.dirname(outputfile))
- _,binsz = codec.encode(proBit[:oct_len,:], oct_seq.astype(np.int16).squeeze(-1)-1,outputfile)
-
- if len(binszList)<=7:
- return binsz,oct_len,elapsed,np.array(binszList),np.array(octNumList)
- return binsz,oct_len,elapsed ,np.array(binszList[7:]),np.array(octNumList[7:])
- # %%
- def main(fileName,model,actualcode = True,showRelut=True,printl = print):
-
- matDataPath = fileName
- octDataPath = matDataPath
- cell,mat = matloader(matDataPath)
- FeatDim = levelNumK #4
- oct_data_seq = np.transpose(mat[cell[0,0]]).astype(int)[:,-FeatDim:,0:6] #oct_data_seq.shape: (3391, 4, 6) oct len 3391
-
- p = np.transpose(mat[cell[1,0]]['Location']) #p.shape: (10474, 3) ptNum: 10474
- ptNum = p.shape[0]
- ptName = os.path.basename(matDataPath)
- outputfile = expName+"/data/"+ptName[:-4]+".bin"
- binsz,oct_len,elapsed,binszList,octNumList = compress(oct_data_seq,outputfile,model,actualcode,printl,showRelut)
- if showRelut:
- printl("ptName: ",ptName)
- printl("time(s):",elapsed)
- printl("ori file",octDataPath)
- printl("ptNum:",ptNum)
- printl("binsize(b):",binsz)
- printl("bpip:",binsz/ptNum)
-
- np.set_printoptions(formatter={'float': '{: 0.2f}'.format})
- printl("pre sz(b) from Q8:",(binszList))
- printl("pre bit per oct from Q8:",(binszList/octNumList))
- printl('octNum',octNumList)
- printl("bit per oct:",binsz/oct_len)
- printl("oct len",oct_len)
-
- return binsz/oct_len
|