|
- import os.path
- import torch.nn
- from torchvision import transforms
- import imageio
- from PIL import Image
- import numpy as np
- import cv2
- from datetime import datetime
- from Model.HrNet.hrnet_T import HighResolutionNet
- from utils.dataset.TT_Dataset_new import MyDataset
- from torch.utils.data import DataLoader
-
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- # def estimate(y_label, y_pred):
- # y_pred[y_label==0]=0
- # # 准确率
- # acc = np.mean(np.equal(y_label, y_pred) + 0)
- #
- # return acc, y_pred
- #
- # def model_predict(model, img_data, lab_data, img_size):
- #
- # row, col, dep = img_data.shape
- #
- # if row % img_size != 0 or col % img_size != 0:
- # print('{}: Need padding the predict image...'.format(datetime.now().strftime('%c')))
- # # 计算填充后图像的 hight 和 width
- # padding_h = (row // img_size + 1) *img_size
- # padding_w = (col // img_size + 1) *img_size
- # else:
- # print('{}: No need padding the predict image...'.format(datetime.now().strftime('%c')))
- # # 不填充后图像的 hight 和 width
- # padding_h = (row // img_size) *img_size
- # padding_w = (col // img_size) *img_size
- #
- # # 初始化一个 0 矩阵,将图像的值赋值到 0 矩阵的对应位置
- # padding_img = np.zeros((padding_h, padding_w, dep), dtype='float32')
- # padding_img[:row, :col, :] = img_data[:row, :col, :]
- #
- # #初始化一个 0 矩阵,用于将预测结果的值赋值到 0 矩阵的对应位置
- # padding_pre = np.zeros((padding_h, padding_w), dtype='uint8')
- #
- # # 对 img_size * img_size 大小的图像进行预测
- # count = 0 # 用于计数
- # for i in list(np.arange(0, padding_h, img_size)):
- # if (i + img_size) > padding_h:
- # continue
- # for j in list(np.arange(0, padding_w, img_size)):
- # if (j + img_size) > padding_w:
- # continue
- #
- # # 取 img_size 大小的图像,在第一维添加维度,变成四维张量,用于模型预测
- # img_data_ = padding_img[i:i+img_size, j:j+img_size, :]
- # toTensor = transforms.ToTensor()
- # img_data_ = toTensor(img_data_)
- # img_data_ = img_data_[np.newaxis, :, :, :]
- # # img_data_ = np.transpose(img_data_, (0, 3, 1, 2))
- #
- # # 预测,对结果进行处理
- # y_pre = model.forward(img_data_)
- # # y_pre = model.predict(img_data_)
- # y_pre = np.squeeze(y_pre, axis = 0)
- # y_pre = torch.argmax(y_pre, axis = 0)
- # # y_pre = y_pre.astype('uint8')
- #
- # # 将预测结果的值赋值到 0 矩阵的对应位置
- # padding_pre[i:i+img_size, j:j+img_size] = y_pre[:img_size, :img_size]
- #
- # count += 1 # 每预测一块就+1
- #
- #
- # print('\r{}: Predited {:<5d}({:<5d})'.format(datetime.now().strftime('%c'), count, int((padding_h/img_size)*(padding_w/img_size))), end='')
- #
- # # 计算准确率
- # acc, y_pred = estimate(lab_data, padding_pre[:row, :col]+1)
- #
- # return acc, y_pred
-
- #参数
- imagePath = r"D:\code\torch\Predict\image\train_image"
- labelPath = r"D:\code\torch\Predict\image\train_label"
- modelPath = r"D:\code\torch\Train\save_model\hrnet\270-0.04784.pth"
- savePath = r"D:\code\torch\Predict\result\HRNet"
- #
- dataset = MyDataset(imagePath,labelPath)
- dataloader = DataLoader(dataset,batch_size=1,shuffle=False)
- #
- # #加载模型
- model = HighResolutionNet()
- model.load_state_dict(torch.load(modelPath,map_location=torch.device('cpu')))
- model.eval()
- #
-
- result = np.zeros((1024,512,4))
- lab = np.zeros((1024,512))
- acc_all = 0
- for i ,data in enumerate(dataloader):
- image, label, image_down, id = data
- print(id)
- image = image
- label = label
- image_down = image_down
- output = model(image)
- image = np.squeeze(image, axis = 0)
- image = torch.permute(image, [1, 2, 0])
- output = np.squeeze(output, axis = 0)
- y_pre = torch.argmax(output, axis = 0)
- acc = np.mean(np.equal(label.numpy(), y_pre.numpy()) + 0)
- print(acc)
- # acc_all += acc
- if id % 2 ==0:
- # result[(id//2)*256:(id//2)*256 + 256,0:256,:] = image.numpy() #[0:256,]
- lab[(id // 2) * 256:(id // 2) * 256 + 256, 0:256] = y_pre.numpy() # [0:256,]
- # result[0:256, (id // 4) * 256:(id // 4) * 256 + 256, :] = image.numpy() # [0:256,]
- else:
- # result[(id//2)*256:(id//2)*256 + 256,256:512, :] = image.numpy() #[256:512,]
- lab[(id // 2) * 256:(id // 2) * 256 + 256, 256:512] = y_pre.numpy() # [256:512,]
- # result[0:256, (id // 4) * 256:(id // 4) * 256 + 256, :] = image.numpy() # [0:256,]
-
-
- if id == 7:
- # path1 = os.path.join(savePath,"hrnet_data.tif")
- path2 = os.path.join(savePath,"hrnet_data_label.tif")
- # imageio.imwrite(path1,result)
- imageio.imwrite(path2, lab)
- print("save")
- break
-
- # print(acc_all,acc_all/len(dataloader))
-
-
- # output = model.forward(image)
- # acc, output = model_predict(model, image, label, img_size=image_size)
- # output = output.numpy()
- # output = output.argmax(dim = 0)
- # print(f"准确率: {acc}")
- # imageio.imwrite( savePath,output)
-
-
-
-
-
|