|
- from torch import nn
- import torch
- import torch.nn.functional as F
- from torch.autograd import Variable
-
- def cross_entropy_3D(input, target, weight=None, size_average=True):
- n, c, h, w, s = input.size()
- log_p = F.log_softmax(input, dim=1)
- log_p = log_p.transpose(1, 2).transpose(2, 3).transpose(3, 4).contiguous().view(-1, c)
- target = target.view(target.numel())
- loss = F.nll_loss(log_p, target, weight=weight, size_average=False)
- if size_average:
- loss /= float(target.numel())
- return loss
-
-
- class Binary_Loss(nn.Module):
- def __init__(self):
- super(Binary_Loss, self).__init__()
- self.criterion = nn.BCEWithLogitsLoss()
-
-
- def forward(self, model_output, targets):
- #targets[targets == 0] = -1
-
- # torch.empty(3, dtype=torch.long)
- # model_output = model_output.long()
- # targets = targets.long()
- # print(model_output)
- # print(F.sigmoid(model_output))
- # print(targets)
- # print('kkk')
- # model_output =torch.LongTensor(model_output.cpu())
- # targets =torch.LongTensor(targets.cpu())
- # model_output = model_output.type(torch.LongTensor)
- # targets = targets.type(torch.LongTensor)
- loss = self.criterion(model_output, targets)
-
-
- return loss
-
-
-
-
-
- def make_one_hot(input, num_classes):
- """Convert class index tensor to one hot encoding tensor.
- Args:
- input: A tensor of shape [N, 1, *]
- num_classes: An int of number of class
- Returns:
- A tensor of shape [N, num_classes, *]
- """
- shape = np.array(input.shape)
- shape[1] = num_classes
- shape = tuple(shape)
- result = torch.zeros(shape)
- result = result.scatter_(1, input.cpu(), 1)
-
- return result
-
- class BinaryDiceLoss(nn.Module):
- """Dice loss of binary class
- Args:
- smooth: A float number to smooth loss, and avoid NaN error, default: 1
- p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
- predict: A tensor of shape [N, *]
- target: A tensor of shape same with predict
- reduction: Reduction method to apply, return mean over batch if 'mean',
- return sum if 'sum', return a tensor of shape [N,] if 'none'
- Returns:
- Loss tensor according to arg reduction
- Raise:
- Exception if unexpected reduction
- """
- def __init__(self, smooth=1, p=2, reduction='mean'):
- super(BinaryDiceLoss, self).__init__()
- self.smooth = smooth
- self.p = p
- self.reduction = reduction
-
- def forward(self, predict, target):
- assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
- predict = predict.contiguous().view(predict.shape[0], -1)
- target = target.contiguous().view(target.shape[0], -1)
-
- num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth
- den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth
-
- loss = 1 - num / den
-
- if self.reduction == 'mean':
- return loss.mean()
- elif self.reduction == 'sum':
- return loss.sum()
- elif self.reduction == 'none':
- return loss
- else:
- raise Exception('Unexpected reduction {}'.format(self.reduction))
-
- class DiceLoss(nn.Module):
- def __init__(self, n_classes):
- super(DiceLoss, self).__init__()
- self.n_classes = n_classes
-
- def _one_hot_encoder(self, input_tensor):
- tensor_list = []
- for i in range(self.n_classes):
- temp_prob = input_tensor == i * torch.ones_like(input_tensor)
- tensor_list.append(temp_prob)
- output_tensor = torch.cat(tensor_list, dim=1)
- return output_tensor.float()
-
- def _dice_loss(self, score, target):
- target = target.float()
- smooth = 1e-5
- intersect = torch.sum(score * target)
- y_sum = torch.sum(target * target)
- z_sum = torch.sum(score * score)
- loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
- loss = 1 - loss
- return loss
-
- def forward(self, inputs, target, weight=None, softmax=False):
- if softmax:
- inputs = torch.softmax(inputs, dim=1)
- target = self._one_hot_encoder(target)
- if weight is None:
- weight = [1] * self.n_classes
- assert inputs.size() == target.size(), 'predict & target shape do not match'
- class_wise_dice = []
- loss = 0.0
- for i in range(0, self.n_classes):
- dice = self._dice_loss(inputs[:, i], target[:, i])
- class_wise_dice.append(1.0 - dice.item())
- loss += dice * weight[i]
- return loss / self.n_classes
|