|
-
-
- import os
- import time
-
- import mindspore
- import mindspore.nn as nn
- from mindspore.communication.management import init, get_rank
- from mindspore import save_checkpoint, context, load_checkpoint, load_param_into_net
- from mindspore.context import ParallelMode
- from src.datasets.dataset import RBPNDataset, create_train_dataset
- from src.util.config import get_args
- from src.loss.generatorloss import GeneratorLoss
- from src.model.rbpn import Net as RBPN
- from mindspore.nn.dynamic_lr import piecewise_constant_lr
- from trainonestepgen import TrainOnestepGen
- # from trainonestep import TrainOnestepGen
- from mindspore.ops import functional as F
- from src.util.utils import save_losses, init_weights , PSNR
- import numpy as np
-
- args = get_args()
- mindspore.set_seed(args.seed)
- epoch_loss = []
- best_avgpsnr = 0
- eval_mean_psnr = []
-
-
-
- save_eval_path = os.path.join(args.Results, args.valDataset, args.model_type)
- if not os.path.exists(save_eval_path):
- os.makedirs(save_eval_path)
-
- save_loss_path = 'results/genloss/'
- if not os.path.exists(save_loss_path):
- os.makedirs(save_loss_path)
-
-
- def train(trainoneStep, trainds, eval_flag=False):
- trainoneStep.set_train()
- trainoneStep.set_grad()
- steps = trainds.get_dataset_size()
-
- for epoch in range(args.start_iter, args.nEpochs + 1):
- e_loss = 0
- timenow = []
- i = 0
- t0 = time.time()
- for iteration, batch in enumerate(trainds.create_dict_iterator(), 1):
- t01 = time.time()
- # print("''''''''''''''''''''''''''''''''''''''''''''''''''''")
- # print("当前epoch所经过的时间:", t01 - t0)
-
- input = batch['input_image']
- target = batch['target_image']
- bicubic = batch['bicubic_image']
- neigbor_tensor = batch['neigbor_image']
- flow_tensor = batch['flow_image']
-
- loss = trainoneStep(target, input, neigbor_tensor, flow_tensor)
-
- # prediction = model(input, neigbor_tensor, flow_tensor)
- # # prediction = prediction.cpu()
- # prediction = prediction[0].asnumpy().astype(np.float32)
- # prediction = prediction * 255.
- #
- # target = target[0].asnumpy().astype(np.float32)
- # target = target * 255.
- #
- # psnr_predicted = PSNR(prediction, target, shave_border=args.upscale_factor)
- # print("Epoch[{}]第{}次的PSNR为{}:".format(epoch,iteration, psnr_predicted))
-
- e_loss += loss.asnumpy()
- print('Epoch[{}]({}/{}): loss: {:.4f}'.format(epoch, iteration, steps, loss.asnumpy()))
-
- t1 = time.time()
- mean = e_loss / steps
- epoch_loss.append(mean)
- print("Epoch {} Complete: Avg. Loss: {:.4f}|| Time: {} min {}s.".format(epoch, mean, int((t1 - t0) / 60),
- int(int(t1 - t0) % 60)))
- save_ckpt = os.path.join(args.save_folder, '{}_{}.ckpt'.format(epoch, args.model_type))
- save_checkpoint(trainoneStep.network, save_ckpt)
- name = os.path.join(save_loss_path, args.valDataset + '_' + args.model_type)
- save_losses(epoch_loss, None, name)
- et1 = time.time()
-
-
- if __name__ == '__main__':
- # distribute
- # parallel environment setting
- # context.set_context(mode=context.PYNATIVE_MODE, device_target2=args.device_target)
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
- if args.run_distribute:
- print("distribute")
- device_id = int(os.getenv("DEVICE_ID"))
- device_num = args.device_num
- context.set_context(device_id=device_id)
- init()
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
- device_num=device_num)
- rank = get_rank()
- else:
- device_id = args.device_id
- # device_id = int(os.getenv("DEVICE_ID"))
- context.set_context(device_id=device_id)
- # get dataset
- train_dataset = RBPNDataset(args.data_dir, args.nFrames, args.upscale_factor, args.data_augmentation, args.file_list, args.other_dataset, args.patch_size, args.future_frame)
- train_ds = create_train_dataset(train_dataset, args)
- train_loader = train_ds.create_dict_iterator()
- train_steps = train_ds.get_dataset_size()
-
- print('===>Building model ', args.model_type)
- model = RBPN(num_channels=3, base_filter=256, feat=64, num_stages=3, n_resblock=5, nFrames=args.nFrames,
- scale_factor=args.upscale_factor)
- # init_weights(model, 'KaimingNormal', 0.02)
- init_weights(model, 'normal', 0.02)
- print('====>start training')
-
- if args.load_pretrained:
- ckpt = os.path.join(args.save_folder, args.pretrained_sr)
- print('=====> load params into generator')
- params = load_checkpoint(ckpt)
- load_param_into_net(model, params)
- print('=====> finish load generator')
-
- lossNetwork = GeneratorLoss(model)
-
- milestone = [int(args.nEpochs / 2) * train_steps, args.nEpochs * train_steps]
- learning_rates = [args.lr, args.lr / 10.0]
- lr = piecewise_constant_lr(milestone, learning_rates)
- optimizer = nn.Adam(model.trainable_params(), lr, loss_scale=args.sens)
-
- trainonestepNet = TrainOnestepGen(lossNetwork, optimizer, sens=args.sens)
-
-
- train(trainonestepNet, train_ds , args.eval_flag)
- print(train_dataset)
-
-
-
-
-
-
-
-
-
-
|