|
- '''
- Author: fuchy@stu.pku.edu.cn
- Description: this file encodes point cloud
- FilePath: /compression/encoder.py
- All rights reserved.
- '''
- import tensorflow as tf
- from numpy import mod
- from Preparedata.data import dataPrepare #无框架
- from encoderTool import main
- from networkTool import CPrintl,expName
- from octAttention import model
- import glob,datetime,os
- import pt as pointCloud #无框架
- import argparse
-
- def parse_args():
- parser = argparse.ArgumentParser(
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument("--input", default='', help="Input filename.") #'compressed/28_airplane_0270' '/userhome/PCGCv1/testdata/8iVFB/redandblack_vox10_1550.ply'
- parser.add_argument("--ckpt_dir", type=str, default='/userhome/PCGCv1/pytorch2/ckpts/hyper_mgpu2/epoch_13_12599.pth', dest="ckpt_dir", help='checkpoint')
- args = parser.parse_args()
- return args
-
- ############## warning ###############
- ## decoder.py relys on this model here
- ## do not move this lines to somewhere else
- # model = model.to(device)
- # saveDic = reload(None,'modelsave/obj/encoder_epoch_00800093.pth')
- # model.load_state_dict(saveDic['encoder'])
- args = parse_args()
- input_data = tf.ones([1024, 32, 4, 6],dtype=tf.int32)
- input_mask = model.generate_square_subsequent_mask(1024)
- output = model(input_data, input_mask,[],training=False)
- model.load_weights(tf.train.latest_checkpoint(args.ckpt_dir)) #文件夹名称
-
- ###########Objct##############
- list_orifile = [args.input]
- if __name__=="__main__":
- printl = CPrintl(expName+'/encoderPLY_tf.txt') #输出的同时保存到log
- printl('_'*50,'OctAttention V tf','_'*50)
- printl(datetime.datetime.now().strftime('%Y-%m-%d:%H:%M:%S'))
- # printl('load checkpoint', saveDic['path'])
- for oriFile in list_orifile:
- printl(oriFile)
- if (os.path.getsize(oriFile)>300*(1024**2)):#300M
- printl('too large!')
- continue
- ptName = os.path.splitext(os.path.basename(oriFile))[0]
- for qs in [1]:
- ptNamePrefix = ptName
- matFile,DQpt,refPt = dataPrepare(oriFile,saveMatDir='./Data/testPly',qs=qs,ptNamePrefix='',rotation=False)
- # please set `rotation=True` in the `dataPrepare` function when processing MVUB data
- main(matFile,model,actualcode=True,printl =printl) # actualcode=False: bin file will not be generated
- # print('_'*50,'pc_error','_'*50)
- # pointCloud.pcerror(refPt,DQpt,None,'-r 1023',None).wait() #这里比的是量化前,量化后再反量化的点
|