|
- # -*- coding: utf-8 -*-
- """
- Created on Thu Jul 28 09:00:44 2022
-
- @author: sunhuan
- """
-
- # ===================================================================================================
- import torch
- from Model.unet import UNet # ,model_psp, model_Deeplabv3P , , mdoel_segnet, model_Deeplabv3P,, model_unet2 , model_enet
- import numpy as np
- import imageio
- import cv2
- from datetime import datetime
- import os
- import gdal
- import sys
- import xml.etree.ElementTree as ET
- from pathlib import Path
- # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
-
- # ===================================================================================================
-
- # ===================================================================================================
-
-
- import logging
- from logging import handlers
-
- class Logger(object):
- level_relations = {
- 'debug': logging.DEBUG,
- 'info': logging.INFO,
- 'warning': logging.WARNING,
- 'error': logging.ERROR,
- 'crit': logging.CRITICAL
- } # 日志级别关系映射
-
- def __init__(self, filename, level='info', when='D', backCount=3,
- fmt='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'):
- self.logger = logging.getLogger(filename)
- format_str = logging.Formatter(fmt) # 设置日志格式
- self.logger.setLevel(self.level_relations.get(level)) # 设置日志级别
- sh = logging.StreamHandler() # 往屏幕上输出
- sh.setFormatter(format_str) # 设置屏幕上显示的格式
- th = handlers.TimedRotatingFileHandler(filename=filename, when=when, backupCount=backCount,
- encoding='utf-8') # 往文件里写入#指定间隔时间自动生成文件的处理器
- # 实例化TimedRotatingFileHandler
- # interval是时间间隔,backupCount是备份文件的个数,如果超过这个个数,就会自动删除,when是间隔的时间单位,单位有以下几种:
- # S 秒
- # M 分
- # H 小时、
- # D 天、
- # W 每星期(interval==0时代表星期一)
- # midnight 每天凌晨
- th.setFormatter(format_str) # 设置文件里写入的格式
- self.logger.addHandler(sh) # 把对象加到logger里
- self.logger.addHandler(th)
-
-
- def makedirs(dir_path):
- """
- 本函数实现以下功能:
- 1、创建路径文件夹
- """
- if not os.path.exists(dir_path):
- os.makedirs(dir_path)
- # print('{}: Folder creation successful: {}'.format(datetime.now().strftime('%c'), dir_path))
- else:
-
- print('{}: Folder already exists: {}'.format(datetime.now().strftime('%c'), dir_path))
-
- return dir_path
-
-
- # ===================================================================================================
-
- def estimate(y_label, y_pred):
- """
- 本函数实现以下功能:
- 1、掩膜
- 2、计算准确率
- """
- # 掩膜
- # y_pred[y_label==0]=0
-
- # 准确率
- acc = np.mean(np.equal(y_label, y_pred) + 0)
-
- return acc, y_pred
-
-
- # ===================================================================================================
-
- def model_predict(model, img_data, img_size):
- """
- 本函数实现以下功能:
- 1、对于一幅高宽较大的图像,实现分块预测,每块的大小是参数 img_size
-
- @parameter:
- model: 模型参数
- img_data:需要预测的图像数据
- lab_data:需要育德的图像的标签
- img_size:预测图像块的大小(不等于 img_data 的大小)
- """
- # 获取预测图像的 shape
- row, col, dep = img_data.shape
-
- # 为了查看信息,没什么用
- if row % img_size != 0 or col % img_size != 0:
- # print('{}: Need padding the predict image...'.format(datetime.now().strftime('%c')))
- # 计算填充后图像的 hight 和 width
- padding_h = (row // img_size + 1) * img_size
- padding_w = (col // img_size + 1) * img_size
- else:
- # print('{}: No need padding the predict image...'.format(datetime.now().strftime('%c')))
- # 不填充后图像的 hight 和 width
- padding_h = (row // img_size) * img_size
- padding_w = (col // img_size) * img_size
-
- # 初始化一个 0 矩阵,将图像的值赋值到 0 矩阵的对应位置
- padding_img = np.zeros((padding_h, padding_w, dep), dtype='float32')
- padding_img[:row, :col, :] = img_data[:row, :col, :]
-
- # 初始化一个 0 矩阵,用于将预测结果的值赋值到 0 矩阵的对应位置
- padding_pre = np.zeros((padding_h, padding_w), dtype='uint8')
-
- # 对 img_size * img_size 大小的图像进行预测
- count = 0 # 用于计数
- for i in list(np.arange(0, padding_h, img_size)):
- if (i + img_size) > padding_h:
- continue
- for j in list(np.arange(0, padding_w, img_size)):
- if (j + img_size) > padding_w:
- continue
-
- # 取 img_size 大小的图像,在第一维添加维度,变成四维张量,用于模型预测
- img_data_ = padding_img[i:i + img_size, j:j + img_size, :]
- img_data_ = img_data_[np.newaxis, :, :, :]
- img_data_ = np.transpose(img_data_, (0, 3, 1, 2))
- img_data_ = torch.Tensor(img_data_)
-
- # 预测,对结果进行处理
- y_pre = model(img_data_)
- # y_pre = model.predict(img_data_)
- y_pre = np.squeeze(y_pre, axis=0)
- y_pre = torch.argmax(y_pre, axis=0)
-
- # 将预测结果的值赋值到 0 矩阵的对应位置
- padding_pre[i:i + img_size, j:j + img_size] = y_pre[:img_size, :img_size]
-
- count += 1 # 每预测一块就+1
- log.logger.info(f'正在处理{count}/{int((padding_h / img_size) * (padding_w / img_size))}')
-
-
- # print('\r{}: Predited {:<5d}({:<5d})'.format(datetime.now().strftime('%c'), count, int((padding_h / img_size) * (padding_w / img_size))), end='')
-
- # 计算准确率
- #acc, y_pred = estimate(lab_data, padding_pre[:row, :col] + 1)
- y_pred = padding_pre[:row, :col] + 1
-
- return y_pred
-
-
- # =========================================================================================
- # def add_proj(img_path,save_path,arr):
- # # filename = r"C:\Users\Administrator\Desktop\GF项目\test\landsat\new_orson_B2345.tif"
- # # 无人机 "C:\Users\Administrator\Desktop\GF项目\wurenji.tif"
- # # result = r"C:\Users\Administrator\Desktop\GF项目\result1.tif"
- # # arr = imageio.imread(result)
- #
- # # 打开图像并创建空间
- # data = gdal.Open(img_path)
- # driver = gdal.GetDriverByName('GTiff')
- #
- # out_tif = driver.Create(save_path, arr.shape[1], arr.shape[0], 1,
- # gdal.GDT_Float32)
- # out_tif.GetRasterBand(1).WriteArray(arr)
- #
- # # 数据集的基本信息
- # # print('Raster Driver : {d}\n'.format(d=driver.ShortName))
- # # print('影像的波段数: ', data.RasterCount)
- # img_width, img_height = data.RasterXSize, data.RasterYSize
- # # print('影像的列,行数: {r}rows * {c}colums'.format(r=img_width, c=img_height))
- #
- # transform = data.GetGeoTransform()
- # if transform:
- # # print('栅格数据的空间参考:{}'.format(transform))
- # # print('影像分辨率:{}m'.format(transform[1]))
- # # 写入空间参考信息
- # out_tif.SetGeoTransform(transform)
- # else:
- # print('未识别出空间参考信息!')
- #
- # # print('栅格数据的空间参考:{}'.format(data.GetGeoTransform())) # 栅格数据的6参数
- #
- # proj = data.GetProjection()
- # if proj:
- # # print('投影信息:{}\n'.format(data.GetProjection())) # 栅格数据的投影
- # # 写入坐标投影
- # out_tif.SetProjection(proj)
- # else:
- # print('未识别出投影信息!')
- #
- # out_tif.FlushCache()
- # del out_tif
-
- # def arr2raster(arr, out_dir, prj=None, trans=None):
- # """
- # 将数组转成栅格文件写入硬盘
- # :param arr: 输入的mask数组 ReadAsArray()
- # :param out_dir: 输出的栅格文件路径
- # :param prj: gdal读取的投影信息 GetProjection(),默认为空
- # :param trans: gdal读取的几何信息 GetGeoTransform(),默认为空
- # :return:
- # """
- # driver = gdal.GetDriverByName('GTiff')
- # out_tif = driver.Create(out_dir, arr.shape[1], arr.shape[0], 1, gdal.GDT_Float32)
- # if prj:
- # out_tif.SetProjection(prj)
- # if trans:
- # out_tif.SetGeoTransform(trans)
- # # 将数组的各通道写入图片
- # out_tif.GetRasterBand(1).WriteArray(arr)
- # out_tif.FlushCache() # 将tif写入硬盘
- # out_tif = None
- # # print("保存tif成功!")
- # =========================================================================================
-
- if __name__ == '__main__':
- """
- 主函数
- """
- """
- 加载图像信息
- """
- img_size = 256
- LR = 1e-4
- input_sizes = (img_size, img_size, 4)
-
- xmlpath = r"E:\yqj\code\GF\coral\test.xml" #sys.argv[1] #
- tree = ET.parse(xmlpath)
- root = tree.getroot()
-
-
- path =root[1][0][2][0][1][0].text# r"E:\yqj\code\GF\Code\lable\langhuajiao\image\image1.tif"
- name = Path(path).stem
- work_path = root[1][0][3][0][0].text # r"E:\yqj\code\GF\Code\lable\langhuajiao\image\DYT.tif" #
- save_path = work_path + "\\" + name +".tif" # r"E:\yqj\code\GF\Code\lable\langhuajiao\image\DYT.tif" #
- model_name = 'UNet'
- classes = 16
- model_path = r"E:\yqj\try\code\torch\GF2\50-0.78368.pth" #
- # 日志文件
- log_path = work_path + "\\" + name + ".log"
- if os.path.exists(work_path) == False:
- os.makedirs(work_path)
-
- f = open(log_path, 'w')
- f.close()
- log = Logger(log_path, level='debug')
- log.logger.info("Start !")
-
- # txtpath = Path(save_path).with_suffix('.txt')
-
-
-
- # 加载预测图像
- if not os.path.exists(path):
- print('{}: Do not find the image: {}'.format(datetime.now().strftime('%c'), path))
- img_data = imageio.imread(path)
-
- #最大最小归一化
- B1, B2, B3, B4 = cv2.split(img_data)
- B1_normalization = ((B1 - np.min(B1)) / (np.max(B1) - np.min(B1)) * 1).astype('float32')
- B2_normalization = ((B2 - np.min(B2)) / (np.max(B2) - np.min(B2)) * 1).astype('float32')
- B3_normalization = ((B3 - np.min(B3)) / (np.max(B3) - np.min(B3)) * 1).astype('float32')
- B4_normalization = ((B4 - np.min(B4)) / (np.max(B4) - np.min(B4)) * 1).astype('float32')
- image_new_data = cv2.merge([B1_normalization, B2_normalization, B3_normalization, B4_normalization])
-
- """
- 加载模型信息
- """
-
- # ===================================================================================================
-
-
- # 加载模型参数
- model = UNet(num_classes=16)
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
- model.eval()
-
-
- # 预测结果
- predict_result = model_predict(model, image_new_data, img_size)
- values = np.unique(predict_result)
- # np.savetxt(txtpath, values, fmt='%d', delimiter=',')
- # file = open(txtpath, "w", encoding='utf-8')
- # file.write(str(values) + ' ')
- # file.close()
-
- dataset = gdal.Open(path) # 提供地理坐标信息和几何信息的栅格底图
- projection = dataset.GetProjection()
- transform = dataset.GetGeoTransform()
-
- arr2raster(predict_result, save_path, prj=projection, trans=transform)
- #
- # # 保存预测结果
- # add_proj(path, save_path, predict_result)
- imageio.imwrite(save_path,predict_result)
- # print('{}: Predict the success of image {}\n'.format(datetime.now().strftime('%c'), save_path))
- log.logger.info("Finish !")
-
- # =========================================================================================
|