|
-
- """export net together with checkpoint into air/mindir models"""
- import os
- import os.path as osp
- import argparse
- import numpy as np
- from mindspore import Tensor, context, export
- import argparse
- import ast
- import time
- from src.model.rbpn import Net as RBPN
- from mindspore import load_checkpoint, load_param_into_net, context
- from src.util.utils import save_img, save_losses, save_psnr, compute_psnr , init_weights ,PSNR
- import numpy as np
- import mindspore
- from src.loss.generatorloss import GeneratorLoss
-
-
-
- parser = argparse.ArgumentParser(description='rbpn export')
- parser.add_argument("--batch_size", type=int, default=1, help="batch size")
- # parser.add_argument("--ckpt_path", type=str, required=True, help="path of checkpoint file")
- parser.add_argument("--file_name", type=str, default="rbpn3", help="output file name.")
- parser.add_argument("--file_format", type=str, default="MINDIR", choices=['MINDIR', 'AIR', 'ONNX'], help="file format")
- parser.add_argument('--scale', type=int, default='4', help='super resolution scale')
- parser.add_argument("--device_id", type=int, default=2, help="device id, default: 0.")
- parser.add_argument('--model_type', type=str, default='RBPN')
-
- args = parser.parse_args()
- print(args)
- mindspore.set_seed(123)
-
- if __name__ == "__main__":
- context.set_context(mode=context.GRAPH_MODE, device_id=args.device_id)
-
- print("=======> load model ckpt")
- ckpt = './weights/65_RBPN.ckpt'
- params = load_checkpoint(ckpt)
- model = RBPN(num_channels=3, base_filter=256, feat=64, num_stages=3, n_resblock=5, nFrames=7,
- scale_factor=4)
-
- load_param_into_net(model, params)
- model.set_train(False)
-
- input_shp = [1, 3, 120, 180]
- neigbor_shp = [1, 6, 3, 120, 180]
- flow_shp = [1, 6, 2, 120, 180]
-
- input_array = Tensor(np.random.normal(-1.0, 1.0, size=input_shp).astype(np.float32))
- neigbor_array = Tensor(np.random.normal(-1.0, 1.0, size=neigbor_shp).astype(np.float32))
- flow_array = Tensor(np.random.normal(-1.0, 1.0, size=flow_shp).astype(np.float32))
-
-
- G_file = "{}_model".format(args.file_name)
- mindir_path = 'mindir_path'
- file_path = osp.join(os.getcwd(), mindir_path)
- if not osp.exists(file_path):
- os.makedirs(file_path)
- G_file_path = os.path.join(file_path, G_file)
-
- export(model, input_array, neigbor_array, flow_array, file_name=G_file_path, file_format=args.file_format)
- print('export successfully!')
|