|
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- Created on Mon Mar 1 15:36:29 2021
- @author: root
- """
- import numpy as np
- import sys
- from pcloud_functs import pcshow,pcread,lowerResolution,inds2vol,vol2inds,dilate_Loc #无框架
- # sys.path.append('/home/emre/Documents/kodlar/Reference-arithmetic-coding-master/python/')
- from usefuls import in1d_index,plt_imshow,write_ints,read_ints,write_bits,read_bits,dec2bin2,bin2dec2,ints2bs,bs2ints #无框架
- import arithmeticcoding as arc #无框架
- # import tensorflow.compat.v1 as tf1
- import torch
- import globz
- import time
- from ac_functs import ac_model2 #无框架
- # from dec2bin import dec2bin
- from runlength import RLED #无框架
-
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- dltime = 0.0
- def N_BackForth(sBBi): ##checked with matlab output
- # %% Move from current resolution one level down and then one up
- # %% In the UP step enforce at each cube all 8 possible patterns
- # %% input: sBBi, the sorted PC, by unique
-
- quotBB = np.floor(sBBi/2).astype('int') #% size is (nBBx3)
- Points_parent,iC = np.unique(quotBB,return_inverse=True,axis=0) # % size of iC is (nBBx3)
-
- PatEl = np.array( [[0, 0, 0],
- [1, 0, 0],
- [0, 1, 0],
- [1, 1, 0],
- [0, 0, 1],
- [1, 0, 1],
- [0, 1, 1],
- [1, 1, 1]])
-
- BlockiM = np.zeros([0,3],'int')
- for iloc in range(8): # % size(PatEl,1) = 8
- iv3 = PatEl[iloc,:]
- Blocki = Points_parent*2+iv3
- BlockiM = np.concatenate((BlockiM,Blocki),0)
-
- LocM = np.unique(BlockiM,axis=0)
- return LocM
-
- def get_temps_dests2(ctx_type,ENC=True,nn_model ='dec',ac_model='dec_and_enc',maxesL='dec_and_enc',sess=None,for_train=False,bs_dir=None,save_SSL=True,level=0,ori_level=0,dSSLs=0):
- #get_temps_dests2(nn_model.w1.shape[0],ENC,nn_model = nn_model,ac_model=ac_model,maxesL = maxes1-mins1+32+[80,0,80],sess=sess,bs_dir=bs_dir,save_SSL=True,level=for(3,10),ori_level=10,dSSLs=0)
- # #nn_model是模型 sess是TF进程 其他跟框架无关
- iBBr = 0
- # gtLoc = np.copy(Loc)
- wsize = np.sqrt(ctx_type//4).astype(int) #5
- swsize = wsize**2 #25
- swsize2 = 2*swsize #50
- swsize3 = 3*swsize #75
- b = (wsize-1)//2 #2
-
- maxX = maxesL[0]
- maxY = maxesL[1]
- maxZ = maxesL[2]
-
- SectSize = (maxZ,maxX)
-
- # %% Find sections in Loc
- lSSL = maxY+10
- StartStopLength = np.zeros((lSSL ,3),dtype='int')
-
- if ENC:
-
- icPC = globz.Loc[0,1] #globz.Loc:各个level的location,lrGTs[level] -mins1+32,然后去重,以y为首排序
- StartStopLength[icPC,0] = 0
- for iBB in range(globz.Loc.shape[0]):#= 1:(size(Loc,1))
- if(globz.Loc[iBB,1] > icPC):
- StartStopLength[icPC,1] = iBB-1
- StartStopLength[icPC,2] = StartStopLength[icPC,1]-StartStopLength[icPC,0]+1
- icPC = globz.Loc[iBB,1]
- StartStopLength[icPC,0] = iBB
-
- iBB = globz.Loc.shape[0]
- if(globz.Loc[iBB-1,1] == icPC):
- StartStopLength[icPC,1] = iBB-1
- StartStopLength[icPC,2] = StartStopLength[icPC,1]-StartStopLength[icPC,0]+1
-
- # ncPC = icPC
- SSL2 = StartStopLength[:,2]>0
- # np.save(bs_dir+'SSL'+str(level)+'.npy',{'SSL':SSL2,'ncPC':ncPC})
-
- if level==ori_level:
- ssbits = ''
- for ssbit in SSL2:
- ssbits=ssbits+str(int(ssbit))
- # ssbits = ssbits + '1'
- # write_bits(ssbits,bs_dir+'SSL.dat')
-
- RLED(ssbits[32:-9],lSSL-41,lSSL-41,1,bs_dir+'rSSL.dat')
-
- else:
-
- # ssbits = read_bits(bs_dir+'SSL'+str(level)+'.dat')
- # ncPC = bin2dec2(ssbits[0:(level+2)])+16
- # SSL2 = np.zeros((maxY+10,),dtype='bool')
- # for ib,ssbit in enumerate(ssbits[(level+2):(-1)]):
- # SSL2[ib] = bool(int(ssbit))
-
- SSL2 = dSSLs[level,:]
-
- ncPC = np.max(np.where(SSL2)[0])
-
- # %% Find sections in LocM
-
- StartStopLengthM = np.zeros((lSSL,3),'int')
- icPC = globz.LocM[0,1]
- StartStopLengthM[icPC,0] = 0
- for iBB in range(globz.LocM.shape[0]):
- if(globz.LocM[iBB,1] > icPC):
- StartStopLengthM[icPC,1] = iBB-1
- StartStopLengthM[icPC,2] = StartStopLengthM[icPC,1]-StartStopLengthM[icPC,0]+1
- icPC = globz.LocM[iBB,1]
- StartStopLengthM[icPC,0] = iBB
-
- iBB = globz.LocM.shape[0]
- if(globz.LocM[iBB-1,1] == icPC):
- StartStopLengthM[icPC,1] = iBB-1
- StartStopLengthM[icPC,2] = StartStopLengthM[icPC,1]-StartStopLengthM[icPC,0]+1
-
- # nM = np.max(LocM[:,1])
- nM7 = globz.LocM.shape[0] #2323968
-
- if not ENC:
- # iBBr=0
- globz.Loc = np.zeros((nM7,3),'int')
- # symbs = []
- # global symbs
-
- BWTrue1 = np.zeros( SectSize,'bool')
- BWTrue2 = np.zeros( SectSize,'bool')
- # iTT = 0
- maxnT = 40000#StartStopLengthM[icPC,2]
- Temp = np.zeros((maxnT,ctx_type),'bool')
- # Tprobs = np.zeros((maxnT,2),'float')
- for icPC in range(ncPC+1):
- if icPC%50==0:
- print('icPC:' + str(icPC))
- if (StartStopLengthM[icPC,2] > 0) & SSL2[icPC] :
- # %% 0. Mark the TRUE points on BWTrue
- BWTrue = np.zeros( SectSize,'int')
- if ENC:
- for iBB in range(StartStopLength[icPC,0],StartStopLength[icPC,1]+1):
- BWTrue[globz.Loc[iBB,2], globz.Loc[iBB,0]] = 1 #这里其实是把每个location都遍历一遍
-
- # %% 0.1 Mark the PREVIOUS SECTION TRUE points on BWTrue
- BWTrue1M = np.zeros( SectSize,'bool')
- # SLoc1M = np.zeros((StartStopLengthM[icPC+1,2],2),'int')
- if(icPC < ncPC):
- if( StartStopLengthM[icPC+1,1] > 0 ):
- for iBB in range(StartStopLengthM[icPC+1,0],StartStopLengthM[icPC+1,1]+1):
-
- iz,ix = globz.LocM[iBB,2], globz.LocM[iBB,0]
- # SLoc1M[t,:] = iz,ix
- BWTrue1M[ iz,ix] = 1
-
- BWTrueM = np.zeros( SectSize,'bool')
- SLocM = np.zeros((StartStopLengthM[icPC,2],2),'int')
- if( StartStopLengthM[icPC,1] > 0 ):
- for t,iBB in enumerate(range(StartStopLengthM[icPC,0],StartStopLengthM[icPC,1]+1)):
-
- iz,ix = globz.LocM[iBB,2], globz.LocM[iBB,0]
- SLocM[t,:] = iz,ix
- BWTrueM[ iz,ix] = 1
-
- iBBr_prev = iBBr
-
- # OneSectOctMask2(SLocM,icPC, BWTrue, BWTrue1, BWTrue2, BWTrueM, BWTrue1M, SectSize, StartStopLengthM,ctx_type,b,ENC,nn_model,ac_model,sess=sess)
-
- for iiBB,iBB in enumerate(range( StartStopLengthM[icPC,0],StartStopLengthM[icPC,1]+1)):
- iz = SLocM[iiBB,0] #globz.LocM[iBB,2]
- ix = SLocM[iiBB,1] #globz.LocM[iBB,0]
- #%%
- Temp[iiBB,0:swsize] = BWTrue2[iz-b:iz+b+1,ix-b:ix+b+1].flatten('F')
- Temp[iiBB,swsize:swsize2] = BWTrue1[iz-b:iz+b+1,ix-b:ix+b+1].flatten('F')
- Temp[iiBB,swsize2:swsize3] = BWTrueM[iz-b:iz+b+1,ix-b:ix+b+1].flatten('F')
- Temp[iiBB,swsize3:] = BWTrue1M[iz-b:iz+b+1,ix-b:ix+b+1].flatten('F')
-
- #%%
- nT = StartStopLengthM[icPC,2]
- global dltime
- start_time = time.time()
- # Tprobs = sess.run(nn_model.output,feed_dict={nn_model.input:Temp[0:nT,:]})
- input = torch.from_numpy(Temp[0:nT,:]).float().to(device)
- Tprobs = nn_model(input).cpu().detach().numpy()
- dltime += time.time() - start_time
- # iTemp = -1
- ##########2nd loop##########################################
- for iiBB,iBB in enumerate(range( StartStopLengthM[icPC,0],StartStopLengthM[icPC,1]+1)):
- # iTemp+=1
- iz = SLocM[iiBB,0] #globz.LocM[iBB,2]
- ix = SLocM[iiBB,1] #globz.LocM[iBB,0]
-
- if ENC:
- symb = BWTrue[iz, ix]#= Des[iTemp]
-
- probs = Tprobs[iiBB,:]
-
- # if not for_train:
- freq = np.ceil(probs*(2**14)).astype('int')#+1
- freqlist = list(freq)
- # freqstable.set_frequencies(freqlist)#arc.SimpleFrequencyTable(freqlist ) #arc.CheckedFrequencyTable(
- freqstable = arc.SimpleFrequencyTable(freqlist )
- if ENC:
- ac_model.encode_symbol(freqstable,symb)
- # ac_model.enc.update(freqstable,symb)
- else:#DECODER
- symb = ac_model.decode_symbol(freqstable)
-
- BWTrue[iz, ix] = symb
- if symb:
-
- globz.Loc[iBBr,:] = [ix,icPC,iz]
- iBBr +=1
-
- BWTrue2 = BWTrue1
- BWTrue1 = BWTrue
-
- iBBr_now = iBBr
- # if ENC:
-
- if not ENC:
-
- iBBr_in = iBBr_now-iBBr_prev
- if iBBr_in>0:
-
- StartStopLength[icPC,0] = iBBr_prev
- StartStopLength[icPC,1] = iBBr_now-1
- StartStopLength[icPC,2] = iBBr_now-iBBr_prev
- if ENC and level==ori_level :
-
- freqlist = [10,10]
- freqs = arc.SimpleFrequencyTable(freqlist ) #arc.CheckedFrequencyTable(
- for i_s in range(64):
- ac_model.encode_symbol(freqs,0)
- if ENC:
- dec_Loc = 0
- if not(ENC):
- dec_Loc = globz.Loc[0:iBBr_now,:]
- return dec_Loc
-
- def ENCODE_DECODE(ENC,bs_dir,nn_model,ac_model,sess,ori_level=0,GT=0): #nn_model是模型 sess是TF进程 其他跟框架无关
- #ENCODE_DECODE(1,bs_dir,nn_model,ac_model,None,ori_level,GT) #ori_level=10
- #ENCODE_DECODE(0,bs_dir,nn_model,ac_model,None,ori_level) 解码
- start = time.time()
- global dltime
- dltime = 0.0
- nintbits = ori_level*np.ones((6,),int)
- lrGTs = dict()
- if ENC:# or debug_dec:
- minsG = np.min(GT ,0)
- maxesG = np.max(GT,0)
- minmaxesG = np.concatenate((minsG,maxesG))
-
- # write_ints(minmaxesG,nintbits,bs_dir+'maxes_mins.dat')
- sibs = ints2bs(minmaxesG,nintbits) #把6个数字转成2进制bit流
-
- lrGTs[ori_level] = GT
- lrGT = np.copy(GT)
- for il in range(ori_level-2): #0-7
- lrGT = lowerResolution(lrGT) #/2去重
- lrGTs[ori_level-il-1] = lrGT
-
- lowest_bs = inds2vol(lrGTs[2],[4,4,4]).flatten().astype(int) #坐标voxel化
- lowest_str = ''
- for ibit in range(64):
- lowest_str = lowest_str+str(lowest_bs[ibit]) #二进制编码
-
- #write_bits(lowest_str+'1',bs_dir+'lowest.dat')
- sibs = sibs+lowest_str+'1'
- write_bits(sibs,bs_dir+'side_info.bs')
- dSSLs = 0
- if not ENC:
- # minmaxesG =read_ints(nintbits,bs_dir+'maxes_mins.dat')
- # lowest_str = read_bits(bs_dir+'lowest.dat')[0:64]
- sibs = read_bits(bs_dir+'side_info.bs')
- minmaxesG = bs2ints(sibs[0:np.sum(nintbits)],nintbits)
- lowest_str = sibs[np.sum(nintbits):-1]
-
- lowest_bs = np.zeros([64,],int)
- for ibit in range(64):
- lowest_bs[ibit] = int(lowest_str[ibit])
- vol = lowest_bs.reshape([4,4,4])
- lrGTs[2] = vol2inds(vol)
-
- lrmm = np.copy(minmaxesG[np.newaxis,:])
- lrmms=np.zeros((ori_level+1,6),int)
- lrmms[ori_level] = lrmm
- for il in range(ori_level-2):
- lrmm = lowerResolution(lrmm)
- lrmms[ori_level-il-1,:] = lrmm
-
- mins11 = lrmms[ori_level,1]
- maxes11 = lrmms[ori_level,4]
- lSSL = maxes11-mins11+32+10
-
- ##get dssls
- dSSLs = np.zeros((ori_level+1,4500),int)
- #ssbits = read_bits(bs_dir+'SSL.dat')[:-1]#[(ori_level+2):-1]
- ssbits = 32*'0'
- ssbits = ssbits + RLED('',lSSL-41,lSSL-41,0,bs_dir+'rSSL.dat') +9*'0'
- # dSSL = ssbits
- for ib,bit in enumerate(ssbits):
- dSSLs[ori_level,ib] = int(bit)
-
- # dncPCs = np.zeros((ori_level+1),int)
- # dncPCs[ori_level] = np.max(np.where(dSSLs[ori_level])[0])
- for level in range(ori_level,3,-1):
- # lrmms[level][1]%2
- add = lrmms[level][1]%2#1-np.where(SSLs[level])[0][-1]%2
- inds = lowerResolution(np.where(dSSLs[level])[0]+add-32)+32
- dSSLs[level-1,inds] = 1
- # dncPCs[level-1] = np.max(np.where(dSSLs[level-1])[0])
- # print(level)
- for level in range(3,ori_level+1): #到这里#3-10
- if ENC: #or debug_dec:
- mins1 = np.min(lrGTs[level] ,0)
- maxes1 = np.max(lrGTs[level],0)
- Location = lrGTs[level] -mins1+32
- if not ENC:
- mins1 = lrmms[level,0:3]
- maxes1 = lrmms[level,3:6]
- maxesL = maxes1-mins1+32+[80,0,80]
- LocM = dilate_Loc(lrGTs[level-1])-mins1+32
- LocM_ro = np.unique(LocM[:,[1,0,2]],axis=0)
- LocM[:,[1,0,2]] = LocM_ro
- del LocM_ro
- globz.LocM = LocM
-
- if ENC :#or debug_dec:
- Loc_ro = np.unique(Location[:,[1,0,2]],axis=0)
- Location[:,[1,0,2]] = Loc_ro
- globz.Loc = Location
- del Loc_ro
-
- dec_Loc= get_temps_dests2(nn_model.w1.shape[0],ENC,nn_model = nn_model,ac_model=ac_model,maxesL = maxesL,sess=sess,bs_dir=bs_dir,save_SSL=True,level=level,ori_level=ori_level,dSSLs=dSSLs)
- if not ENC:
- lrGTs[level] = dec_Loc+mins1-32
- if ENC and level==ori_level:
- ac_model.end_encoding()
- end = time.time()
- time_spent = end - start
- nmins = int(time_spent//60)
- nsecs = int(np.round(time_spent-nmins*60))
- print('time spent: ' + str(nmins) + 'm ' + str(nsecs) + 's')
- print("dltime:",dltime)
- if not ENC:
- dec_GT = lrGTs[level]
- else:
- dec_GT = 0
-
- return dec_GT,time_spent
|