|
- #!/usr/bin/python
- # -*- encoding: utf-8 -*-
-
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.cuda.amp as amp
-
-
-
- ##
- # version 1: use torch.autograd
- class LabelSmoothSoftmaxCEV1(nn.Module):
- '''
- This is the autograd version, you can also try the LabelSmoothSoftmaxCEV2 that uses derived gradients
- '''
-
- def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100):
- super(LabelSmoothSoftmaxCEV1, self).__init__()
- self.lb_smooth = lb_smooth
- self.reduction = reduction
- self.lb_ignore = ignore_index
- self.log_softmax = nn.LogSoftmax(dim=1)
-
- def forward(self, logits, label):
- '''
- Same usage method as nn.CrossEntropyLoss:
- >>> criteria = LabelSmoothSoftmaxCEV1()
- >>> logits = torch.randn(8, 19, 384, 384) # nchw, float/half
- >>> lbs = torch.randint(0, 19, (8, 384, 384)) # nhw, int64_t
- >>> loss = criteria(logits, lbs)
- '''
- # overcome ignored label
- logits = logits.float() # use fp32 to avoid nan
- with torch.no_grad():
- num_classes = logits.size(1)
- label = label.clone().detach()
- ignore = label.eq(self.lb_ignore)
- n_valid = ignore.eq(0).sum()
- label[ignore] = 0
- lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes
- lb_one_hot = torch.empty_like(logits).fill_(
- lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()
-
- logs = self.log_softmax(logits)
- loss = -torch.sum(logs * lb_one_hot, dim=1)
- loss[ignore] = 0
- if self.reduction == 'mean':
- loss = loss.sum() / n_valid
- if self.reduction == 'sum':
- loss = loss.sum()
-
- return loss
-
-
-
- ##
- # version 2: user derived grad computation
- class LSRCrossEntropyFunctionV2(torch.autograd.Function):
-
- @staticmethod
- @amp.custom_fwd(cast_inputs=torch.float32)
- def forward(ctx, logits, label, lb_smooth, lb_ignore):
- # prepare label
- num_classes = logits.size(1)
- lb_pos, lb_neg = 1. - lb_smooth, lb_smooth / num_classes
- label = label.clone().detach()
- ignore = label.eq(lb_ignore)
- n_valid = ignore.eq(0).sum()
- label[ignore] = 0
- lb_one_hot = torch.empty_like(logits).fill_(
- lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()
-
- ignore = ignore.nonzero(as_tuple=False)
- _, M = ignore.size()
- a, *b = ignore.chunk(M, dim=1)
- mask = [a, torch.arange(logits.size(1)), *b]
- lb_one_hot[mask] = 0
- coeff = (num_classes - 1) * lb_neg + lb_pos
-
- ctx.variables = coeff, mask, logits, lb_one_hot
-
- loss = torch.log_softmax(logits, dim=1).neg_().mul_(lb_one_hot).sum(dim=1)
- return loss
-
- @staticmethod
- @amp.custom_bwd
- def backward(ctx, grad_output):
- coeff, mask, logits, lb_one_hot = ctx.variables
-
- scores = torch.softmax(logits, dim=1).mul_(coeff)
- grad = scores.sub_(lb_one_hot).mul_(grad_output.unsqueeze(1))
- grad[mask] = 0
- return grad, None, None, None
-
-
- class LabelSmoothSoftmaxCEV2(nn.Module):
-
- def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100):
- super(LabelSmoothSoftmaxCEV2, self).__init__()
- self.lb_smooth = lb_smooth
- self.reduction = reduction
- self.lb_ignore = ignore_index
-
- def forward(self, logits, labels):
- '''
- Same usage method as nn.CrossEntropyLoss:
- >>> criteria = LabelSmoothSoftmaxCEV2()
- >>> logits = torch.randn(8, 19, 384, 384) # nchw, float/half
- >>> lbs = torch.randint(0, 19, (8, 384, 384)) # nhw, int64_t
- >>> loss = criteria(logits, lbs)
- '''
- losses = LSRCrossEntropyFunctionV2.apply(
- logits, labels, self.lb_smooth, self.lb_ignore)
- if self.reduction == 'sum':
- losses = losses.sum()
- elif self.reduction == 'mean':
- n_valid = (labels != self.lb_ignore).sum()
- losses = losses.sum() / n_valid
- return losses
-
- ##
- # version 3: implement wit cpp/cuda to save memory and accelerate
- import lsr_cpp
- class LSRCrossEntropyFunctionV3(torch.autograd.Function):
- '''
- use cpp/cuda to accelerate and shrink memory usage
- '''
- @staticmethod
- @amp.custom_fwd(cast_inputs=torch.float32)
- def forward(ctx, logits, labels, lb_smooth, lb_ignore):
- losses = lsr_cpp.lsr_forward(logits, labels, lb_ignore, lb_smooth)
-
- ctx.variables = logits, labels, lb_ignore, lb_smooth
- return losses
-
- @staticmethod
- @amp.custom_bwd
- def backward(ctx, grad_output):
- logits, labels, lb_ignore, lb_smooth = ctx.variables
-
- grad = lsr_cpp.lsr_backward(logits, labels, lb_ignore, lb_smooth)
- grad.mul_(grad_output.unsqueeze(1))
- return grad, None, None, None
-
-
- class LabelSmoothSoftmaxCEV3(nn.Module):
-
- def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100):
- super(LabelSmoothSoftmaxCEV3, self).__init__()
- self.lb_smooth = lb_smooth
- self.reduction = reduction
- self.lb_ignore = ignore_index
-
- def forward(self, logits, labels):
- '''
- Same usage method as nn.CrossEntropyLoss:
- >>> criteria = LabelSmoothSoftmaxCEV3()
- >>> logits = torch.randn(8, 19, 384, 384) # nchw, float/half
- >>> lbs = torch.randint(0, 19, (8, 384, 384)) # nhw, int64_t
- >>> loss = criteria(logits, lbs)
- '''
- losses = LSRCrossEntropyFunctionV3.apply(
- logits, labels, self.lb_smooth, self.lb_ignore)
- if self.reduction == 'sum':
- losses = losses.sum()
- elif self.reduction == 'mean':
- n_valid = (labels != self.lb_ignore).sum()
- losses = losses.sum() / n_valid
- return losses
-
-
- if __name__ == '__main__':
- import torchvision
- import torch
- import numpy as np
- import random
- torch.manual_seed(15)
- random.seed(15)
- np.random.seed(15)
- torch.backends.cudnn.deterministic = True
-
- class Model(nn.Module):
- def __init__(self):
- super(Model, self).__init__()
- net = torchvision.models.resnet18(pretrained=False)
- self.conv1 = net.conv1
- self.bn1 = net.bn1
- self.maxpool = net.maxpool
- self.relu = net.relu
- self.layer1 = net.layer1
- self.layer2 = net.layer2
- self.layer3 = net.layer3
- self.layer4 = net.layer4
- self.fc = nn.Conv2d(512, 19, 3, 1, 1)
- def forward(self, x):
- feat = self.conv1(x)
- feat = self.bn1(feat)
- feat = self.relu(feat)
- feat = self.maxpool(feat)
- feat = self.layer1(feat)
- feat = self.layer2(feat)
- feat = self.layer3(feat)
- feat = self.layer4(feat)
- feat = self.fc(feat)
- out = F.interpolate(feat, x.size()[2:], mode='bilinear', align_corners=True)
- return out
-
- net1 = Model()
- net2 = Model()
- net2.load_state_dict(net1.state_dict())
- red = 'mean'
- criteria1 = LabelSmoothSoftmaxCEV2(lb_smooth=0.1, ignore_index=255, reduction=red)
- criteria2 = LabelSmoothSoftmaxCEV1(lb_smooth=0.1, ignore_index=255, reduction=red)
- net1.cuda()
- net2.cuda()
- net1.train()
- net2.train()
- criteria1.cuda()
- criteria2.cuda()
-
- optim1 = torch.optim.SGD(net1.parameters(), lr=1e-2)
- optim2 = torch.optim.SGD(net2.parameters(), lr=1e-2)
-
- bs = 64
- for it in range(300):
- inten = torch.randn(bs, 3, 224, 224).cuda()
- lbs = torch.randint(0, 19, (bs, 224, 224)).cuda()
- lbs[1, 1, 1] = 255
- lbs[30, 3, 2:200] = 255
- lbs[18, 4:7, 8:200] = 255
- logits = net1(inten)
- loss1 = criteria1(logits, lbs)
- optim1.zero_grad()
- loss1.backward()
- optim1.step()
- # print(net1.fc.weight[:, :5])
- logits = net2(inten)
- loss2 = criteria2(logits, lbs)
- optim2.zero_grad()
- loss2.backward()
- optim2.step()
- # net1.load_state_dict(net2.state_dict())
- # print(net2.fc.weight[:, :5])
- with torch.no_grad():
- if (it+1) % 50 == 0:
- print('iter: {}, ================='.format(it+1))
- # print(net1.fc.weight.numel())
- print('fc weight: ', torch.mean(torch.abs(net1.fc.weight - net2.fc.weight)).item())
-
- print('conv1 weight: ', torch.mean(torch.abs(net1.conv1.weight - net2.conv1.weight)).item())
- print('loss: ', loss1.item() - loss2.item())
|