|
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- Created on Wed Feb 10 14:30:32 2021
-
- @author: root
- """
-
- from pcloud_functs import pcread,lowerResolution #无框架
- import numpy as np
- import os, sys
- import torch
- from models import tfModel10 as tfModel10_test
- from ac_functs import ac_model2 #无框架
- from usefuls import show_time_spent,compare_Locations,get_dir_size #无框架
- import globz #无框架
- globz.init()
- from datetime import datetime
- from config_utils import get_model_info #无框架
- import argparse
-
- def parse_args():
- parser = argparse.ArgumentParser(
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument("--input", default='compressed/redandblack_vox10_1550', help="Input filename.") #'compressed/28_airplane_0270' '/userhome/PCGCv1/testdata/8iVFB/redandblack_vox10_1550.ply'
- parser.add_argument("--output",default='output_root/', help="Output filename.")
- parser.add_argument("--ckpt_dir", type=str, default='train_logs/20220401-155532/epoch_72.pth', dest="ckpt_dir", help='checkpoint')
- args = parser.parse_args()
-
- return args
-
- args = parse_args()
-
- model_type='fNNOC'
- assert(model_type in ['NNOC','fNNOC','fNNOC1','fNNOC2','fNNOC3'])
- GPU = 1
- decode=1
- filepath = args.input #'/userhome/dataset/paper_test/redandblack_vox10_1550_n.ply' #28_airplane_0270.ply
- nlevel_down = 0 #set number of times to downsample the input to compress a lower resolution
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- # device = torch.device('cpu')
- #%%
- output_root = args.output #'output_root/'
- if not GPU:
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
-
- enc_functs_file,log_id,ctx_type = get_model_info(model_type) #'enc_functs_fast45d' ctx_type=100
- exec('from '+enc_functs_file+ ' import ENCODE_DECODE')
- ckpt_path = args.ckpt_dir #'/userhome/NNCTX-main/pytorch/train_logs/20220401-155532/epoch_72.pth'
-
- # if model_type=='NNOC':
- # nn_model = tfint10_3(ckpt_path)
- # else:
- ctx_type=100
- nn_model = tfModel10_test(ctx_type).to(device)
- ckpt = torch.load(ckpt_path)
- nn_model.load_state_dict(ckpt['model'])
- #################################################################
- if not os.path.exists(output_root):
- os.mkdir(output_root)
- bs_dir = output_root + 'bss/'
- if not os.path.exists(bs_dir):
- os.mkdir(bs_dir)
- ##################################################################
- GT = pcread(filepath).astype('int')
- ori_level = np.ceil(np.log2(np.max(GT))).astype(int)
- #assert(str(ori_level) in filepath)
- #LOWER RES INPUT FOR DEBUGGING:
- ori_level = ori_level-nlevel_down
- for il in range(nlevel_down):
- GT = lowerResolution(GT)
- print('input level:'+str(ori_level)) #10
-
- acbspath = bs_dir+'AC.dat'
- ac_model = ac_model2(2,acbspath,1) #创建编码对象
- _,time_spente = ENCODE_DECODE(1,bs_dir,nn_model,ac_model,None,ori_level,GT) #nn_model,sess与框架有关
- npts = GT.shape[0]
-
- CL = get_dir_size(bs_dir)
- bpv = CL/npts
- # bpvs[ifile]=bpv
- print('bpv: '+str(bpv))
- print('filepath:'+filepath)
- print('input level:'+str(ori_level))
-
- if decode:
- ac_model = ac_model2(2,acbspath,0)
- dec_GT,time_spentd = ENCODE_DECODE(0,bs_dir,nn_model,ac_model,None,ori_level)
-
- TP,FP,FN=compare_Locations(dec_GT,GT)
-
- print('bpv: '+str(bpv))
- print('filepath:'+filepath)
- print('input level:'+str(ori_level))
-
- print('enc:')
- show_time_spent(time_spente)
-
- if decode:
- print('dec:')
- show_time_spent(time_spentd)
|