|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Export mindir."""
- import argparse
- import numpy as np
- from mindspore import Tensor, export, load_checkpoint
-
- from src.regnet import regnet20
-
- parser = argparse.ArgumentParser(description='RegNet checkpoint export')
- parser.add_argument('--batch_size', type=int, default=64)
- parser.add_argument('--im_size', type=int, default=32)
- parser.add_argument('--class_num', type=int, default=10)
- parser.add_argument('--checkpoint_path', type=str, default='RegNet.ckpt')
- parser.add_argument('--export_file_name', type=str, default='pretrained_model')
- args = parser.parse_args()
- batch_size = args.batch_size
- im_size = args.im_size
- class_num = args.class_num
- checkpoint_path = args.checkpoint_path
- export_file_name = args.export_file_name
-
- regnet = regnet20(batch_size=batch_size, im_size=im_size, class_num=class_num)
-
- load_checkpoint(checkpoint_path, net=regnet)
- input_ckpt = np.random.uniform(0.0, 1.0, size=[batch_size, 3, im_size, im_size]).astype(np.float32)
- export(regnet, Tensor(input_ckpt), file_name=export_file_name, file_format='MINDIR')
|