|
- import argparse
- import cv2
- import glob
- import os
- import shutil
- import torch
-
- from basicsr.archs.basicvsr_arch import BasicVSR
- from basicsr.data.data_util import read_img_seq
- from basicsr.utils.img_util import tensor2img
-
-
- def inference(imgs, imgnames, model, save_path):
- with torch.no_grad():
- outputs = model(imgs)
- # save imgs
- outputs = outputs.squeeze()
- outputs = list(outputs)
- for output, imgname in zip(outputs, imgnames):
- output = tensor2img(output)
- cv2.imwrite(os.path.join(save_path, f'{imgname}_BasicVSR.png'), output)
-
-
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/BasicVSR_REDS4.pth')
- parser.add_argument(
- '--input_path', type=str, default='datasets/REDS4/sharp_bicubic/000', help='input test image folder')
- parser.add_argument('--save_path', type=str, default='results/BasicVSR', help='save image path')
- parser.add_argument('--interval', type=int, default=15, help='interval size')
- args = parser.parse_args()
-
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
- # set up model
- model = BasicVSR(num_feat=64, num_block=30)
- model.load_state_dict(torch.load(args.model_path)['params'], strict=True)
- model.eval()
- model = model.to(device)
-
- os.makedirs(args.save_path, exist_ok=True)
-
- # extract images from video format files
- input_path = args.input_path
- use_ffmpeg = False
- if not os.path.isdir(input_path):
- use_ffmpeg = True
- video_name = os.path.splitext(os.path.split(args.input_path)[-1])[0]
- input_path = os.path.join('./BasicVSR_tmp', video_name)
- os.makedirs(os.path.join('./BasicVSR_tmp', video_name), exist_ok=True)
- os.system(f'ffmpeg -i {args.input_path} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 {input_path} /frame%08d.png')
-
- # load data and inference
- imgs_list = sorted(glob.glob(os.path.join(input_path, '*')))
- num_imgs = len(imgs_list)
- if len(imgs_list) <= args.interval: # too many images may cause CUDA out of memory
- imgs, imgnames = read_img_seq(imgs_list, return_imgname=True)
- imgs = imgs.unsqueeze(0).to(device)
- inference(imgs, imgnames, model, args.save_path)
- else:
- for idx in range(0, num_imgs, args.interval):
- interval = min(args.interval, num_imgs - idx)
- imgs, imgnames = read_img_seq(imgs_list[idx:idx + interval], return_imgname=True)
- imgs = imgs.unsqueeze(0).to(device)
- inference(imgs, imgnames, model, args.save_path)
-
- # delete ffmpeg output images
- if use_ffmpeg:
- shutil.rmtree(input_path)
-
-
- if __name__ == '__main__':
- main()
|