|
- import os
- # import torch
- # import logging
- # import cv2
- # from PIL import Image
- import imageio.v2 as imageio
- import numpy as np
- # import torch.utils.data as data
- # from os.path import join, exists
- import math
- # import random
- # import sys
- # import json
- # import random
- # from subnet.basics import *
- from subnet.ms_ssim_mindspore import ms_ssim
- from augmentation import random_flip_np, random_crop_and_pad_image_and_labels_np
- from mindspore import Tensor
- # import mindspore.ops as ops
- # import mindspore as ms
-
- name_ref = {'Beauty':"Beauty_1920x1024_120fps_420_8bit_YUV",
- 'HoneyBee':"HoneyBee_1920x1024_120fps_420_8bit_YUV",
- 'ReadySteadyGo':"ReadySteadyGo_1920x1024_120fps_420_8bit_YUV",
- 'YachtRide':"YachtRide_1920x1024_120fps_420_8bit_YUV",
- 'Bosphorus':"Bosphorus_1920x1024_120fps_420_8bit_YUV",
- 'Jockey':"Jockey_1920x1024_120fps_420_8bit_YUV",
- 'ShakeNDry':"ShakeNDry_1920x1024_120fps_420_8bit_YUV", }
- out_channel_N = 64
- out_channel_M = 96
- out_channel_mv = 128
-
- def CalcuPSNR(target, ref):
- diff = ref - target
- diff = diff.flatten('C')
- rmse = math.sqrt(np.mean(diff**2.))
- return 20 * math.log10(1.0 / (rmse))
-
- class UVGDataSet():
- def __init__(self, root="/userhome/DVC/PyTorch/data/UVG/images/",
- filelist="/userhome/DVC/PyTorch/data/UVG/originalv.txt", refdir='H265L23',
- testfull=False, save_refmsssim='/userhome/DVC/MindSpore/refmsssim.npy'):
- if os.path.exists(save_refmsssim):
- self.refmsssim = np.load(save_refmsssim, allow_pickle=True).item()
- self.save = False
- else:
- self.refmsssim = {}
- self.save = True
- with open(filelist) as f:
- folders = f.readlines()
- self.ref = []
- self.refbpp = []
- self.input = []
- self.hevcclass = []
- AllIbpp = self.getbpp(refdir)
- ii = 0
- for folder in folders:
- seq = folder.rstrip()
- seqIbpp = AllIbpp[ii]
- imlist = os.listdir(os.path.join(root, seq))
- cnt = 0
- for im in imlist:
- if im[-4:] == '.png':
- cnt += 1
- if testfull:
- framerange = cnt // 12
- else:
- framerange = 1
- for i in range(framerange):
- # refpath = os.path.join(root, seq, refdir, 'im'+str(i * 12 + 1).zfill(4)+'.png') #没有这样的路径
- refpath = os.path.join('/userhome/DVC/PyTorch/data/UVG/videos_crop_1920x1024/',
- name_ref[seq], refdir, 'im'+str(i * 12 + 1).zfill(4)+'.png')
- if os.path.exists(refpath):
- inputpath = []
- for j in range(12):
- img = os.path.join(root, seq, 'im' + str(i * 12 + j + 1).zfill(3)+'.png')
- if os.path.exists(img):
- inputpath.append(img)
- self.ref.append(refpath)
- self.refbpp.append(seqIbpp)
- self.input.append(inputpath)
- ii += 1
-
- def getbpp(self, ref_i_folder):
- Ibpp = None
- if ref_i_folder == 'H265L20':
- print('use H265L20')
- Ibpp = []# you need to fill bpps after generating crf=20
- elif ref_i_folder == 'H265L23':
- print('use H265L23')
- #Beauty HoneyBee ReadySteadyGo YachtRide Bosphorus Jockey ShakeNDry
- Ibpp = [0.724, 0.567, 0.550, 0.501, 0.471, 0.360, 0.581]# you need to fill bpps after generating crf=23 按照上面originalv.txt的顺序
- elif ref_i_folder == 'H265L26':
- print('use H265L26')
- Ibpp = []# you need to fill bpps after generating crf=26
- elif ref_i_folder == 'H265L29':
- print('use H265L29')
- Ibpp = []# you need to fill bpps after generating crf=29
- else:
- print('cannot find ref : ', ref_i_folder)
- exit()
- if len(Ibpp) == 0:
- print('You need to generate I frames and fill the bpps above!')
- exit()
- return Ibpp
-
- def __len__(self):
- return len(self.ref)
-
- def __getitem__(self, index):
- ref_image = imageio.imread(self.ref[index]).transpose(2, 0, 1).astype(np.float32) / 255.0
- h = (ref_image.shape[1] // 64) * 64
- w = (ref_image.shape[2] // 64) * 64
- ref_image = np.array(ref_image[:, :h, :w]) #CHW
- input_images = []
- refpsnr = None
- refmsssim = None
- for filename in self.input[index]:
- input_image = (imageio.imread(filename).transpose(2, 0, 1)[:, :h, :w]).astype(np.float32) / 255.0 #CHW
- if refpsnr is None:
- refpsnr = CalcuPSNR(input_image, ref_image)
- key = '_'.join(filename.split('/')[-2:])+'-'+'_'.join(self.ref[index].split('/')[-3:])
- key = key.replace('.png', '')
- if self.save: #需要把refmsssim保存下来
- if key not in self.refmsssim.keys():
- refmsssim = ms_ssim(Tensor.from_numpy(input_image[np.newaxis, :]),
- Tensor.from_numpy(ref_image[np.newaxis, :]), data_range=1.0).asnumpy()
- self.refmsssim[key] = refmsssim.item()
- else:
- refmsssim = self.refmsssim[key]
- else:
- refmsssim = self.refmsssim[key]
- else:
- input_images.append(input_image[:, :h, :w])
-
- input_images = np.array(input_images)
- return input_images, ref_image, self.refbpp[index], refpsnr, refmsssim
-
- class DataSet():
- def __init__(self, rootdir="/mnt/cloud_disk/ssk/data/vimeo_septuplet/sequences/",
- test_txt="/mnt/cloud_disk/ssk/data/vimeo_septuplet/test.txt", im_height=256, im_width=256):
- self.image_input_list, self.image_ref_list = self.get_vimeo(rootdir=rootdir, filefolderlist=test_txt)
- self.im_height = im_height
- self.im_width = im_width
-
- # self.featurenoise = ops.zeros((out_channel_M, self.im_height // 16, self.im_width // 16), dtype=ms.float32)
- # self.znoise = ops.zeros((out_channel_N, self.im_height // 64, self.im_width // 64), dtype=ms.float32)
- # self.mvnois = ops.zeros((out_channel_mv, self.im_height // 16, self.im_width // 16), dtype=ms.float32)
- print("dataset find image: ", len(self.image_input_list))
-
- def get_vimeo(self, rootdir="/mnt/cloud_disk/ssk/data/vimeo_septuplet/sequences/",
- filefolderlist="/mnt/cloud_disk/ssk/data/vimeo_septuplet/test.txt"):
- with open(filefolderlist) as f:
- data = f.readlines()
-
- fns_train_input = []
- fns_train_ref = []
-
- for n, line in enumerate(data, 1):
- y = os.path.join(rootdir, line.rstrip())
- if os.path.exists(y):
- # try: #有问题的图片被替换了,所以这一步可以省略
- # image = imageio.imread(y)
- # except Exception as e:
- # print(e)
- # else: #正常打开不出错
- refnumber = int(y[-5:-4]) - 2
- refname = y[0:-5] + str(refnumber) + '.png'
- if os.path.exists(refname):
- # try:
- # image = imageio.imread(refname)
- # except Exception as e:
- # print(e)
- # else: #正常打开不出错
- fns_train_input += [y]
- fns_train_ref += [refname]
-
- return fns_train_input, fns_train_ref
-
- def __len__(self):
- return len(self.image_input_list)
-
- def __getitem__(self, index):
- input_image = imageio.imread(self.image_input_list[index])
- ref_image = imageio.imread(self.image_ref_list[index])
-
- input_image = input_image.astype(np.float32) / 255.0
- ref_image = ref_image.astype(np.float32) / 255.0
-
- input_image = input_image.transpose(2, 0, 1) #CHW
- ref_image = ref_image.transpose(2, 0, 1)
-
- # input_image = Tensor.from_numpy(input_image)
- # ref_image = Tensor.from_numpy(ref_image)
-
- input_image, ref_image = random_crop_and_pad_image_and_labels_np(input_image, ref_image, [self.im_height, self.im_width])
- input_image, ref_image = random_flip_np(input_image, ref_image)
-
- # quant_noise_feature = ops.ZerosLike()(self.featurenoise).asnumpy()
- # ms.common.initializer.Uniform(scale=0.5)(quant_noise_feature)
- quant_noise_feature = np.random.uniform(-0.5,0.5,(out_channel_M, self.im_height // 16, self.im_width // 16)).astype(np.float32)
-
- # quant_noise_z = ops.ZerosLike()(self.znoise).asnumpy()
- # ms.common.initializer.Uniform(scale=0.5)(quant_noise_z)
- quant_noise_z = np.random.uniform(-0.5,0.5,(out_channel_N, self.im_height // 64, self.im_width // 64)).astype(np.float32)
-
- # quant_noise_mv = ops.ZerosLike()(self.mvnois).asnumpy()
- # ms.common.initializer.Uniform(scale=0.5)(quant_noise_mv)
- quant_noise_mv = np.random.uniform(-0.5,0.5,(out_channel_mv, self.im_height // 16, self.im_width // 16)).astype(np.float32)
-
- return input_image, ref_image, quant_noise_feature, quant_noise_z, quant_noise_mv
-
- if __name__ == "__main__":
- test_dataset = UVGDataSet(testfull=True)
- print(len(test_dataset))
- for batch_idx, input in enumerate(test_dataset):
- print(batch_idx)
- print(input[4])
- # break
- # print(test_dataset.refmsssim)
- # np.save('/userhome/DVC/MindSpore/refmsssim.npy', test_dataset.refmsssim)
|