|
- # -*- coding = utf-8 -*-
- '''
- # @time:2023/3/27 11:08
- # Author:DFTL
- # @File:pre_eval.py
- '''
-
- import argparse
- import os.path
-
- import imageio
- import torch
- import Module
- import numpy as np
- import h5py
- from sklearn.decomposition import PCA
- from transformers import BertTokenizer,BertModel
- import json
- from sklearn.metrics import classification_report
-
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- target_names = ['class 0', 'class 1', 'class 2','class 3','class 4','class 5','class 6','class 7','class 8','class 9','class 10','class 11']
-
-
- def getKG(model,tokenizer,args):
-
- image = torch.randn(1,32,27,27).to(device)
- list = []
-
- with open(args.knowledge_path, 'r', encoding='utf-8') as f:
- knowledge_dict = json.load(f)
-
- for i in range(args.num_class):
-
- triples = knowledge_dict[str(i)]
- event = triples[0][0] + triples[0][1] + triples[0][2]
-
- event_input1 = tokenizer((event), padding='longest', truncation=True, max_length=25, return_tensors="pt").to(
- device)
-
- encoded_event = model(image,event_input1,torch.ones([1]).to(device))
- list.append(encoded_event)
-
- return torch.cat([e for e in list],0)
-
- '''加载数据'''
- def load_Data(path):
- #X = sio.loadmat(r"F:\Dataset_of_DF\YRD\NC12.mat") print(X.keys())
- print("========Loading Data========")
- X = h5py.File(path)['HSI']
- Y = h5py.File(path)['GT']
- return X ,Y
-
- def pca_change(X, num_components):
- print("========PCA========")
- newX = np.reshape(X, (-1, X.shape[2]))
- pca = PCA(n_components=num_components, whiten=True)
- newX = pca.fit_transform(newX)
- newX = np.reshape(newX, (X.shape[0], X.shape[1], num_components))
- return newX
-
- '''预测整幅影像'''
- def predict(model,patch_size,input_data):
-
- h,w,Band = input_data.shape
-
- #填充
- paddingdata = np.pad(input_data, ((13, 13), (13, 13), (0, 0)), "constant")
- paddingdata = np.transpose(paddingdata, [2, 0, 1])
- #结果
- result = np.zeros((h,w),dtype='uint8')
-
- #得到每个像素点的patch,并进行预测
- start_row = 13
- start_col = 13
-
- count = 0
-
- print("========Start Predicting========")
-
- for row in range(h):
- for col in range(w):
- #print("\r============{}/{}===========".format(count,h*w))
- print("\r","============{}/{}===========".format(count,h*w),end="",flush=True)
-
- row_start = (start_row+row) - (patch_size // 2)
- row_end = (start_row+row) + (patch_size // 2 + 1)
- col_start = (start_col+col) - (patch_size // 2)
- col_end = (start_col+col) + (patch_size // 2 + 1)
-
- patch = paddingdata[:, row_start:row_end, col_start:col_end]
- patch = patch[np.newaxis, :, :, :]
-
- # HybirdSN 用到三维卷积,输入增加一个维度
- patch = torch.tensor(patch)
- # patch = torch.unsqueeze(patch, 1)
-
- patch = patch.to(torch.float32).to(device)
- # label = label.to(device)
-
- output = model.forward_test(patch,kg)
-
- pred = np.argmax(output.detach().cpu().numpy(), axis=1)
-
- result[row][col] = pred
-
- count+=1
-
- return result
-
- '''评估预测结果'''
- def eval(pre,label):
- h,w = pre.shape
- print(classification_report(pre.flatten(), label.flatten(), target_names=target_names))
-
-
- if __name__ == "__main__":
-
- parser = argparse.ArgumentParser(description='Model Controller')
-
- parser.add_argument('--mode',type=str,default='pre_train',help='pre_train/test/train/final_test')
-
- parser.add_argument('--HSI_path',type=str,default=r"/dataset/NC12/NC12.mat")
- parser.add_argument('--kg_path',type=str,default=r"/dataset/kg.pt")
-
- parser.add_argument('--modelname',type=str,default=r"")
-
- parser.add_argument('--knowledge_path',type=str, default='/code/DataSet/Knowledge_v2.json')
- parser.add_argument('--config_path', type=str, default="/code/configs/config_bert.json")
-
- # parser.add_argument('--val_image_path',type=str, default='',help='val_imageset path')
- # parser.add_argument('--val_label_path', type=str, default='', help='val_labelset path')
-
- parser.add_argument('--max_len', type=int, default=300)
- parser.add_argument('--pretrained_bert',type=str,default='/dataset/pre_trained/bert-base-chinese')
- parser.add_argument('--embed_dim',type=int,default=64)
- parser.add_argument('--num_class',type=int,default=12)
-
- parser.add_argument('--result_path', type=str, default="/result")
- parser.add_argument('--result_name', type=str, default="result.tif")
-
- args = parser.parse_args()
-
- '''加载高光谱数据'''
- input_data ,label = load_Data(args.HSI_path)
- input_data = np.transpose(input_data,[1,2,0])
- label = np.uint8(label)
- input_data = pca_change(input_data,32)
-
- '''定义模型并加载训练参数'''
- model = Module.MyModule_v8(args).to(device)
- model.load_state_dict(torch.load(r"/model/"+args.modelname,map_location=device))
-
- '''加载zhishi'''
- if os.path.exists(args.kg_path):
- kg = torch.load(args.kg_path,map_location=device)
- else:
- tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert)
- kg = getKG(model,tokenizer,args).to(device)
- torch.save(kg,args.kg_path)
-
- '''预测并评估'''
-
- pre = predict(model,27,input_data)
- args.result_name = os.path.basename(args.modelname).split('.p')[0] + '.tif'
- imageio.imwrite(os.path.join(args.result_path,args.result_name),pre)
-
- #eval(pre,label)
-
-
-
|