|
-
- from __future__ import absolute_import
-
- import sys
- import torch
- import torch.nn as nn
- import torch.nn.init as init
- from torch.autograd import Variable
- import numpy as np
- from pdb import set_trace as st
- from skimage import color
- from IPython import embed
- from . import pretrained_networks as pn
-
- import lpips as util
-
- def spatial_average(in_tens, keepdim=True):
- return in_tens.mean([2,3],keepdim=keepdim)
-
- def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
- in_H = in_tens.shape[2]
- scale_factor = 1.*out_H/in_H
-
- return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
-
- # Learned perceptual metric
- class PNetLin(nn.Module):
- def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
- super(PNetLin, self).__init__()
-
- self.pnet_type = pnet_type
- self.pnet_tune = pnet_tune
- self.pnet_rand = pnet_rand
- self.spatial = spatial
- self.lpips = lpips
- self.version = version
- self.scaling_layer = ScalingLayer()
-
- if(self.pnet_type in ['vgg','vgg16']):
- net_type = pn.vgg16
- self.chns = [64,128,256,512,512]
- elif(self.pnet_type=='alex'):
- net_type = pn.alexnet
- self.chns = [64,192,384,256,256]
- elif(self.pnet_type=='squeeze'):
- net_type = pn.squeezenet
- self.chns = [64,128,256,384,384,512,512]
- self.L = len(self.chns)
-
- self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
-
- if(lpips):
- self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
- self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
- self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
- self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
- self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
- self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
- if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
- self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
- self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
- self.lins+=[self.lin5,self.lin6]
-
- def forward(self, in0, in1, retPerLayer=False):
- # v0.0 - original release had a bug, where input was not scaled
- in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
- outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
- feats0, feats1, diffs = {}, {}, {}
-
- for kk in range(self.L):
- feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
- diffs[kk] = (feats0[kk]-feats1[kk])**2
-
- if(self.lpips):
- if(self.spatial):
- res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
- else:
- res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
- else:
- if(self.spatial):
- res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
- else:
- res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
-
- val = res[0]
- for l in range(1,self.L):
- val += res[l]
-
- if(retPerLayer):
- return (val, res)
- else:
- return val
-
- class ScalingLayer(nn.Module):
- def __init__(self):
- super(ScalingLayer, self).__init__()
- self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
- self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
-
- def forward(self, inp):
- return (inp - self.shift) / self.scale
-
-
- class NetLinLayer(nn.Module):
- ''' A single linear layer which does a 1x1 conv '''
- def __init__(self, chn_in, chn_out=1, use_dropout=False):
- super(NetLinLayer, self).__init__()
-
- layers = [nn.Dropout(),] if(use_dropout) else []
- layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
- self.model = nn.Sequential(*layers)
-
-
- class Dist2LogitLayer(nn.Module):
- ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
- def __init__(self, chn_mid=32, use_sigmoid=True):
- super(Dist2LogitLayer, self).__init__()
-
- layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
- layers += [nn.LeakyReLU(0.2,True),]
- layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
- layers += [nn.LeakyReLU(0.2,True),]
- layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
- if(use_sigmoid):
- layers += [nn.Sigmoid(),]
- self.model = nn.Sequential(*layers)
-
- def forward(self,d0,d1,eps=0.1):
- return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
-
- class BCERankingLoss(nn.Module):
- def __init__(self, chn_mid=32):
- super(BCERankingLoss, self).__init__()
- self.net = Dist2LogitLayer(chn_mid=chn_mid)
- # self.parameters = list(self.net.parameters())
- self.loss = torch.nn.BCELoss()
-
- def forward(self, d0, d1, judge):
- per = (judge+1.)/2.
- self.logit = self.net.forward(d0,d1)
- return self.loss(self.logit, per)
-
- # L2, DSSIM metrics
- class FakeNet(nn.Module):
- def __init__(self, use_gpu=True, colorspace='Lab'):
- super(FakeNet, self).__init__()
- self.use_gpu = use_gpu
- self.colorspace=colorspace
-
- class L2(FakeNet):
-
- def forward(self, in0, in1, retPerLayer=None):
- assert(in0.size()[0]==1) # currently only supports batchSize 1
-
- if(self.colorspace=='RGB'):
- (N,C,X,Y) = in0.size()
- value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
- return value
- elif(self.colorspace=='Lab'):
- value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
- util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
- ret_var = Variable( torch.Tensor((value,) ) )
- if(self.use_gpu):
- ret_var = ret_var.cuda()
- return ret_var
-
- class DSSIM(FakeNet):
-
- def forward(self, in0, in1, retPerLayer=None):
- assert(in0.size()[0]==1) # currently only supports batchSize 1
-
- if(self.colorspace=='RGB'):
- value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
- elif(self.colorspace=='Lab'):
- value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
- util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
- ret_var = Variable( torch.Tensor((value,) ) )
- if(self.use_gpu):
- ret_var = ret_var.cuda()
- return ret_var
-
- def print_network(net):
- num_params = 0
- for param in net.parameters():
- num_params += param.numel()
- print('Network',net)
- print('Total number of parameters: %d' % num_params)
|