|
- """
- reference: https://github.com/jxgu1016/Total_Variation_Loss.pytorch
- The smaller the tv_loss, the smoother the image.
- """
- import torch
- import torch.nn as nn
- from torch.autograd import Variable
-
-
- class TVLoss(nn.Module):
- def __init__(self, TVLoss_weight=1):
- super(TVLoss, self).__init__()
- self.TVLoss_weight = TVLoss_weight
-
- def forward(self, x):
- batch_size = x.size()[0]
- h_x = x.size()[2]
- w_x = x.size()[3]
- count_h = self._tensor_size(x[:, :, 1:, :])
- count_w = self._tensor_size(x[:, :, :, 1:])
- h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
- w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
- return self.TVLoss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
-
- def _tensor_size(self, t):
- return t.size()[1] * t.size()[2] * t.size()[3]
-
-
- def main():
- # x = Variable(torch.FloatTensor([[[1,2],[2,3]],[[1,2],[2,3]]]).view(1,2,2,2), requires_grad=True)
- # x = Variable(torch.FloatTensor([[[3,1],[4,3]],[[3,1],[4,3]]]).view(1,2,2,2), requires_grad=True)
- # x = Variable(torch.FloatTensor([[[1,1,1], [2,2,2],[3,3,3]],[[1,1,1], [2,2,2],[3,3,3]]]).view(1, 2, 3, 3), requires_grad=True)
- x = Variable(torch.zeros((2, 3, 128, 128)), requires_grad=True)
- addition = TVLoss()
- z = addition(x)
- # print(x)
- print(z.data)
- # z.backward()
- # print(x.grad)
-
-
- if __name__ == '__main__':
- main()
|