|
- # Copyright 2022 Huawei Technologies Co., Ltd
- # Copyright 2022 Aerospace Information Research Institute,
- # Chinese Academy of Sciences.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """monitor of swintransformerv2"""
- import time
- import numpy as np
- import os
- from mindspore.common.tensor import Tensor
- from mindspore.train.callback._callback import Callback
-
-
- class StateMonitor(Callback):
- """StateMonitor"""
-
- def __init__(self, data_size, tot_batch_size=None,
- eval_interval=None, eval_offset=None,
- eval_engine=None, logger=None):
- super(StateMonitor, self).__init__()
- self.data_size = data_size
- self.tot_batch_size = tot_batch_size
- self.epoch_num = 0
- self.loss = 0
- self.eval_interval = eval_interval
- self.eval_offset = eval_offset
- self.eval_engine = eval_engine
- self.best_acc = -1
- self.best_acc_top5 = -1
- self.best_i2t_recall = -1
- self.best_t2i_recall = -1
- self.mean_fps = 0.0
- self.print = print
- self.epoch_time = 0
- if logger is not None:
- self.print = logger
-
- def step_end(self, run_context):
- """step end"""
- cb_params = run_context.original_args()
- loss = cb_params.net_outputs
-
- if isinstance(loss, (tuple, list)):
- if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
- loss = loss[0]
-
- if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
- loss = np.mean(loss.asnumpy())
-
- self.loss = loss
-
- def epoch_begin(self, run_context):
- # pylint: disable=W0613
- self.epoch_time = time.time()
-
- def epoch_end(self, run_context):
- # pylint: disable=W0613
- """epoch end of StateMonitor"""
- epoch_seconds = (time.time() - self.epoch_time)
- per_step_seconds = epoch_seconds / self.data_size
-
- print_str = "epoch[{}]".format(self.epoch_num)
- print_str += ', epoch time: {:.2f}s'.format(epoch_seconds)
- print_str += ', per step time: {:.4f}s'.format(per_step_seconds)
- print_str += ', loss={:.6f}'.format(self.loss)
-
- if self.tot_batch_size is not None:
- fps = self.tot_batch_size * self.data_size / epoch_seconds
- self.mean_fps = (self.mean_fps * self.epoch_num + fps) / (self.epoch_num + 1)
- print_str += ', fps={:.2f}'.format(fps)
-
- if (self.epoch_num + 1) % self.eval_interval == self.eval_offset:
- eval_start = time.time()
- self.eval_engine.eval()
- output = self.eval_engine.get_result()
- eval_seconds = time.time() - eval_start
- if output is not None:
- if isinstance(output, list):
- print_str += ', top1 accuracy={:.6f}'.format(float(output[0]))
- print_str += ', top5 accuracy={:.6f}'.format(float(output[1]))
- print_str += ', i2t_recall={:.6f}'.format(float(output[2]))
- print_str += ', t2i_recall={:.6f}'.format(float(output[3]))
- print_str += ', eval_cost={:.2f}'.format(eval_seconds)
-
- if float(output[0]) > self.best_acc:
- self.best_acc = float(output[0])
- if float(output[1]) > self.best_acc_top5:
- self.best_acc_top5 = float(output[1])
- if float(output[2]) > self.best_i2t_recall:
- self.best_i2t_recall = float(output[2])
- if float(output[3]) > self.best_t2i_recall:
- self.best_t2i_recall = float(output[3])
- else:
- print_str += ', accuracy={:.6f}'.format(float(output))
- print_str += ', eval_cost={:.2f}'.format(eval_seconds)
-
- if float(output) > self.best_acc:
- self.best_acc = float(output)
- print_str += ', best_accuracy={:.6f}'.format(float(self.best_acc))
-
- if int(os.getenv('RANK_ID')) == 0:
- self.print(print_str)
- self.epoch_num += 1
-
-
- def build_finetune_callback_2(args, eval_engine, time_cb, ckpoint_cb, LM, uploadOutput):
- """build finetune callback"""
- ckpt_config = args.callback.ckpt_config
- train_config = args.train_config
- dataset_config = args.finetune_dataset
- state_cb = StateMonitor(data_size=train_config.per_epoch_size,
- tot_batch_size=train_config.batch_size * args.device_num,
- eval_interval=dataset_config.eval_interval,
- eval_offset=dataset_config.eval_offset,
- eval_engine=eval_engine,
- logger=args.logger.info)
-
- ckpt_append_info = [{"epoch_num": train_config.has_trained_epoches,
- "step_num": train_config.has_trained_steps}]
-
-
-
- callback = [ckpoint_cb, state_cb, LM, uploadOutput]
-
-
- return callback
|