|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- postprocess
- """
- from os import path as osp
- import argparse
- import ast
- import numpy as np
- from src.datasets.dataset import RBPNDatasetTest , create_val_dataset
- from mindspore import context
- from src.util.utils import save_img, save_losses, save_psnr, compute_psnr,PSNR
- import mindspore
-
- parser = argparse.ArgumentParser('Postprocess')
- parser.add_argument("--device_id", type=int, default=6, help="device id, default: 0.")
- parser.add_argument("--val_path", type=str, default=r'/mass_data/dataset/Vid4')
- parser.add_argument('--upscale_factor', type=int, default=4, choices=[2, 4, 8],
- help="Super resolution upscale factor")
- parser.add_argument('--testBatchSize', type=int, default=1, help='testing batch size')
- parser.add_argument('--model_type', type=str, default='RBPN')
- parser.add_argument('--save_eval_path', type=str, default="./Results/eval", help='save eval image path')
- parser.add_argument('--data_dir', type=str, default=r'/mass_data/dataset/Vid4')
- parser.add_argument('--file_list', type=str, default='foliage.txt')
- parser.add_argument('--other_dataset', type=bool, default=True, help="use other dataset than vimeo-90k")
- parser.add_argument('--future_frame', type=bool, default=True, help="use future frame")
- parser.add_argument('--nFrames', type=int, default=7)
- parser.add_argument('--residual', type=bool, default=False)
- args = parser.parse_args()
- mindspore.set_seed(123)
-
- def predict(prediction , target ):
- """predict
- Args:
- ds(Dataset): eval dataset
- model(Cell): the generate model
- """
- prediction = prediction[0]
- prediction = prediction * 255.
-
- target = target.squeeze().asnumpy().astype(np.float32)
- target = target * 255.
-
- psnr_predicted = PSNR(prediction, target, shave_border=args.upscale_factor)
- print("psnr:",psnr_predicted)
-
- #
- # if __name__ == '__main__':
- # rst_path = "./result_Files"
- # object_imageSize = 800
- # dataset = RBPNDatasetTest(args.val_GT_path, args.val_LR_path, args)
- # ds = create_val_dataset(dataset, args)
- # psnr_list = []
- # for i, data in enumerate(ds.create_dict_iterator(output_numpy=True)):
- # gt = data['target_image']
- # gt_img = np.squeeze(gt, axis=0)
- # file_name = osp.join(rst_path, "DBPN_data_bs" + str(args.testBatchSize) + '_' + str(i) + '_0.bin')
- # output = np.fromfile(file_name, np.float32).reshape(3, object_imageSize, object_imageSize)
- # sr_img = unpadding(output, gt_img.shape)
- # save_img_path = osp.join('./310_infer_img', 'SR')
- # save_img(sr_img, str(i), save_img_path)
- # save_img_path = osp.join('./310_infer_img', 'HR')
- # save_img(gt_img, str(i), save_img_path)
- # cur_psnr = compute_psnr(gt_img, sr_img)
- # psnr_list.append(cur_psnr)
- # print("===> Processing: {} compute_psnr:{:.4f}.".format(i, cur_psnr))
- # psnr_mean = np.mean(psnr_list)
- # print("val ending psnr = ", np.mean(psnr_list))
- # print("Generate images success!")
-
-
- if __name__ == "__main__":
-
- rst_path = "./result_Files"
- val_dataset = RBPNDatasetTest(args.val_path, args.nFrames, args.upscale_factor, args.file_list, args.other_dataset,
- args.future_frame)
- val_ds = create_val_dataset(val_dataset, args)
-
- save_pre_path = osp.join('./310_infer_image', 'prediction')
- save_gt_path = osp.join('./310_infer_image', 'gt')
- psnr_list = []
-
-
- for i, data in enumerate(val_ds.create_dict_iterator() ,1):
- context.set_context(mode=context.GRAPH_MODE, device_id=args.device_id)
- target = data['target_image']
- file_name = osp.join(rst_path , "RBPN_data_x" + str(args.testBatchSize) + '_' + str(i) + '_0.bin')
- # 分别导入预测结果和label的二进制流文件,此处一定要注意dtype和shape应与保存前一致
- prediction = np.fromfile(file_name, np.float32).reshape(1, 3, 480, 720)
-
- # save_img(prediction, str(i), save_pre_path)
- # save_img(target, str(i), save_gt_path)
-
- psnr = predict(prediction , target)
- psnr_list.append(psnr)
- print("===> Processing: {} compute_psnr:{:.4f}.".format(i, psnr))
-
- psnr_mean = np.mean(psnr_list)
- print("val ending psnr = ", np.mean(psnr_list))
- print("Generate images success!")
-
-
-
-
-
-
|