|
- import os
- from mindspore import Model, load_checkpoint, load_param_into_net
- from mindspore import Tensor
- import mindspore.dataset.vision.py_transforms as py_transforms
- from utils import Adder
- from src.data.data_load import test_dataloader, train_dataloader
- #from skimage.metrics import peak_signal_noise_ratio
- import time
- from mindspore import context
- import mindspore.ops as ops
- import numpy as np
- import math
-
-
- def _eval(network, args):
- context.set_context(mode=context.GRAPH_MODE , device_target="Ascend")
- context.set_context(device_id=int(1))
- param_dict = load_checkpoint(args.test_model)
- load_param_into_net(network, param_dict)
- model = Model(network)
- dataloader = test_dataloader(args.data_dir, batch_size=1, num_workers=16)
- print()
- dataloader = dataloader.create_dict_iterator()
- adder = Adder()
- #sq = ops.squeeze(0)
- psnr_adder = Adder()
- s = ops.Shape()
-
- # Hardware warm-up
-
- '''for _idx, data in enumerate(dataloader):
- input_img = data["input"]
- label_img = data["label"]
- tm = time.time()
- _ = model.predict(input_img)[2]
- _ = time.time() - tm
- if _idx == 20:
- break
- '''
- for iter_idx, data in enumerate(dataloader):
-
- input_img = data["input"]
- # print(s(input_img))
- label_img = data["label"]
- #name = data["name"]
- tm = time.time()
-
- pre = model.predict(input_img)[2]
- #print("jjjjdfafsa")
-
- elapsed = time.time() - tm
- adder(elapsed)
-
- #pred_clip = ops.clip_by_value(pre, 0, 1)
- #input_numpy = input_img.asnumpy()
- pred_numpy = pre.asnumpy()
- label_numpy = label_img.asnumpy()
-
- if args.save_image:
- save_name = os.path.join(args.result_dir, name[0])
- pred_clip += 0.5 / 255
- pred = py_transforms.ToPIL(pred_clip)
- pred.save(save_name)
- psnr = PSNR(pred_numpy, label_numpy)
- psnr_adder(psnr)
- #print('%d iter %s PSNR: %.2f time: %f' % (iter_idx + 1, name, psnr, elapsed))
- # print('%d iter PSNR: %.2f time: %f' % (iter_idx + 1, psnr, elapsed))
-
- print('==========================================================')
- print('The average PSNR is %.2f dB' % (psnr_adder.average()))
- print("Average time: %f" % adder.average())
-
- def PSNR(img1, img2):
- mse = np.mean( (img1 - img2) ** 2 )
- if mse == 0:
- return 100
- PIXEL_MAX = 255
- return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
|