|
- import sys
- sys.path.append('core')
-
- from mindspore import Tensor, context, load_checkpoint, load_param_into_net
- import mindspore.ops as ops
- from core.utils.ms_utils import InputPadder
- # from core.utils.ms_utils import load_pytorch_state_dict
- import core.ms_datasets as ms_datasets
- from core.ms_datasets import *
-
- from PIL import Image
- import argparse
- import os
- import numpy as np
- import mindspore.numpy as msnp
- # import torch
- import imageio
-
- from core.network import RAFTGMA
-
- from core.utils import flow_viz
- from core.utils import frame_utils
-
- import tqdm
-
- def validate_chairs(model, iters=6):
- """ Perform evaluation on the FlyingChairs (test) split """
- model.set_train(mode=False)
- epe_list = []
-
- val_dataset = ms_datasets.FlyingChairs(split='validation')
- for val_id in tqdm.tqdm(range(len(val_dataset))):
- image1, image2, flow_gt, _ = val_dataset[val_id]
- image1 = image1[None]
- image2 = image2[None]
- # ndarray2ms_tensor
- image1 = ms.Tensor(image1).astype(ms.float32)
- image2 = ms.Tensor(image2).astype(ms.float32)
- _, _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
- flow_pr = torch.from_numpy(flow_pr.asnumpy())
- epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt()
- epe_list.append(epe.view(-1).numpy())
-
- epe = np.mean(np.concatenate(epe_list))
- print("Validation Chairs EPE: %f" % epe)
- return {'chairs_epe': epe}
-
- def validate_things(model, iters=6):
- """ Perform evaluation on the FlyingThings (test) split """
- model.set_train(mode=False)
- results = {}
-
- for dstype in ['frames_cleanpass', 'frames_finalpass']:
- epe_list = []
- val_dataset = ms_datasets.FlyingThings3D(dstype=dstype, split='validation')
- print(f'Dataset length {len(val_dataset)}')
- for val_id in tqdm.tqdm(range(len(val_dataset))):
- image1, image2, flow_gt, _ = val_dataset[val_id]
- image1 = image1[None]
- image2 = image2[None]
-
- padder = InputPadder(image1.shape)
- image1, image2 = padder.pad(image1, image2)
- image1 = ms.Tensor(image1).astype(ms.float32)
- image2 = ms.Tensor(image2).astype(ms.float32)
- _, _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
- flow_pr = torch.from_numpy(flow_pr.asnumpy())
- flow = padder.unpad(flow_pr[0]).cpu()
-
- epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
- epe_list.append(epe.view(-1).numpy())
-
- epe_all = np.concatenate(epe_list)
-
- epe = np.mean(epe_all)
- px1 = np.mean(epe_all < 1)
- px3 = np.mean(epe_all < 3)
- px5 = np.mean(epe_all < 5)
-
- print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
- results[dstype] = np.mean(epe_list)
-
- return results
-
- def validate_sintel(args, model, iters=6):
- """ Peform validation using the Sintel (train) split """
- results = {}
- for dstype in ['clean', 'final']:
- val_dataset = ms_datasets.MpiSintel(split='training', dstype=dstype, root=args.dataset_path)
- epe_list = []
-
- for val_id in tqdm.tqdm(range(len(val_dataset))):
- image1, image2, flow_gt, _ = val_dataset[val_id]
-
- image1 = image1[None]
- image2 = image2[None]
-
- padder = InputPadder(image1.shape)
- image1, image2 = padder.pad(image1, image2)
-
- # ndarray2ms_tensor
- image1 = ms.Tensor(image1).astype(ms.float32)
- image2 = ms.Tensor(image2).astype(ms.float32)
-
- _, _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
-
- flow_pr = flow_pr.asnumpy()
- # flow_pr = Tensor.from_numpy(flow_pr.asnumpy())
-
- flow = padder.unpad(flow_pr[0])
- # tmp3 = ((flow - flow_gt)**2).sum(axis=0)
- # epe = ops.sqrt(tmp3)
- epe = np.sqrt(np.sum((flow - flow_gt)**2, axis=0))
- # print(epe, epe.shape)
- epe_list.append(epe.reshape(-1))
- # if val_id == 15:
- # epe_all_debug = np.concatenate(epe_list)
- # print('[DEBUG] concat susseed!')
-
- epe_all = np.concatenate(epe_list)
-
- epe = np.mean(epe_all)
- px1 = np.mean(epe_all<1)
- px3 = np.mean(epe_all<3)
- px5 = np.mean(epe_all<5)
-
- print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
- results[dstype] = np.mean(epe_list)
-
- return results
-
-
-
- def validate_sintel_occ(model, iters=6):
- """ Peform validation using the Sintel (train) split """
- model.eval()
- results = {}
- for dstype in ['albedo', 'clean', 'final']:
- # for dstype in ['clean', 'final']:
- val_dataset = ms_datasets.MpiSintel(split='training', dstype=dstype, occlusion=True)
- epe_list = []
- epe_occ_list = []
- epe_noc_list = []
-
- for val_id in range(len(val_dataset)):
- image1, image2, flow_gt, _, occ, _ = val_dataset[val_id]
- image1 = image1[None].cuda()
- image2 = image2[None].cuda()
-
- padder = InputPadder(image1.shape)
- image1, image2 = padder.pad(image1, image2)
-
- _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
- flow = padder.unpad(flow_pr[0]).cpu()
-
- epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
- epe_list.append(epe.view(-1).numpy())
-
- epe_noc_list.append(epe[~occ].numpy())
- epe_occ_list.append(epe[occ].numpy())
-
- epe_all = np.concatenate(epe_list)
-
- epe_noc = np.concatenate(epe_noc_list)
- epe_occ = np.concatenate(epe_occ_list)
-
- epe = np.mean(epe_all)
- px1 = np.mean(epe_all<1)
- px3 = np.mean(epe_all<3)
- px5 = np.mean(epe_all<5)
-
- epe_occ_mean = np.mean(epe_occ)
- epe_noc_mean = np.mean(epe_noc)
-
- print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
- print("Occ epe: %f, Noc epe: %f" % (epe_occ_mean, epe_noc_mean))
- results[dstype] = np.mean(epe_list)
-
- return results
-
- def validate_kitti(model, iters=6):
- """ Peform validation using the KITTI-2015 (train) split """
- model.set_train(mode=False)
- val_dataset = ms_datasets.KITTI(split='training')
-
- out_list, epe_list = [], []
- for val_id in tqdm.tqdm(range(len(val_dataset))):
- image1, image2, flow_gt, valid_gt = val_dataset[val_id]
- image1 = image1[None]
- image2 = image2[None]
- #print(image1.shape, image2.shape)
-
- padder = InputPadder(image1.shape, mode='kitti')
- image1, image2 = padder.pad(image1, image2)
- #print(image1.shape, image2.shape)
-
- image1 = ms.Tensor(image1).astype(ms.float32)
- image2 = ms.Tensor(image2).astype(ms.float32)
-
- _, _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
-
- flow_pr = flow_pr.asnumpy()
- flow = padder.unpad(flow_pr[0])
-
- epe = np.sqrt(np.sum((flow - flow_gt)**2, axis=0))
- mag = np.sqrt(np.sum(flow_gt**2, axis=0))
-
- epe = epe.reshape(-1)
- mag = mag.reshape(-1)
- val = valid_gt.reshape(-1) >= 0.5
-
- out = ((epe > 3.0) & ((epe/mag) > 0.05))
- epe_list.append(np.mean(epe[val]))
- out_list.append(out[val])
-
- epe_list = np.array(epe_list)
- out_list = np.concatenate(out_list)
-
- epe = np.mean(epe_list)
- f1 = 100 * np.mean(out_list)
-
- print("Validation KITTI: %f, %f" % (epe, f1))
- return {'kitti_epe': epe, 'kitti_f1': f1}
-
- def separate_inout_sintel_occ():
- """ Peform validation using the Sintel (train) split """
- dstype = 'clean'
- val_dataset = ms_datasets.MpiSintel(split='training', dstype=dstype, occlusion=True)
- # coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
- # coords = torch.stack(coords[::-1], dim=0).float()
- # return coords[None].expand(batch, -1, -1, -1)
-
- for val_id in range(len(val_dataset)):
- image1, image2, flow_gt, _, occ, occ_path = val_dataset[val_id]
- _, h, w = image1.size()
- coords = torch.meshgrid(torch.arange(h), torch.arange(w))
- coords = torch.stack(coords[::-1], dim=0).float()
-
- coords_img_2 = coords + flow_gt
- out_of_frame = (coords_img_2[0] < 0) | (coords_img_2[0] > w) | (coords_img_2[1] < 0) | (coords_img_2[1] > h)
- occ_union = out_of_frame | occ
- in_frame = occ_union ^ out_of_frame
-
- # Generate union of occlusions and out of frame
- # path_list = occ_path.split('/')
- # path_list[-3] = 'occ_plus_out'
- # dir_path = os.path.join('/', *path_list[:-1])
- # img_path = os.path.join('/', *path_list)
- # if not os.path.exists(dir_path):
- # os.makedirs(dir_path)
- #
- # imageio.imwrite(img_path, occ_union.int().numpy() * 255)
-
- # Generate out-of-frame
- # path_list = occ_path.split('/')
- # path_list[-3] = 'out_of_frame'
- # dir_path = os.path.join('/', *path_list[:-1])
- # img_path = os.path.join('/', *path_list)
- # if not os.path.exists(dir_path):
- # os.makedirs(dir_path)
- #
- # imageio.imwrite(img_path, out_of_frame.int().numpy() * 255)
-
- # # Generate in-frame occlusions
- # path_list = occ_path.split('/')
- # path_list[-3] = 'in_frame_occ'
- # dir_path = os.path.join('/', *path_list[:-1])
- # img_path = os.path.join('/', *path_list)
- # if not os.path.exists(dir_path):
- # os.makedirs(dir_path)
- #
- # imageio.imwrite(img_path, in_frame.int().numpy() * 255)
-
-
-
-
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--dataset_path', default='/home/zjc/sintel')
- parser.add_argument('--device_target', default='Ascend')
- parser.add_argument('--model', help="restore checkpoint")
- parser.add_argument('--dataset', help="dataset for evaluation")
- parser.add_argument('--iters', type=int, default=12)
- parser.add_argument('--num_heads', default=1, type=int,
- help='number of heads in attention and aggregation')
- parser.add_argument('--position_only', default=False, action='store_true',
- help='only use position-wise attention')
- parser.add_argument('--position_and_content', default=False, action='store_true',
- help='use position and content-wise attention')
- parser.add_argument('--mixed_precision', default=True, help='use mixed precision')
- parser.add_argument('--model_name')
-
- # Ablations
- parser.add_argument('--replace', default=False, action='store_true',
- help='Replace local motion feature with aggregated motion features')
- parser.add_argument('--no_alpha', default=False, action='store_true',
- help='Remove learned alpha, set it to 1')
- parser.add_argument('--no_residual', default=False, action='store_true',
- help='Remove residual connection. Do not add local features with the aggregated features.')
-
- args = parser.parse_args()
-
- if args.dataset == 'separate':
- separate_inout_sintel_occ()
- sys.exit()
-
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
- model = RAFTGMA(args.__dict__)
-
- # load mindspore checkpoint
- # param_dict = load_checkpoint(args.model)
- # load_param_into_net(model, param_dict)
- if args.model[-4:] == 'ckpt':
- param_dict = load_checkpoint(args.model)
- load_param_into_net(model, param_dict)
- # else:
- # model = load_pytorch_state_dict(model,args.model)
- model.set_train(mode=False)
- print(f"Loaded checkpoint at {args.model}")
-
-
- # create_sintel_submission(model, warm_start=True)
- # create_sintel_submission_vis(model, warm_start=True)
- # create_kitti_submission(model)
- # create_kitti_submission_vis(model)
-
- if args.dataset == 'chairs':
- validate_chairs(model, iters=args.iters)
-
- elif args.dataset == 'things':
- validate_things(model, iters=args.iters)
-
- elif args.dataset == 'sintel':
- validate_sintel(args, model, iters=args.iters)
-
- elif args.dataset == 'sintel_occ':
- validate_sintel_occ(model, iters=args.iters)
-
- elif args.dataset == 'kitti':
- validate_kitti(model, iters=args.iters)
|