|
- # encoding:utf-8
- """Train retinanet and get checkpoint files."""
-
- import os
- import argparse
- import ast
- import moxing as mox
- import mindspore.nn as nn
- from mindspore import context, Tensor
- from mindspore.communication.management import init, get_rank
- from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor, Callback
- from mindspore.train import Model
- from mindspore.context import ParallelMode
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from mindspore.common import set_seed
- from src.config import config
- from src.dataset import create_EfficientDet_datasets
- from src.monitor import Monitor
- from src.lr_schedule import get_lr_cosine
- from src.mind_backbone import EfficientDetBackbone
- from src.efficientdet.loss import FocalLoss
- import os
- import numpy as np
- import mindspore as ms
- import mindspore.nn as nn
- from mindspore.ops import operations as P
- from mindspore.ops import functional as F
- from mindspore.ops import composite as C
- from mindspore.common.tensor import Tensor
- from mindspore.common.parameter import Parameter
- from mindspore.common import dtype as mstype
- from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
- from mindspore.context import ParallelMode
- import mindspore.common.initializer as weight_init
- from mindspore.nn import TrainOneStepCell
- from mindspore.parallel._auto_parallel_context import auto_parallel_context
-
- from mindspore.communication.management import get_group_size
- from mindspore import context, FixedLossScaleManager
-
- import math
-
- set_seed(1)
-
- grad_scale = C.MultitypeFuncGraph("grad_scale")
- @grad_scale.register("Tensor", "Tensor")
- def tensor_grad_scale(scale, grad):
- return grad * P.Reciprocal()(scale)
-
- class TrainingWrapper(nn.Cell):
- """
- Encapsulation class of SSD network training.
-
- Append an optimizer to the training network after that the construct
- function can be called to create the backward graph.
-
- Args:
- network (Cell): The training network. Note that loss function should have been added.
- optimizer (Optimizer): Optimizer for updating the weights.
- sens (Number): The adjust parameter. Default: 1.0.
- use_global_nrom(bool): Whether apply global norm before optimizer. Default: False
- """
- def __init__(self, network, optimizer, sens=1.0, use_global_norm=False):
- super(TrainingWrapper, self).__init__(auto_prefix=False)
- self.network = network
- self.network.set_grad()
- self.weights = ms.ParameterTuple(network.trainable_params())
- self.optimizer = optimizer
- self.grad = C.GradOperation(get_by_list=True, sens_param=True)
- self.sens = sens
- self.reducer_flag = False
- self.grad_reducer = None
- self.use_global_norm = use_global_norm
- self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
- if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
- self.reducer_flag = True
- if self.reducer_flag:
- mean = context.get_auto_parallel_context("gradients_mean")
- if auto_parallel_context().get_device_num_is_set():
- degree = context.get_auto_parallel_context("device_num")
- else:
- degree = get_group_size()
- self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
- self.hyper_map = C.HyperMap()
-
- def construct(self, *args):
- weights = self.weights
- loss = self.network(*args)
- sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
- grads = self.grad(self.network, weights)(*args, sens)
- if self.reducer_flag:
- # apply grad reducer on grads
- grads = self.grad_reducer(grads)
- if self.use_global_norm:
- grads = self.hyper_map(F.partial(grad_scale, F.scalar_to_array(self.sens)), grads)
- grads = C.clip_by_global_norm(grads)
- return F.depend(loss, self.optimizer(grads))
-
-
- def _calculate_fan_in_and_fan_out(tensor):
- """
- _calculate_fan_in_and_fan_out
- """
- dimensions = len(tensor)
- if dimensions < 2:
- raise ValueError("Fan in and fan out can not be computed for tensor"
- " with fewer than 2 dimensions")
- if dimensions == 2: # Linear
- fan_in = tensor[1]
- fan_out = tensor[0]
- else:
- num_input_fmaps = tensor[1]
- num_output_fmaps = tensor[0]
- receptive_field_size = 1
- if dimensions > 2:
- receptive_field_size = tensor[2] * tensor[3]
- fan_in = num_input_fmaps * receptive_field_size
- fan_out = num_output_fmaps * receptive_field_size
- return fan_in, fan_out
-
-
-
- def init_weights(model):
- # 返回所有模块的迭代器
- for name, cell in model.cells_and_names():
- is_conv_layer = isinstance(cell, nn.Conv2d)
-
- if is_conv_layer:
-
- if "conv_list" in name or "header" in name:
- fan_in, fan_out = _calculate_fan_in_and_fan_out(cell.weight.shape)
- sigma = math.sqrt(1. / float(fan_in)) # 这里计算的是std 而不是bound
- data = ms.Tensor(np.random.normal(loc = 0, scale=sigma, size=cell.weight.shape).astype(np.float32))
- cell.weight.set_data(weight_init.initializer(data, cell.weight.shape))
- else:
- cell.weight.set_data(weight_init.initializer(weight_init.HeUniform(),
- cell.weight.shape,
- cell.weight.dtype))
-
- if cell.has_bias is True:
- if "header_cls" in name:
- bias_value = -np.log((1 - 0.01) / 0.01)
- cell.bias.set_data(weight_init.initializer(bias_value, cell.bias.shape))
- else:
- cell.bias.set_data(weight_init.initializer('zeros', cell.bias.shape))
-
-
- class WithLossCell(nn.Cell):
- def __init__(self, backbone, loss):
- super(WithLossCell, self).__init__()
- self.backbone = backbone
- self.loss = loss
-
- def construct(self, x, y):
- _, reg, cls, anchor = self.backbone(x)
- cls_loss, reg_loss = self.loss(reg, cls, anchor, y)
- return cls_loss + reg_loss
-
-
- class TransferCallback(Callback):
-
- def __init__(self, local_train_path, obs_train_path):
- super(TransferCallback, self).__init__()
- self.local_train_path = local_train_path
- self.obs_train_path = obs_train_path
-
- def step_end(self, run_context):
- cb_params = run_context.original_args()
- current_epoch = cb_params.cur_epoch_num
- if current_epoch % 10 == 0 and current_epoch != 0:
- # mox.file.copy_parallel(self.local_train_url, self.obs_train_path)
- mox.file.copy_parallel(self.local_train_path, self.obs_train_path)
-
-
- def main():
- parser = argparse.ArgumentParser(description="EfficientDet training")
- parser.add_argument("--distribute", type=ast.literal_eval, default=True, help="Run distribute, default is False.")
- parser.add_argument("--workers", type=int, default=8, help="Num parallel workers.")
- parser.add_argument("--data_url", type=str, default=None, help="mindrecord dir")
- parser.add_argument("--train_url", type=str, default=None, help="ckpt output dir in obs")
- parser.add_argument("--lr", type=float, default=0.1, help="Learning rate, default is 0.1.")
- parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.")
- parser.add_argument("--epoch_size", type=int, default=500, help="Epoch size, default is 500.")
- parser.add_argument("--batch_size", type=int, default=16, help="Batch size, default is 32.")
- parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.")
- # parser.add_argument("--pretrained_backbone", type=str, default=None, help="backbone ckpt file path.")
- parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.")
- parser.add_argument("--save_checkpoint_epochs", type=int, default=1, help="Save checkpoint epochs, default is 5.")
- parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
- parser.add_argument("--filter_weight", type=ast.literal_eval, default=False, help="Filter weight parameters, default is False.")
- parser.add_argument("--run_platform", type=str, default="Ascend", choices="Ascend", help="run platform, only support Ascend.")
-
- args_opt = parser.parse_args()
- device_id = int(os.getenv('DEVICE_ID'), 0)
- device_num = int(os.getenv("RANK_SIZE", 1))
-
- args_opt.distribute = True if device_num > 1 else False
-
- if args_opt.run_platform == "Ascend":
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
- if args_opt.distribute:
- context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
- init()
- context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
- device_num=device_num)
- else:
- raise ValueError("Unsupported platform.")
-
-
- local_data_url = "/cache/data/" + str(device_id)
- mox.file.make_dirs(local_data_url)
-
- local_train_url = "/cache/ckpt"
- mox.file.make_dirs(local_train_url)
-
- filename = "EfficientDet.mindrecord0"
-
- # 生成 EfficientDet.mindrecord
- mox.file.copy_parallel(args_opt.data_url, local_data_url)
- local_data_path = os.path.join(local_data_url, filename)
-
- dataset = create_EfficientDet_datasets(local_data_path, repeat_num=1,
- num_parallel_workers=args_opt.workers,
- batch_size=args_opt.batch_size, device_num=device_num, rank=device_id)
- dataset_size = dataset.get_dataset_size()
-
- print("Create dataset done!")
-
- net = EfficientDetBackbone(90, 0, False, True)
-
- init_weights(net)
-
- loss = FocalLoss()
-
- loss_scale = 1024.
-
- lr = Tensor(get_lr_cosine(init_lr=0.012, steps_per_epoch=dataset_size, warmup_epochs=int(args_opt.epoch_size / 20),
- max_epoch=args_opt.epoch_size, t_max=args_opt.epoch_size, eta_min=0.0))
-
- opt = nn.Momentum(net.trainable_params(), lr,
- config.momentum, config.weight_decay, loss_scale=loss_scale)
- # opt = nn.AdamWeightDecay(params=net.trainable_params(), learning_rate=1e-4)
-
- net_withloss = WithLossCell(net, loss)
-
- network = TrainOneStepCell(net_withloss, opt, sens=loss_scale)
-
- network.set_train()
-
- model = Model(network, amp_level="O0")
-
- transferCb = TransferCallback(local_train_url, args_opt.train_url)
-
- # cb = [Monitor(lr_init=lr.asnumpy())]
- cb = [LossMonitor(), TimeMonitor()]
-
- config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs,
- keep_checkpoint_max=config.keep_checkpoint_max)
-
- ckpt_cb = ModelCheckpoint(prefix="EfficientDet", directory=local_train_url, config=config_ck)
- print("============== Starting Training ==============")
-
- if device_id == 0:
- cb += [ckpt_cb, transferCb]
-
- # is_sink = True if args_opt.distribute else False
-
- # 分析算子的性能
- model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
-
- # profiler.analyse()
- print("============== End Training ==============")
-
-
- if __name__ == '__main__':
- main()
|