|
- #!/usr/bin/python
- # -*- encoding: utf-8 -*-
-
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
-
-
- class GeneralizedSoftDiceLoss(nn.Module):
-
- def __init__(self,
- p=1,
- smooth=1,
- reduction='mean',
- weight=None,
- ignore_lb=255):
- super(GeneralizedSoftDiceLoss, self).__init__()
- self.p = p
- self.smooth = smooth
- self.reduction = reduction
- self.weight = None if weight is None else torch.tensor(weight)
- self.ignore_lb = ignore_lb
-
- def forward(self, logits, label):
- '''
- args: logits: tensor of shape (N, C, H, W)
- args: label: tensor of shape(N, H, W)
- '''
- # overcome ignored label
- logits = logits.float()
- ignore = label.data.cpu() == self.ignore_lb
- label = label.clone()
- label[ignore] = 0
- lb_one_hot = torch.zeros_like(logits).scatter_(1, label.unsqueeze(1), 1)
- ignore = ignore.nonzero()
- _, M = ignore.size()
- a, *b = ignore.chunk(M, dim=1)
- lb_one_hot[[a, torch.arange(lb_one_hot.size(1)).long(), *b]] = 0
- lb_one_hot = lb_one_hot.detach()
-
- # compute loss
- probs = torch.sigmoid(logits)
- numer = torch.sum((probs*lb_one_hot), dim=(2, 3))
- denom = torch.sum(probs.pow(self.p)+lb_one_hot.pow(self.p), dim=(2, 3))
- if not self.weight is None:
- numer = numer * self.weight.view(1, -1)
- denom = denom * self.weight.view(1, -1)
- numer = torch.sum(numer, dim=1)
- denom = torch.sum(denom, dim=1)
- loss = 1 - (2*numer+self.smooth)/(denom+self.smooth)
-
- if self.reduction == 'mean':
- loss = loss.mean()
- return loss
-
-
- class BatchSoftDiceLoss(nn.Module):
-
- def __init__(self,
- p=1,
- smooth=1,
- weight=None,
- ignore_lb=255):
- super(BatchSoftDiceLoss, self).__init__()
- self.p = p
- self.smooth = smooth
- self.weight = None if weight is None else torch.tensor(weight)
- self.ignore_lb = ignore_lb
-
- def forward(self, logits, label):
- '''
- args: logits: tensor of shape (N, C, H, W)
- args: label: tensor of shape(N, H, W)
- '''
- # overcome ignored label
- logits = logits.float()
- ignore = label.data.cpu() == self.ignore_lb
- label = label.clone()
- label[ignore] = 0
- lb_one_hot = torch.zeros_like(logits).scatter_(1, label.unsqueeze(1), 1)
- ignore = ignore.nonzero()
- _, M = ignore.size()
- a, *b = ignore.chunk(M, dim=1)
- lb_one_hot[[a, torch.arange(lb_one_hot.size(1)).long(), *b]] = 0
- lb_one_hot = lb_one_hot.detach()
-
- # compute loss
- probs = torch.sigmoid(logits)
- numer = torch.sum((probs*lb_one_hot), dim=(2, 3))
- denom = torch.sum(probs.pow(self.p)+lb_one_hot.pow(self.p), dim=(2, 3))
- if not self.weight is None:
- numer = numer * self.weight.view(1, -1)
- denom = denom * self.weight.view(1, -1)
- numer = torch.sum(numer)
- denom = torch.sum(denom)
- loss = 1 - (2*numer+self.smooth)/(denom+self.smooth)
- return loss
-
-
- if __name__ == '__main__':
- import torchvision
- import torch
- import numpy as np
- import random
- torch.manual_seed(15)
- random.seed(15)
- np.random.seed(15)
- torch.backends.cudnn.deterministic = True
-
- class Model(nn.Module):
- def __init__(self):
- super(Model, self).__init__()
- net = torchvision.models.resnet18(pretrained=False)
- self.conv1 = net.conv1
- self.bn1 = net.bn1
- self.maxpool = net.maxpool
- self.relu = net.relu
- self.layer1 = net.layer1
- self.layer2 = net.layer2
- self.layer3 = net.layer3
- self.layer4 = net.layer4
- self.out = nn.Conv2d(512, 1, 3, 1, 1)
- def forward(self, x):
- feat = self.conv1(x)
- feat = self.bn1(feat)
- feat = self.relu(feat)
- feat = self.maxpool(feat)
- feat = self.layer1(feat)
- feat = self.layer2(feat)
- feat = self.layer3(feat)
- feat = self.layer4(feat)
- feat = self.out(feat)
- out = F.interpolate(feat, x.size()[2:], mode='bilinear', align_corners=True)
- return out
- net1 = Model()
- net2 = Model()
- net2.load_state_dict(net1.state_dict())
-
- criteria1 = SoftDiceLossV1()
- criteria2 = SoftDiceLossV3()
- net1.cuda()
- net2.cuda()
- net1.train()
- net2.train()
- criteria1.cuda()
- criteria2.cuda()
-
- optim1 = torch.optim.SGD(net1.parameters(), lr=1e-2)
- optim2 = torch.optim.SGD(net2.parameters(), lr=1e-2)
-
- bs = 2
- for it in range(300000):
- inten = torch.randn(bs, 3, 224, 244).cuda()
- lbs = torch.randint(0, 2, (bs, 224, 244)).cuda()
- logits = net1(inten).squeeze(1)
- loss1 = criteria1(logits, lbs)
- optim1.zero_grad()
- loss1.backward()
- optim1.step()
- logits = net2(inten).squeeze(1)
- loss2 = criteria2(logits, lbs)
- optim2.zero_grad()
- loss2.backward()
- optim2.step()
- # print('====')
- # print(loss1.item())
- # print(loss2.item())
-
- with torch.no_grad():
- if (it+1) % 50 == 0:
- print('iter: {}, ================='.format(it+1))
- print('out.weight: ', torch.mean(torch.abs(net1.out.weight - net2.out.weight)).item())
- print('conv1.weight: ', torch.mean(torch.abs(net1.conv1.weight - net2.conv1.weight)).item())
- print('loss: ', loss1.item() - loss2.item())
|