|
- # paddle模型必须保存为 .pdmodel和.pdparams 的静态图,包含模型结构和参数
- # 用户需要调用 paddle.jit.save()来保存
- # reference https://www.johngo689.com/38791/
-
- import paddle
- import argparse, os, sys, stat
- from paddle.static import InputSpec
-
-
- parser = argparse.ArgumentParser(description='Paddle Lenet Example')
-
- parser.add_argument('--model',
- type=str,
- help='path to training/inference dataset folder'
- )
- parser.add_argument('--n',
- type=int,
- default=64,
- help='batch size for input shape type'
- )
- parser.add_argument('--c',
- type=int,
- default=1,
- help='channel for input shape type'
- )
- parser.add_argument('--h',
- type=int,
- default=28,
- help='height for input shape type'
- )
- parser.add_argument('--w',
- type=int,
- default=28,
- help='width for input shape type'
- )
-
- if __name__ == "__main__":
- args = parser.parse_args()
-
-
- model_name = args.model.split('.')[0]
- model_path = '/tmp/dataset/' + model_name
- output_path = '/tmp/output/' + model_name
- #l = os.listdir(model_path)
- #print(l)
-
- model = paddle.jit.load(model_path)
- input_shape = paddle.static.InputSpec(shape=[args.n, args.c, args.h, args.w], dtype='float32')
-
- os.system('ls -a /tmp/dataset/')
- print('find all your model files above.')
- print('---------------------------------------')
- print('input data shape:' , input_shape)
- print('---------------------------------------')
- print('loading model:', args.model)
- paddle.summary(model,(input_shape))
-
- x = paddle.static.InputSpec(shape=[args.n, args.c, args.h, args.w], dtype='float32')
- paddle.onnx.export(model, output_path , input_spec=[input_shape])
|