|
- import os
- import sys
- import argparse
- import numpy as np
- import moxing as mox
- import json
- from config import config
-
- import mindspore as ms
- import mindspore.nn as nn
- from mindspore import context
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
- from mindspore import load_checkpoint, load_param_into_net
- from mindspore.train import Model
- from mindspore.nn.metrics import Accuracy
- from mindspore.context import ParallelMode
- from mindspore.communication.management import init, get_rank
- import mindspore.ops as ops
- import time
- from upload import UploadOutput
-
- from src.data.dataset import create_openi_dataset
- from src.DETR import build_model
- from src.tools.cell import WithLossCell, WithGradCell
- from src.tools.average_meter import AverageMeter
-
-
- ### Copy multiple datasets from obs to training image ###
- def MultiObsToEnv(multi_data_url, data_dir):
- #--multi_data_url is json data, need to do json parsing for multi_data_url
- multi_data_json = json.loads(multi_data_url)
- for i in range(len(multi_data_json)):
- path = data_dir + "/" + multi_data_json[i]["dataset_name"]
- if not os.path.exists(path):
- os.makedirs(path)
- try:
- mox.file.copy_parallel(multi_data_json[i]["dataset_url"], path)
- print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],path))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- multi_data_json[i]["dataset_url"], path) + str(e))
- #Set a cache file to determine whether the data has been copied to obs.
- #If this file exists during multi-card training, there is no need to copy the dataset multiple times.
- f = open("/cache/download_input.txt", 'w')
- f.close()
- try:
- if os.path.exists("/cache/download_input.txt"):
- print("download_input succeed")
- except Exception as e:
- print("download_input failed")
- return
-
- def ObsUrlToEnv(obs_ckpt_url, ckpt_url):
- try:
- mox.file.copy(obs_ckpt_url, ckpt_url)
- print("Successfully Download {} to {}".format(obs_ckpt_url,ckpt_url))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_ckpt_url, ckpt_url) + str(e))
- return
-
- ### Copy the output to obs###
- def EnvToObs(train_dir, obs_train_url):
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
- return
-
- def DownloadFromQizhi(multi_data_url, data_dir):
- device_num = int(os.getenv('RANK_SIZE'))
- if device_num == 1:
- MultiObsToEnv(multi_data_url,data_dir)
- context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target)
- if device_num > 1:
- # set device_id and init for multi-card training
- context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
- init()
- #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
- local_rank=int(os.getenv('RANK_ID'))
- if local_rank%8==0:
- MultiObsToEnv(multi_data_url,data_dir)
- #If the cache file does not exist, it means that the copy data has not been completed,
- #and Wait for 0th card to finish copying data
- while not os.path.exists("/cache/download_input.txt"):
- time.sleep(1)
- return
-
- def UploadToQizhi(train_dir, obs_train_url):
- device_num = int(os.getenv('RANK_SIZE'))
- local_rank=int(os.getenv('RANK_ID'))
- if device_num == 1:
- EnvToObs(train_dir, obs_train_url)
- if device_num > 1:
- if local_rank%8==0:
- EnvToObs(train_dir, obs_train_url)
- return
-
- parser = argparse.ArgumentParser(description='MindSpore DETR')
- parser.add_argument('--data_url',
- help='path to training/inference dataset folder',
- default= '/cache/data/')
- parser.add_argument('--train_url',
- help='output folder to save/load',
- default= '/cache/output/')
- parser.add_argument('--ckpt_url',
- help='model to save/load',
- default= '/cache/checkpoint.ckpt')
- parser.add_argument('--multi_data_url',
- help='path to multi dataset',
- default= '/cache/data/')
- parser.add_argument(
- '--device_target',
- type=str,
- default="Ascend",
- choices=['Ascend', 'CPU'],
- help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
-
- parser.add_argument('--epoch_size',
- type=int,
- default=5,
- help='Training epochs.')
-
- if __name__ == "__main__":
- args, unknown = parser.parse_known_args()
- data_dir = '/cache/data'
- train_dir = '/cache/output'
- ckpt_url = '/cache/checkpoint.ckpt'
- try:
- if not os.path.exists(data_dir):
- os.makedirs(data_dir)
- if not os.path.exists(train_dir):
- os.makedirs(train_dir)
- except Exception as e:
- print("path already exists")
-
- ###Copy ckpt file from obs to training image
- ObsUrlToEnv(args.ckpt_url, ckpt_url)
-
- ###Copy data from obs to training image
- DownloadFromQizhi(args.multi_data_url, data_dir)
-
- ###The dataset path is used here:data_dir +"/train"
- device_num = int(os.getenv('RANK_SIZE'))
- rank = get_rank()
-
- dataset = create_openi_dataset(
- config,
- "/cache/data/train2017/train2017",
- "/cache/data/annotations_trainval2017/annotations/instances_train2017.json",
- batch_size=4,
- device_num=device_num,
- rank_id=rank)
-
- dataset_size = dataset.get_dataset_size()
- print("Create COCO dataset done!")
- print(f"COCO dataset num: {dataset_size}")
-
- # model
- net, criterion, postprocessors = build_model(config)
- params = net.trainable_params()
-
- ckpt = load_checkpoint(ckpt_url)
- # new_ckpt = {}
- # for k in ckpt.keys():
- # new_key = "backbone.backbone." + k
- # new_ckpt[new_key] = ckpt[k]
- # print("loaded weight", new_key)
-
- unloaded = load_param_into_net(net, ckpt, strict_load=True)
- print('backbone weights loaded')
-
- if not unloaded:
- print("all net weights loaded.")
- else:
- for u in unloaded:
- print(u, " unloaded")
-
- # data_dtype = ms.float16
- # net.to_float(data_dtype)
- # net.set_train()
-
- # # lr and optimizer
- # lr = nn.piecewise_constant_lr(
- # [dataset_size * config.lr_drop, dataset_size * config.epochs],
- # [config.lr, config.lr * 0.1]
- # )
- # lr_backbone = nn.piecewise_constant_lr(
- # [dataset_size * config.lr_drop, dataset_size * config.epochs],
- # [config.lr_backbone, config.lr_backbone * 0.1]
- # )
-
- # backbone_params = list(filter(lambda x: 'backbone' in x.name, net.trainable_params()))
- # no_backbone_params = list(filter(lambda x: 'backbone' not in x.name, net.trainable_params()))
- # param_dicts = [
- # {'params': backbone_params, 'lr': lr_backbone, 'weight_decay': config.weight_decay},
- # {'params': no_backbone_params, 'lr': lr, 'weight_decay': config.weight_decay}
- # ]
- # optimizer = nn.AdamWeightDecay(param_dicts)
-
- # # init mindspore model
- # scale_sense = nn.DynamicLossScaleUpdateCell(loss_scale_value=512, scale_factor=2, scale_window=1000)
-
- # net_with_loss = WithLossCell(net, criterion)
- # net_with_grad = WithGradCell(net_with_loss, optimizer, scale_sense, config.clip_max_norm)
-
- # print("Create DETR network done!")
-
- # # callbacks
- # loss_meter = AverageMeter()
- # data_loader = dataset.create_dict_iterator()
-
- # for e in range(config.start_epoch, config.epochs):
- # for i, data in enumerate(data_loader):
-
- # start_time = time.time()
- # img_data = data['image'].astype(data_dtype)
- # mask = data['mask'].astype(data_dtype)
- # boxes = data['boxes']
- # labels = data['labels']
- # valid = data['valid']
- # loss = net_with_grad(img_data, mask, boxes, labels, valid)
-
- # loss_meter.update(loss.asnumpy())
- # end_time = time.time()
-
- # if i % (dataset_size//50) == 0:
- # fps = config.batch_size / (end_time - start_time)
- # print('epoch[{}/{}], iter[{}/{}], loss:{:.4f}, fps:{:.2f} imgs/sec, lr:[{}/{}]'.format(
- # e, config.epochs,
- # i, dataset_size,
- # loss_meter.average(),
- # fps,
- # lr_backbone[e * dataset_size + i], lr[e * dataset_size + i]
- # ), flush=True)
- # loss_meter.reset()
-
- # if rank == 0: # save ckpt on device 0.
- # ckpt_path = os.path.join(train_dir, f'detr_epoch_{e}.ckpt')
- # ms.save_checkpoint(net, ckpt_path)
- # UploadToQizhi(train_dir,args.train_url)
|