|
- # Copyright 2020-21 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- """train MaskRcnn and get checkpoint files."""
-
- import os
- import time
- import moxing as mox
-
- import mindspore.common.dtype as mstype
- from mindspore import context, Tensor
- from mindspore.communication.management import init, get_rank
- from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
- from mindspore.train import Model
- from mindspore.context import ParallelMode
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from mindspore.nn import Momentum
- from mindspore.common import set_seed
-
- from src.model_utils.config import config
- from src.model_utils.moxing_adapter import moxing_wrapper
- from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
- from src.maskrcnn_mobilenetv1.mask_rcnn_mobilenetv1 import Mask_Rcnn_Mobilenetv1
- from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
- from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
- from src.lr_schedule import dynamic_lr
-
-
- ### Copy single dataset from obs to training image###
- def ObsToEnv(obs_data_url, data_dir):
- try:
- mox.file.copy_parallel(obs_data_url, data_dir)
- print("Successfully Download {} to {}".format(obs_data_url, data_dir))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))
- #Set a cache file to determine whether the data has been copied to obs.
- #If this file exists during multi-card training, there is no need to copy the dataset multiple times.
- f = open("/cache/download_input.txt", 'w')
- f.close()
- try:
- if os.path.exists("/cache/download_input.txt"):
- print("download_input succeed")
- except Exception as e:
- print("download_input failed")
- return
- ### Copy the output to obs###
- def EnvToObs(train_dir, obs_train_url):
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
- return
- def DownloadFromQizhi(obs_data_url, data_dir):
- device_num = int(os.getenv('RANK_SIZE'))
- if device_num == 1:
- ObsToEnv(obs_data_url,data_dir)
- context.set_context(mode=context.GRAPH_MODE,device_target=config.device_target)
- if device_num > 1:
- # set device_id and init for multi-card training
- context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
- init()
- #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
- local_rank=int(os.getenv('RANK_ID'))
- if local_rank%8==0:
- ObsToEnv(obs_data_url,data_dir)
- #If the cache file does not exist, it means that the copy data has not been completed,
- #and Wait for 0th card to finish copying data
- while not os.path.exists("/cache/download_input.txt"):
- time.sleep(1)
- return
- def UploadToQizhi(train_dir, obs_train_url):
- device_num = int(os.getenv('RANK_SIZE'))
- local_rank=int(os.getenv('RANK_ID'))
- if device_num == 1:
- EnvToObs(train_dir, obs_train_url)
- if device_num > 1:
- if local_rank%8==0:
- EnvToObs(train_dir, obs_train_url)
- return
- ### Copy ckpt file from obs to inference image###
- ### To operate on folders, use mox.file.copy_parallel. If copying a file.
- ### Please use mox.file.copy to operate the file, this operation is to operate the file
- def ObsUrlToEnv(obs_ckpt_url, ckpt_url):
- try:
- mox.file.copy(obs_ckpt_url, ckpt_url)
- print("Successfully Download {} to {}".format(obs_ckpt_url,ckpt_url))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_ckpt_url, ckpt_url) + str(e))
- return
-
- def create_mindrecord_dir(prefix, mindrecord_dir, mindrecord_file):
- if not os.path.isdir(mindrecord_dir):
- os.makedirs(mindrecord_dir)
- if config.dataset == "coco":
- if os.path.isdir(config.coco_root):
- print("Create Mindrecord.")
- data_to_mindrecord_byte_image("coco", True, prefix, file_num=4)
- print("Create Mindrecord Done, at {}".format(mindrecord_dir))
- else:
- raise Exception("coco_root not exits.")
- else:
- if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH):
- print("Create Mindrecord.")
- data_to_mindrecord_byte_image("other", True, prefix)
- print("Create Mindrecord Done, at {}".format(mindrecord_dir))
- else:
- raise Exception("IMAGE_DIR or ANNO_PATH not exits.")
- while not os.path.exists(mindrecord_file+".db"):
- time.sleep(5)
-
- def load_pretrained_ckpt(net, load_path, device_target='Ascend'):
-
- if config.pretrain_epoch_size == 0:
- param_dict = load_checkpoint(load_path)
- new_param_dict = dict()
- for name, param in param_dict.items():
- new_name = name.replace('vovnet', 'backbone')
- param.name = new_name
- new_param_dict[new_name] = param
-
- load_param_into_net(net, new_param_dict)
- return net
-
- set_seed(1)
-
- def train_maskrcnn_mobilenetv1():
- data_dir = '/cache/data'
- train_dir = '/cache/output'
- ckpt_url = '/cache/checkpoint.ckpt'
- if not os.path.exists(data_dir):
- os.makedirs(data_dir)
- if not os.path.exists(train_dir):
- os.makedirs(train_dir)
- ###Copy dataset from obs to inference image
- DownloadFromQizhi(config.data_url, data_dir)
- ###Copy ckpt file from obs to inference image
- ObsUrlToEnv(config.ckpt_url, ckpt_url)
-
- config.mindrecord_dir = os.path.join(config.mindrecord_dir)
- print("Start training for maskrcnn_mobilenetv1! config:\n", config)
- if not config.do_eval and config.run_distribute:
- device_num = get_device_num()
- if config.device_target == "Ascend":
- init()
- rank = get_rank_id()
- context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
- gradients_mean=True)
- elif config.device_target == "GPU":
- init()
- rank = get_rank()
- context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
- gradients_mean=True)
- else:
- rank = 0
- device_num = 1
-
- print("Start create dataset!")
-
- # It will generate mindrecord file in config.mindrecord_dir,
- # and the file name is MaskRcnn.mindrecord0, 1, ... file_num.
- prefix = "MaskRcnn.mindrecord"
- mindrecord_dir = config.mindrecord_dir
- mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
- # create_mindrecord_files(rank, mindrecord_file, mindrecord_dir, prefix)
-
- if not config.only_create_dataset:
- # When create MindDataset, using the fitst mindrecord file, such as MaskRcnn.mindrecord0.
- dataset = create_maskrcnn_dataset(mindrecord_file, batch_size=config.batch_size,
- device_num=device_num, rank_id=rank)
-
- dataset_size = dataset.get_dataset_size()
- print("Create dataset done, and total images num: ", dataset_size)
-
- net = Mask_Rcnn_Mobilenetv1(config=config)
- net = net.set_train()
-
- load_path = ckpt_url
- if load_path != "":
- load_pretrained_ckpt(net, load_path)
-
- loss = LossNet()
- lr = Tensor(dynamic_lr(config, rank_size=device_num, start_steps=config.pretrain_epoch_size * dataset_size),
- mstype.float32)
- opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,
- weight_decay=config.weight_decay, loss_scale=config.loss_scale)
-
- net_with_loss = WithLossCell(net, loss)
- if config.run_distribute:
- net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True,
- mean=True, degree=device_num)
- else:
- net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale)
-
- time_cb = TimeMonitor(data_size=dataset_size)
- loss_cb = LossCallBack(rank_id=rank)
- cb = [time_cb, loss_cb]
- if config.save_checkpoint:
- ckptconfig = CheckpointConfig(save_checkpoint_steps=6*dataset_size,
- keep_checkpoint_max=config.keep_checkpoint_max)
- save_checkpoint_path = os.path.join(train_dir, 'ckpt_' + str(rank) + '/')
- ckpoint_cb = ModelCheckpoint(prefix='mask_rcnn', directory=save_checkpoint_path, config=ckptconfig)
- cb += [ckpoint_cb]
-
- model = Model(net)
- model.train(config.epoch_size, dataset, callbacks=cb)
- UploadToQizhi(train_dir, config.train_url)
-
- if __name__ == '__main__':
- train_maskrcnn_mobilenetv1()
|