|
- import os
-
- import mindspore as ms
- import mindspore.nn as nn
- from mindspore import FixedLossScaleManager, Model
- from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
-
- from optimizer import get_optimizer
- from losses import get_loss_func
- from postprocess import create_postprocesser
- from metrics import get_metrics
- from scheduler import get_scheduler
- from utils import EvalCallback
-
-
- def apply_eval(eval_param):
- evaluation_model = eval_param["model"]
- eval_ds = eval_param["dataset"]
- metrics_name = eval_param["metrics_name"]
- res = evaluation_model.eval(eval_ds)
- return res[metrics_name]
-
-
- class Trainer():
- def __init__(self, network, config, rank_id, train_dataset=None, eval_dataset=None):
- super(Trainer, self).__init__()
- self.network = network
- self.config = config
- self.rank_id = rank_id
- self.train_dataset = train_dataset
- self.eval_dataset = eval_dataset
-
- def _build_train_model(self):
-
- global_config = self.config["Global"]
- optim_config = self.config["Optimizer"]["operator"]
- loss_config = self.config["Loss"]
- post_config = self.config["PostProcess"]
- metric_config = self.config["Metric"]
- scheduler_config = self.config["Scheduler"]
- dynamic_loss_scale = self.config["Optimizer"]["dynamic_loss_scale"]
-
- self.loss_func = get_loss_func(loss_config)
- self.postprocesser = create_postprocesser(post_config)
- self.metrics = get_metrics(metric_config, decoder=self.postprocesser)
-
- steps_per_epoch = self.train_dataset.get_dataset_size()
- self.lr_scheduler = get_scheduler(scheduler_config, steps_per_epoch)
- self.optimizer = get_optimizer(self.network, optim_config, lr=self.lr_scheduler)
-
- if dynamic_loss_scale:
- loss_scale_manager = ms.amp.DynamicLossScaleManager(init_loss_scale=optim_config["loss_scale"],
- scale_factor=2,
- scale_window=1000)
- else:
- loss_scale_manager = FixedLossScaleManager(
- loss_scale=optim_config["loss_scale"], drop_overflow_update=False)
- model = Model(self.network,
- loss_fn=self.loss_func,
- optimizer=self.optimizer,
- metrics={
- 'CRNNAccuracy': self.metrics},
- amp_level=global_config["amp_level"],
- loss_scale_manager=loss_scale_manager
- )
-
- self.model = model
-
- def _build_eval_model(self):
-
- if "loss_func" not in self.__dict__:
- loss_config = self.config["Loss"]
- metric_config = self.config["Metric"]
- post_config = self.config["PostProcess"]
- self.loss_func = get_loss_func(loss_config)
- self.postprocesser = create_postprocesser(post_config)
- self.metrics = get_metrics(metric_config, decoder=self.postprocesser)
-
- self.eval_model = Model(self.network.set_train(False), self.loss_func, metrics={
- 'CRNNAccuracy': self.metrics})
-
- def _get_eval_callback(self, config):
- save_ckpt_path = self._set_output_dir()
- eval_param_dict = {"model": self.eval_model,
- "dataset": self.eval_dataset,
- "metrics_name": "CRNNAccuracy"}
- eval_callback = EvalCallback(apply_eval, eval_param_dict, self.rank_id, interval=config["eval_interval"],
- eval_start_epoch=config["eval_start_epoch"], save_best_ckpt=True,
- ckpt_directory=save_ckpt_path, best_ckpt_name="best_acc.ckpt",
- metrics_name="acc")
- return eval_callback
-
- def _set_output_dir(self):
- save_ckpt_path = os.path.join(
- self.config["Global"]["save_model_dir"], 'ckpt')
- return save_ckpt_path
-
- def train(self):
- self._build_train_model()
-
- num_epoch = self.config["Global"]["epoch_num"]
- step_size = self.train_dataset.get_dataset_size()
-
- eval_interval = min(self.config["Global"]["eval_interval"], step_size)
-
- if self.eval_dataset is not None and self.eval_dataset.get_dataset_size() > 0:
- self._build_eval_model()
- eval_callback = self._get_eval_callback(self.config["Global"])
- callbacks = [eval_callback]
- else:
- callbacks = [LossMonitor(per_print_times=eval_interval), TimeMonitor(data_size=step_size)]
-
- dataset_sink_mode = False
- self.model.train(num_epoch, self.train_dataset,
- callbacks=callbacks, dataset_sink_mode=dataset_sink_mode)
-
- def evaluate(self):
- self._build_eval_model()
- dataset_sink_mode = False
- res = self.eval_model.eval(
- self.eval_dataset, dataset_sink_mode=dataset_sink_mode)
- return res
|