|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Evaluation script"""
-
- import argparse
- import os
- import warnings
- from time import time
-
- from mindspore import context
- from mindspore import dataset as de
- from mindspore import load_checkpoint
- from mindspore import load_param_into_net
-
- from src.core.eval_utils import get_official_eval_result
- from src.predict import predict
- from src.predict import predict_kitti_to_anno
- from src.utils import get_config
- from src.utils import get_model_dataset
- from src.utils import get_params_for_net
-
- warnings.filterwarnings('ignore')
-
-
- def run_evaluate(args):
- """run evaluate"""
- cfg_path = args.cfg_path
- ckpt_path = args.ckpt_path
-
- cfg = get_config(cfg_path)
-
- device_id = int(os.getenv('DEVICE_ID', '0'))
- device_target = args.device_target
-
- context.set_context(mode=context.GRAPH_MODE, device_target=device_target, device_id=device_id)
-
- model_cfg = cfg['model']
-
- center_limit_range = model_cfg['post_center_limit_range']
-
- pointpillarsnet, eval_dataset, box_coder = get_model_dataset(cfg, False)
-
- params = load_checkpoint(ckpt_path)
- new_params = get_params_for_net(params)
- load_param_into_net(pointpillarsnet, new_params)
-
- eval_input_cfg = cfg['eval_input_reader']
-
- eval_column_names = eval_dataset.data_keys
-
- ds = de.GeneratorDataset(
- eval_dataset,
- column_names=eval_column_names,
- python_multiprocessing=True,
- num_parallel_workers=6,
- max_rowsize=100,
- shuffle=False
- )
- batch_size = eval_input_cfg['batch_size']
- ds = ds.batch(batch_size, drop_remainder=False)
- print("-----------lixiangyi0------------", eval_column_names, ds)
- data_loader = ds.create_dict_iterator(num_epochs=1)
- print("-----------lixiangyi2------------")
-
- class_names = list(eval_input_cfg['class_names'])
-
- dt_annos = []
- gt_annos = [info["annos"] for info in eval_dataset.kitti_infos]
-
- log_freq = 100
- len_dataset = len(eval_dataset)
- start = time()
-
- for i, data in enumerate(data_loader):
- voxels = data["voxels"]
- num_points = data["num_points"]
- coors = data["coordinates"]
- bev_map = data.get('bev_map', False)
-
- preds = pointpillarsnet(voxels, num_points, coors, bev_map)
- if len(preds) == 2:
- preds = {
- 'box_preds': preds[0],
- 'cls_preds': preds[1],
- }
- else:
- preds = {
- 'box_preds': preds[0],
- 'cls_preds': preds[1],
- 'dir_cls_preds': preds[2]
- }
- preds = predict(data, preds, model_cfg, box_coder)
-
- dt_annos += predict_kitti_to_anno(preds,
- data,
- class_names,
- center_limit_range)
-
- if i % log_freq == 0 and i > 0:
- time_used = time() - start
- print('processed: ' + str(i * batch_size/len_dataset) + ' imgs, time elapsed: '+ str(time_used) + ' s')
-
- result = get_official_eval_result(
- gt_annos,
- dt_annos,
- class_names,
- )
- print(result)
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--cfg_path', required=True, help='Path to config file.')
- parser.add_argument('--ckpt_path', required=True, help='Path to checkpoint.')
- parser.add_argument('--device_target', default='GPU', help='device target')
-
- parse_args = parser.parse_args()
-
- run_evaluate(parse_args)
|