|
- import os
- import time
-
- import mindspore
- import mindspore.nn as nn
- from mindspore.communication.management import init, get_rank
- from mindspore import save_checkpoint, context, load_checkpoint, load_param_into_net
- from mindspore.context import ParallelMode
- from src.datasets.dataset import RBPNDataset, create_train_dataset
- from src.loss.generatorloss import GeneratorLoss
- from model.rbpn import Net as RBPN
- from mindspore.nn.dynamic_lr import piecewise_constant_lr
- from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, _InternalCallbackParam, RunContext
- from trainonestepgen import TrainOnestepGen
- # from trainonestep import TrainOnestepGen
- from mindspore.ops import functional as F
- from src.util.utils import init_weights
- import numpy as np
- import argparse
- import ast
- import moxing as mox
- import zipfile
-
-
- environment = 'train'
- if environment == 'debug':
- workroot = '/home/ma-user/work' # 调试任务使用该参数
- else:
- workroot = '/home/work/user-job-dir' # 训练任务使用该参数
- print('current work mode:' + environment + ', workroot:' + workroot)
-
- parser = argparse.ArgumentParser(description='RBPN-mindspore')
-
- parser.add_argument('--batchSize', type=int, default=4, help='training batch size')
- parser.add_argument('--testBatchSize', type=int, default=5, help='testing batch size')
- parser.add_argument('--nEpochs', type=int, default=210, help='number of epochs to train for')
- parser.add_argument('--lr', type=float, default=1e-4, help='Learning Rate. Default=0.01')
- parser.add_argument('--threads', type=int, default=8, help='number of threads for data loader to use')
- parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
- parser.add_argument('--other_dataset', type=bool, default=False, help="use other dataset than vimeo-90k")
- parser.add_argument('--future_frame', type=bool, default=True, help="use future frame")
- parser.add_argument('--nFrames', type=int, default=7)
- parser.add_argument('--patch_size', type=int, default=64, help='0 to use original frame size')
- parser.add_argument('--data_augmentation', type=bool, default=True)
- parser.add_argument('--model_type', type=str, default='RBPN')
- parser.add_argument('--residual', type=bool, default=False)
- parser.add_argument('--pretrained_sr', default='69_RBPN.ckpt', help='sr pretrained base model')
- parser.add_argument('--pretrained', type=bool, default=False)
- parser.add_argument('--save_folder', default='weights/', help='Location to save checkpoint models')
-
- parser.add_argument('--Results', default='Results/gen/', help='eval image')
- parser.add_argument("--valDataset", type=str, default="vimeo", choices=["vimoe", "vid4"], help="eval dataset type")
-
- # data resource configuration
- parser.add_argument("--large_dataset", type=int, default=0, help="use large dataset, default: false.")
- parser.add_argument("--large_file", type=int, default=0, help="use large dataset file, default: false.")
- parser.add_argument('--data_url',
- help='path to training/inference dataset folder',
- default=workroot + '/data/')
- parser.add_argument('--train_url',
- help='model folder to save/load',
- default=workroot + '/model/')
- parser.add_argument('--eval_flag', type=ast.literal_eval, default=True,
- help="The flag means whether to eval while training")
- parser.add_argument('--upscale_factor', type=int, default=4, choices=[2, 4, 8],
- help="Super resolution upscale factor")
- parser.add_argument('--snapshots', type=int, default=10, help='Snapshots')
- parser.add_argument('--start_iter', type=int, default=1, help='Starting Epoch')
- parser.add_argument("--run_distribute", type=int, default=1, help="run distribute, default: false.")
- parser.add_argument(
- '--device_target',
- type=str,
- default="Ascend",
- choices=['Ascend', 'CPU'],
- help='device where the code will be implemented (default: CPU),若要在启智平台上使用NPU,需要在启智平台训练界面上加上运行参数device_target=Ascend')
- parser.add_argument("--device_id", type=int, default=0, help="device id, default: 0.")
- parser.add_argument("--device_num", type=int, default=8, help="number of device, default: 0.")
- parser.add_argument("--rank", type=int, default=0, help="rank id, default: 0.")
- # additional parameters
- parser.add_argument('--sens', type=float, default=1024.0)
-
- args = parser.parse_args()
- mindspore.set_seed(args.seed)
- epoch_loss = []
- best_avgpsnr = 0
- eval_mean_psnr = []
- save_loss_path = 'results/genloss/'
- if not os.path.exists(save_loss_path):
- os.makedirs(save_loss_path)
-
-
- def train(trainoneStep, trainds, eval_flag=False):
- trainoneStep.set_train()
- trainoneStep.set_grad()
- steps = trainds.get_dataset_size()
- train_loader = trainds.create_dict_iterator()
- for epoch in range(args.start_iter, args.nEpochs + 1):
- e_loss = 0
- timenow = []
- i = 0
- t0 = time.time()
- for iteration, batch in enumerate(train_loader, 1):
- # print("''''''''''''''''''''''''''''''''''''''''''''''''''''")
- # print("当前epoch所经过的时间:", t01 - t0)
-
- input = batch['input_image']
- target = batch['target_image']
- neighbor_tensor = batch['neighbor_image']
- flow_tensor = batch['flow_image']
-
- loss = trainoneStep(target, input, neighbor_tensor, flow_tensor)
-
- # prediction = model(input, neigbor_tensor, flow_tensor)
- # # prediction = prediction.cpu()
- # prediction = prediction[0].asnumpy().astype(np.float32)
- # prediction = prediction * 255.
- #
- # target = target[0].asnumpy().astype(np.float32)
- # target = target * 255.
- #
- # psnr_predicted = PSNR(prediction, target, shave_border=args.upscale_factor)
- # print("Epoch[{}]第{}次的PSNR为{}:".format(epoch,iteration, psnr_predicted))
-
- e_loss += loss.asnumpy()
- print('Epoch[{}]({}/{}): loss: {:.4f}'.format(epoch, iteration, steps, loss.asnumpy()))
-
-
- mean = e_loss / steps
- epoch_loss.append(mean)
- t1 = time.time()
- print("Epoch {} Complete: Avg. Loss: {:.4f}|| Time: {} min {}s.".format(epoch, mean, int((t1 - t0) / 60),
- int(int(t1 - t0) % 60)))
- step_time = (t1 - t0) / steps
- print('per step needs time:{:.2f}ms'.format(step_time * 1000))
-
-
-
- if (epoch + 1) % (args.snapshots) == 0:
- print('===> Saving model')
- print('train_dir:', train_dir)
- save_checkpoint_path = train_dir + '/device_' + os.getenv('DEVICE_ID') + '/'
- if not os.path.exists(save_checkpoint_path):
- os.makedirs(save_checkpoint_path)
-
- model_name = 'rbpn_epoch%d.ckpt'%(epoch)
- # ckpt_dir_path = os.path.join(train_dir, f'rbpn_epoch{epoch}.ckpt')
-
- ckpt_dir_path = os.path.join(save_checkpoint_path, model_name)
- print("ckpt position:" , ckpt_dir_path)
- file_names = os.listdir(train_dir)
- train_dir_list = []
- for i in file_names:
- train_dir_list.append(i)
- print("********************train_dir_list:", train_dir_list)
- save_checkpoint(trainoneStep.network, ckpt_dir_path)
-
-
-
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
-
-
- if __name__ == "__main__":
- args = parser.parse_args()
- print(args)
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
- if args.run_distribute:
- print("distribute")
- device_id = int(os.getenv("DEVICE_ID"))
- device_num = args.device_num
- context.set_context(device_id=device_id)
- init()
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
- device_num=device_num)
- rank = get_rank()
- else:
- device_id = args.device_id
- # device_id = int(os.getenv("DEVICE_ID"))
- context.set_context(device_id=device_id)
-
- home = os.path.dirname(os.path.realpath(__file__))
-
-
- if args.large_file ==1:
- file_list = home + '/sep_trainlist.txt'
- else:
- file_list = home + '/fast12.txt'
- ######################## 将数据集从obs拷贝到训练镜像中 (固定写法)########################
- # 在训练环境中定义data_url和train_url,并把数据从obs拷贝到相应的固定路径
- data_dir = os.path.join(home, 'data') # 数据集存放路径
- train_dir = os.path.join(home, 'checkpoints') # 模型存放路径
- # 初始化数据存放目录
- if not os.path.exists(data_dir):
- os.mkdir(data_dir)
- # 初始化模型存放目录
- obs_train_url = args.train_url
- if not os.path.exists(train_dir):
- os.mkdir(train_dir)
-
-
- ######################## 将数据集从obs拷贝到训练镜像中 (固定写法)########################
- # 在训练环境中定义data_url和train_url,并把数据从obs拷贝到相应的固定路径,以下写法是将数据拷贝到/home/work/user-job-dir/data/目录下,可修改为其他目录
- # 创建数据存放的位置
- if environment == 'train':
- obs_data_url = args.data_url
- # 将数据拷贝到训练环境
- try:
- mox.file.copy_parallel(obs_data_url, data_dir)
- print("Successfully Download {} to {}".format(obs_data_url,
- data_dir))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- obs_data_url, data_dir) + str(e))
-
-
- path = home
- datanames = os.listdir(path)
- list = []
- for i in datanames:
- list.append(i)
- print("********************list:", list)
-
- # 数据集选择
- if args.large_dataset==1:
- zip_out_dir = home + '/data/vimeo_septuplet/sequences'
- else:
- zip_out_dir = home + '/data/vimeo1/sequences'
-
-
- print("Preparing Data")
- start_time = time.perf_counter()
- train_dataset = RBPNDataset(zip_out_dir, args.nFrames, args.upscale_factor, args.data_augmentation,
- file_list, args.other_dataset, args.patch_size, args.future_frame)
- train_ds = create_train_dataset(train_dataset, args)
- train_steps = train_ds.get_dataset_size()
-
- end_time = time.perf_counter()
- print("preparing data use: {}min".format((end_time - start_time) / 60))
- # model
-
- print('===>Building model ', args.model_type)
- model = RBPN(num_channels=3, base_filter=256, feat=64, num_stages=3, n_resblock=5, nFrames=args.nFrames,
- scale_factor=args.upscale_factor)
- init_weights(model, 'KaimingNormal', 0.02)
- # init_weights(model, 'normal', 0.02)
- print('====>start training')
-
- if args.pretrained:
- ckpt = os.path.join(home, args.pretrained_sr)
- # ckpt = args.pretrained_sr
- print('=====> load params into generator')
- params = load_checkpoint(ckpt)
- load_param_into_net(model, params)
- print('=====> finish load generator')
-
- lossNetwork = GeneratorLoss(model)
-
- # milestone = [int(args.nEpochs / 2) * train_steps, args.nEpochs * train_steps]
- # learning_rates = [args.lr, args.lr / 10.0]
- # lr = piecewise_constant_lr(milestone, learning_rates)
-
- milestone = [int(args.nEpochs / 3)*train_steps , int(args.nEpochs/3)*2*train_steps , args.nEpochs*train_steps]
- learning_rates = [args.lr, args.lr / 10.0 , args.lr /100.0]
- lr = piecewise_constant_lr(milestone, learning_rates)
-
- optimizer = nn.Adam(model.trainable_params(), lr, loss_scale=args.sens)
-
-
-
-
- trainonestepNet = TrainOnestepGen(lossNetwork, optimizer, sens=args.sens )
-
- train(trainonestepNet, train_ds, args.eval_flag)
- print(train_dataset)
-
- if environment == 'train':
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,
- obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,
- obs_train_url) + str(e))
|