|
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.cuda.amp as amp
-
-
- ##
- # version 1: use torch.autograd
- class FocalLossV1(nn.Module):
-
- def __init__(self,
- alpha=0.25,
- gamma=2,
- reduction='mean',):
- super(FocalLossV1, self).__init__()
- self.alpha = alpha
- self.gamma = gamma
- self.reduction = reduction
- self.crit = nn.BCEWithLogitsLoss(reduction='none')
-
- def forward(self, logits, label):
- '''
- Usage is same as nn.BCEWithLogits:
- >>> criteria = FocalLossV1()
- >>> logits = torch.randn(8, 19, 384, 384)
- >>> lbs = torch.randint(0, 2, (8, 19, 384, 384)).float()
- >>> loss = criteria(logits, lbs)
- '''
- probs = torch.sigmoid(logits)
- coeff = torch.abs(label - probs).pow(self.gamma).neg()
- log_probs = torch.where(logits >= 0,
- F.softplus(logits, -1, 50),
- logits - F.softplus(logits, 1, 50))
- log_1_probs = torch.where(logits >= 0,
- -logits + F.softplus(logits, -1, 50),
- -F.softplus(logits, 1, 50))
- loss = label * self.alpha * log_probs + (1. - label) * (1. - self.alpha) * log_1_probs
- loss = loss * coeff
-
- if self.reduction == 'mean':
- loss = loss.mean()
- if self.reduction == 'sum':
- loss = loss.sum()
- return loss
-
-
- ##
- # version 2: user derived grad computation
- class FocalSigmoidLossFuncV2(torch.autograd.Function):
- '''
- compute backward directly for better numeric stability
- '''
- @staticmethod
- @amp.custom_fwd(cast_inputs=torch.float32)
- def forward(ctx, logits, label, alpha, gamma):
- # logits = logits.float()
-
- probs = torch.sigmoid(logits)
- coeff = (label - probs).abs_().pow_(gamma).neg_()
- log_probs = torch.where(logits >= 0,
- F.softplus(logits, -1, 50),
- logits - F.softplus(logits, 1, 50))
- log_1_probs = torch.where(logits >= 0,
- -logits + F.softplus(logits, -1, 50),
- -F.softplus(logits, 1, 50))
- ce_term1 = log_probs.mul_(label).mul_(alpha)
- ce_term2 = log_1_probs.mul_(1. - label).mul_(1. - alpha)
- ce = ce_term1.add_(ce_term2)
- loss = ce * coeff
-
- ctx.vars = (coeff, probs, ce, label, gamma, alpha)
-
- return loss
-
- @staticmethod
- @amp.custom_bwd
- def backward(ctx, grad_output):
- '''
- compute gradient of focal loss
- '''
- (coeff, probs, ce, label, gamma, alpha) = ctx.vars
-
- d_coeff = (label - probs).abs_().pow_(gamma - 1.).mul_(gamma)
- d_coeff.mul_(probs).mul_(1. - probs)
- d_coeff = torch.where(label < probs, d_coeff.neg(), d_coeff)
- term1 = d_coeff.mul_(ce)
-
- d_ce = label * alpha
- d_ce.sub_(probs.mul_((label * alpha).mul_(2).add_(1).sub_(label).sub_(alpha)))
- term2 = d_ce.mul(coeff)
-
- grads = term1.add_(term2)
- grads.mul_(grad_output)
-
- return grads, None, None, None
-
-
- class FocalLossV2(nn.Module):
-
- def __init__(self,
- alpha=0.25,
- gamma=2,
- reduction='mean'):
- super(FocalLossV2, self).__init__()
- self.alpha = alpha
- self.gamma = gamma
- self.reduction = reduction
-
- def forward(self, logits, label):
- '''
- Usage is same as nn.BCEWithLogits:
- >>> criteria = FocalLossV2()
- >>> logits = torch.randn(8, 19, 384, 384)
- >>> lbs = torch.randint(0, 2, (8, 19, 384, 384)).float()
- >>> loss = criteria(logits, lbs)
- '''
- loss = FocalSigmoidLossFuncV2.apply(logits, label, self.alpha, self.gamma)
- if self.reduction == 'mean':
- loss = loss.mean()
- if self.reduction == 'sum':
- loss = loss.sum()
- return loss
-
-
- ##
- # version 3: implement wit cpp/cuda to save memory and accelerate
- import focal_cpp # import torch before import cpp extension
- class FocalSigmoidLossFuncV3(torch.autograd.Function):
- '''
- use cpp/cuda to accelerate and shrink memory usage
- '''
- @staticmethod
- @amp.custom_fwd(cast_inputs=torch.float32)
- def forward(ctx, logits, labels, alpha, gamma):
- # logits = logits.float()
- loss = focal_cpp.focalloss_forward(logits, labels, gamma, alpha)
- ctx.variables = logits, labels, alpha, gamma
- return loss
-
- @staticmethod
- @amp.custom_bwd
- def backward(ctx, grad_output):
- '''
- compute gradient of focal loss
- '''
- logits, labels, alpha, gamma = ctx.variables
- grads = focal_cpp.focalloss_backward(grad_output, logits, labels, gamma, alpha)
- return grads, None, None, None
-
-
- class FocalLossV3(nn.Module):
- '''
- This use better formula to compute the gradient, which has better numeric stability. Also use cuda to shrink memory usage and accelerate.
- '''
- def __init__(self,
- alpha=0.25,
- gamma=2,
- reduction='mean'):
- super(FocalLossV3, self).__init__()
- self.alpha = alpha
- self.gamma = gamma
- self.reduction = reduction
-
- def forward(self, logits, label):
- '''
- Usage is same as nn.BCEWithLogits:
- >>> criteria = FocalLossV3()
- >>> logits = torch.randn(8, 19, 384, 384)
- >>> lbs = torch.randint(0, 2, (8, 19, 384, 384)).float()
- >>> loss = criteria(logits, lbs)
- '''
- loss = FocalSigmoidLossFuncV3.apply(logits, label, self.alpha, self.gamma)
- if self.reduction == 'mean':
- loss = loss.mean()
- if self.reduction == 'sum':
- loss = loss.sum()
- return loss
-
-
-
-
-
- if __name__ == '__main__':
- import torchvision
- import torch
- import numpy as np
- import random
- torch.manual_seed(15)
- random.seed(15)
- np.random.seed(15)
- torch.backends.cudnn.deterministic = True
-
- class Model(nn.Module):
- def __init__(self):
- super(Model, self).__init__()
- net = torchvision.models.resnet18(pretrained=False)
- self.conv1 = net.conv1
- self.bn1 = net.bn1
- self.maxpool = net.maxpool
- self.relu = net.relu
- self.layer1 = net.layer1
- self.layer2 = net.layer2
- self.layer3 = net.layer3
- self.layer4 = net.layer4
- self.out = nn.Conv2d(512, 3, 3, 1, 1)
- def forward(self, x):
- feat = self.conv1(x)
- feat = self.bn1(feat)
- feat = self.relu(feat)
- feat = self.maxpool(feat)
- feat = self.layer1(feat)
- feat = self.layer2(feat)
- feat = self.layer3(feat)
- feat = self.layer4(feat)
- feat = self.out(feat)
- out = F.interpolate(feat, x.size()[2:], mode='bilinear', align_corners=True)
- return out
- net1 = Model()
- net2 = Model()
- net2.load_state_dict(net1.state_dict())
-
- criteria1 = FocalLossV2()
- criteria2 = FocalLossV3()
- net1.cuda()
- net2.cuda()
- net1.train()
- net2.train()
- net1.double()
- net2.double()
- criteria1.cuda()
- criteria2.cuda()
-
- optim1 = torch.optim.SGD(net1.parameters(), lr=1e-2)
- optim2 = torch.optim.SGD(net2.parameters(), lr=1e-2)
-
- bs = 16
- for it in range(300000):
- inten = torch.randn(bs, 3, 224, 244).cuda()
- # lbs = torch.randint(0, 2, (bs, 3, 224, 244)).float().cuda()
- lbs = torch.randn(bs, 3, 224, 244).sigmoid().cuda()
- inten = inten.double()
- lbs = lbs.double()
- logits = net1(inten)
- loss1 = criteria1(logits, lbs)
- optim1.zero_grad()
- loss1.backward()
- optim1.step()
- logits = net2(inten)
- loss2 = criteria2(logits, lbs)
- optim2.zero_grad()
- loss2.backward()
- optim2.step()
- with torch.no_grad():
- if (it+1) % 50 == 0:
- print('iter: {}, ================='.format(it+1))
- print('out.weight: ', torch.mean(torch.abs(net1.out.weight - net2.out.weight)).item())
- print('conv1.weight: ', torch.mean(torch.abs(net1.conv1.weight - net2.conv1.weight)).item())
- print('loss: ', loss1.item() - loss2.item())
|