|
- import math
- import argparse
- import options.options as options
- import mindspore
- import mindspore.nn as nn
- import mindspore.ops as ops
- import numpy as np
- from models.modules.RRDBNet_arch import RRDBNet
- from models.modules.FlowUpsamplerNet import FlowUpsamplerNet
- import models.modules.thops as thops
- import models.modules.flow as flow
- from uutils.util import opt_get
- from mindspore import context
-
-
- class SRFlowNet(nn.Cell):
- def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, K=None, opt=None, step=None):
- super(SRFlowNet, self).__init__()
-
- self.opt = opt
- self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
- None else opt_get(opt, ['datasets', 'train', 'quant'])
- self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
- hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels'])
- hidden_channels = hidden_channels or 64
- self.RRDB_training = True # Default is true
-
- train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
- set_RRDB_to_train = False
- if set_RRDB_to_train:
- self.set_rrdb_training(True)
-
- self.flowUpsamplerNet = \
- FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
- flow_coupling=opt['network_G']['flow']['coupling'], opt=opt)
- self.i = 0
- self.logp = flow.logp()
- self.noiseQuant = opt_get(opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True)
- self.block_idxs = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
- self.RRDB_concat = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'concat'])
- self.log = ops.Log()
- self.cat = ops.Concat(axis=1)
- self.zeros_like = ops.ZerosLike()
- self.rand = ops.UniformReal()
- self.loge_2 = Tensor(np.log(2.))
-
- self.print = ops.Print()
-
-
- def set_rrdb_training(self, trainable):
- if self.RRDB_training != trainable:
- for p in self.RRDB.parameters():
- p.requires_grad = trainable
- self.RRDB_training = trainable
- return True
- return False
-
- def construct(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
- lr_enc=None,
- add_gt_noise=False, step=None, y_label=None):
- if not reverse:
- return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
- y_onehot=y_label)
- else:
- return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
- add_gt_noise=add_gt_noise)
-
- def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None):
- if lr_enc is None:
- lr_enc = self.rrdbPreprocessing(lr)
-
- logdet = self.zeros_like(gt[:, 0, 0, 0])
- pixels = gt.shape[2] * gt.shape[3]
-
- z = gt
-
- if add_gt_noise:
- # Setup
-
- if self.noiseQuant:
- z = z + ((self.rand(z.shape) - 0.5) / self.quant)
- logdet = logdet + float(-self.log(self.quant) * pixels)
-
- # Encode
- epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses,
- y_onehot=y_onehot)
-
- ##objective = logdet.clone()
- objective = logdet
-
- # if isinstance(epses, (list, tuple)):
- # z = epses[-1]
- # else:
- z = epses
-
- objective = objective + self.logp(None, None, z)
-
- nll = (-objective) / (self.loge_2 * pixels)
-
- # if isinstance(epses, list):
- # return epses, nll, logdet
- return z, nll, logdet
-
- def rrdbPreprocessing(self, lr):
- rrdbResults = self.RRDB(lr, get_steps=True)
-
- if len(self.block_idxs) > 0:
- concat = self.cat((rrdbResults["block_0"], rrdbResults['block_1'], rrdbResults["block_2"], rrdbResults['block_3']))
-
- if self.RRDB_concat or False:
- keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
- if 'fea_up0' in rrdbResults.keys():
- keys.append('fea_up0')
- if 'fea_up-1' in rrdbResults.keys():
- keys.append('fea_up-1')
- if self.opt['scale'] >= 8:
- keys.append('fea_up8')
- if self.opt['scale'] == 16:
- keys.append('fea_up16')
- for k in keys:
- h = rrdbResults[k].shape[2]
- w = rrdbResults[k].shape[3]
- resize_nearest = ops.ResizeNearestNeighbor([h, w])
- rrdbResults[k] = self.cat((rrdbResults[k], resize_nearest(concat)))
- return rrdbResults
-
- def get_score(self, disc_loss_sigma, z):
- score_real = 0.5 * (1 - 1 / (disc_loss_sigma ** 2)) * thops.sum(z ** 2, dim=[1, 2, 3]) - \
- z.shape[1] * z.shape[2] * z.shape[3] * math.log(disc_loss_sigma)
- return -score_real
-
- def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True):
- logdet = self.zeros_like(lr[:, 0, 0, 0])
- pixels = lr.shape[2] * lr.shape[3] * self.opt['scale'] ** 2
-
- if add_gt_noise:
- logdet = logdet - (-self.log(self.quant) * pixels)
-
- if lr_enc is None:
- lr_enc = self.rrdbPreprocessing(lr)
-
- x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses,
- logdet=logdet)
-
- return x, logdet
-
-
- if __name__ == "__main__":
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
- context.set_context(device_id=1)
- # parser = argparse.ArgumentParser()
- # parser.add_argument('--config', default='/root/xidian_wks/jiang/SRFlow/data/SRFlow_CelebA_8X.yml')
- # args = parser.parse_args()
- # opt = options.parse(args.config, is_train=True)
- # net = SRFlowNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, scale=8, K=16, opt=opt)
- # print(net)
- # gt = mindspore.Tensor(np.random.rand(2, 3, 160, 160).astype(np.float32))
- # lr = mindspore.Tensor(np.random.rand(2, 3, 20, 20).astype(np.float32))
- # out = net(gt, lr)
- # print(out)
- # print(len(out))
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--config', default='/root/xidian_wks/jiang/SRFlow/data/SRFlow_DF2K_4X.yml')
- args = parser.parse_args()
- opt = options.parse(args.config, is_train=True)
- net = SRFlowNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, scale=4, K=16, opt=opt)
- print(net)
- gt = mindspore.Tensor(np.random.rand(12, 3, 160, 160).astype(np.float32))
- lr = mindspore.Tensor(np.random.rand(12, 3, 40, 40).astype(np.float32))
- out = net(gt, lr)
- print(out)
- print(len(out))
|