|
- import torch
- import os
- from torch import nn
- import numpy as np
- import torch.nn.functional
- from collections import OrderedDict
- from termcolor import colored
-
-
- def sigmoid(x):
- y = torch.clamp(x.sigmoid(), min=1e-4, max=1 - 1e-4)
- return y
-
-
- def _neg_loss(pred, gt):
- ''' Modified focal loss. Exactly the same as CornerNet.
- Runs faster and costs a little bit more memory
- Arguments:
- pred (batch x c x h x w)
- gt_regr (batch x c x h x w)
- '''
- pos_inds = gt.eq(1).float()
- neg_inds = gt.lt(1).float()
-
- neg_weights = torch.pow(1 - gt, 4)
-
- loss = 0
-
- pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
- neg_loss = torch.log(1 - pred) * torch.pow(pred,
- 2) * neg_weights * neg_inds
-
- num_pos = pos_inds.float().sum()
- pos_loss = pos_loss.sum()
- neg_loss = neg_loss.sum()
-
- if num_pos == 0:
- loss = loss - neg_loss
- else:
- loss = loss - (pos_loss + neg_loss) / num_pos
- return loss
-
-
- class FocalLoss(nn.Module):
- '''nn.Module warpper for focal loss'''
- def __init__(self):
- super(FocalLoss, self).__init__()
- self.neg_loss = _neg_loss
-
- def forward(self, out, target):
- return self.neg_loss(out, target)
-
-
- def smooth_l1_loss(vertex_pred,
- vertex_targets,
- vertex_weights,
- sigma=1.0,
- normalize=True,
- reduce=True):
- """
- :param vertex_pred: [b, vn*2, h, w]
- :param vertex_targets: [b, vn*2, h, w]
- :param vertex_weights: [b, 1, h, w]
- :param sigma:
- :param normalize:
- :param reduce:
- :return:
- """
- b, ver_dim, _, _ = vertex_pred.shape
- sigma_2 = sigma**2
- vertex_diff = vertex_pred - vertex_targets
- diff = vertex_weights * vertex_diff
- abs_diff = torch.abs(diff)
- smoothL1_sign = (abs_diff < 1. / sigma_2).detach().float()
- in_loss = torch.pow(diff, 2) * (sigma_2 / 2.) * smoothL1_sign \
- + (abs_diff - (0.5 / sigma_2)) * (1. - smoothL1_sign)
-
- if normalize:
- in_loss = torch.sum(in_loss.view(b, -1), 1) / (
- ver_dim * torch.sum(vertex_weights.view(b, -1), 1) + 1e-3)
-
- if reduce:
- in_loss = torch.mean(in_loss)
-
- return in_loss
-
-
- class SmoothL1Loss(nn.Module):
- def __init__(self):
- super(SmoothL1Loss, self).__init__()
- self.smooth_l1_loss = smooth_l1_loss
-
- def forward(self,
- preds,
- targets,
- weights,
- sigma=1.0,
- normalize=True,
- reduce=True):
- return self.smooth_l1_loss(preds, targets, weights, sigma, normalize,
- reduce)
-
-
- class AELoss(nn.Module):
- def __init__(self):
- super(AELoss, self).__init__()
-
- def forward(self, ae, ind, ind_mask):
- """
- ae: [b, 1, h, w]
- ind: [b, max_objs, max_parts]
- ind_mask: [b, max_objs, max_parts]
- obj_mask: [b, max_objs]
- """
- # first index
- b, _, h, w = ae.shape
- b, max_objs, max_parts = ind.shape
- obj_mask = torch.sum(ind_mask, dim=2) != 0
-
- ae = ae.view(b, h * w, 1)
- seed_ind = ind.view(b, max_objs * max_parts, 1)
- tag = ae.gather(1, seed_ind).view(b, max_objs, max_parts)
-
- # compute the mean
- tag_mean = tag * ind_mask
- tag_mean = tag_mean.sum(2) / (ind_mask.sum(2) + 1e-4)
-
- # pull ae of the same object to their mean
- pull_dist = (tag - tag_mean.unsqueeze(2)).pow(2) * ind_mask
- obj_num = obj_mask.sum(dim=1).float()
- pull = (pull_dist.sum(dim=(1, 2)) / (obj_num + 1e-4)).sum()
- pull /= b
-
- # push away the mean of different objects
- push_dist = torch.abs(tag_mean.unsqueeze(1) - tag_mean.unsqueeze(2))
- push_dist = 1 - push_dist
- push_dist = nn.functional.relu(push_dist, inplace=True)
- obj_mask = (obj_mask.unsqueeze(1) + obj_mask.unsqueeze(2)) == 2
- push_dist = push_dist * obj_mask.float()
- push = ((push_dist.sum(dim=(1, 2)) - obj_num) /
- (obj_num * (obj_num - 1) + 1e-4)).sum()
- push /= b
- return pull, push
-
-
- class PolyMatchingLoss(nn.Module):
- def __init__(self, pnum):
- super(PolyMatchingLoss, self).__init__()
-
- self.pnum = pnum
- batch_size = 1
- pidxall = np.zeros(shape=(batch_size, pnum, pnum), dtype=np.int32)
- for b in range(batch_size):
- for i in range(pnum):
- pidx = (np.arange(pnum) + i) % pnum
- pidxall[b, i] = pidx
-
- device = torch.device('cuda')
- pidxall = torch.from_numpy(
- np.reshape(pidxall, newshape=(batch_size, -1))).to(device)
-
- self.feature_id = pidxall.unsqueeze_(2).long().expand(
- pidxall.size(0), pidxall.size(1), 2).detach()
-
- def forward(self, pred, gt, loss_type="L2"):
- pnum = self.pnum
- batch_size = pred.size()[0]
- feature_id = self.feature_id.expand(batch_size,
- self.feature_id.size(1), 2)
- device = torch.device('cuda')
-
- gt_expand = torch.gather(gt, 1,
- feature_id).view(batch_size, pnum, pnum, 2)
-
- pred_expand = pred.unsqueeze(1)
-
- dis = pred_expand - gt_expand
-
- if loss_type == "L2":
- dis = (dis**2).sum(3).sqrt().sum(2)
- elif loss_type == "L1":
- dis = torch.abs(dis).sum(3).sum(2)
-
- min_dis, min_id = torch.min(dis, dim=1, keepdim=True)
- # print(min_id)
-
- # min_id = torch.from_numpy(min_id.data.cpu().numpy()).to(device)
- # min_gt_id_to_gather = min_id.unsqueeze_(2).unsqueeze_(3).long().\
- # expand(min_id.size(0), min_id.size(1), gt_expand.size(2), gt_expand.size(3))
- # gt_right_order = torch.gather(gt_expand, 1, min_gt_id_to_gather).view(batch_size, pnum, 2)
-
- return torch.mean(min_dis)
-
-
- class AttentionLoss(nn.Module):
- def __init__(self, beta=4, gamma=0.5):
- super(AttentionLoss, self).__init__()
-
- self.beta = beta
- self.gamma = gamma
-
- def forward(self, pred, gt):
- num_pos = torch.sum(gt)
- num_neg = torch.sum(1 - gt)
- alpha = num_neg / (num_pos + num_neg)
- edge_beta = torch.pow(self.beta, torch.pow(1 - pred, self.gamma))
- bg_beta = torch.pow(self.beta, torch.pow(pred, self.gamma))
-
- loss = 0
- loss = loss - alpha * edge_beta * torch.log(pred) * gt
- loss = loss - (1 - alpha) * bg_beta * torch.log(1 - pred) * (1 - gt)
- return torch.mean(loss)
-
-
- def _gather_feat(feat, ind, mask=None):
- dim = feat.size(2)
- ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
- feat = feat.gather(1, ind)
- if mask is not None:
- mask = mask.unsqueeze(2).expand_as(feat)
- feat = feat[mask]
- feat = feat.view(-1, dim)
- return feat
-
-
- def _tranpose_and_gather_feat(feat, ind):
- feat = feat.permute(0, 2, 3, 1).contiguous()
- feat = feat.view(feat.size(0), -1, feat.size(3))
- feat = _gather_feat(feat, ind)
- return feat
-
-
- class Ind2dRegL1Loss(nn.Module):
- def __init__(self, type='l1'):
- super(Ind2dRegL1Loss, self).__init__()
- if type == 'l1':
- self.loss = torch.nn.functional.l1_loss
- elif type == 'smooth_l1':
- self.loss = torch.nn.functional.smooth_l1_loss
-
- def forward(self, output, target, ind, ind_mask):
- """ind: [b, max_objs, max_parts]"""
- b, max_objs, max_parts = ind.shape
- ind = ind.view(b, max_objs * max_parts)
- pred = _tranpose_and_gather_feat(output,
- ind).view(b, max_objs, max_parts,
- output.size(1))
- mask = ind_mask.unsqueeze(3).expand_as(pred)
- loss = self.loss(pred * mask, target * mask, reduction='sum')
- loss = loss / (mask.sum() + 1e-4)
- return loss
-
-
- class IndL1Loss1d(nn.Module):
- def __init__(self, type='l1'):
- super(IndL1Loss1d, self).__init__()
- if type == 'l1':
- self.loss = torch.nn.functional.l1_loss
- elif type == 'smooth_l1':
- self.loss = torch.nn.functional.smooth_l1_loss
-
- def forward(self, output, target, ind, weight):
- """ind: [b, n]"""
- output = _tranpose_and_gather_feat(output, ind)
- weight = weight.unsqueeze(2)
- loss = self.loss(output * weight, target * weight, reduction='sum')
- loss = loss / (weight.sum() * output.size(2) + 1e-4)
- return loss
-
-
- class GeoCrossEntropyLoss(nn.Module):
- def __init__(self):
- super(GeoCrossEntropyLoss, self).__init__()
-
- def forward(self, output, target, poly):
- output = torch.nn.functional.softmax(output, dim=1)
- output = torch.log(torch.clamp(output, min=1e-4))
- poly = poly.view(poly.size(0), 4, poly.size(1) // 4, 2)
- target = target[..., None, None].expand(poly.size(0), poly.size(1), 1,
- poly.size(3))
- target_poly = torch.gather(poly, 2, target)
- sigma = (poly[:, :, 0] - poly[:, :, 1]).pow(2).sum(2, keepdim=True)
- kernel = torch.exp(-(poly - target_poly).pow(2).sum(3) / (sigma / 3))
- loss = -(output * kernel.transpose(2, 1)).sum(1).mean()
- return loss
-
-
- def load_model(net,
- optim,
- scheduler,
- recorder,
- model_dir,
- resume=True,
- epoch=-1):
- if not resume:
- os.system('rm -rf {}'.format(model_dir))
-
- if not os.path.exists(model_dir):
- return 0
-
- pths = [
- int(pth.split('.')[0]) for pth in os.listdir(model_dir)
- if pth != 'latest.pth'
- ]
- if len(pths) == 0 and 'latest.pth' not in os.listdir(model_dir):
- return 0
- if epoch == -1:
- if 'latest.pth' in os.listdir(model_dir):
- pth = 'latest'
- else:
- pth = max(pths)
- else:
- pth = epoch
- print('load model: {}'.format(os.path.join(model_dir,
- '{}.pth'.format(pth))))
- pretrained_model = torch.load(
- os.path.join(model_dir, '{}.pth'.format(pth)), 'cpu')
- net.load_state_dict(pretrained_model['net'])
- optim.load_state_dict(pretrained_model['optim'])
- scheduler.load_state_dict(pretrained_model['scheduler'])
- recorder.load_state_dict(pretrained_model['recorder'])
- return pretrained_model['epoch'] + 1
-
-
- def save_model(net, optim, scheduler, recorder, model_dir, epoch, last=False):
- os.system('mkdir -p {}'.format(model_dir))
- model = {
- 'net': net.state_dict(),
- 'optim': optim.state_dict(),
- 'scheduler': scheduler.state_dict(),
- 'recorder': recorder.state_dict(),
- 'epoch': epoch
- }
- if last:
- torch.save(model, os.path.join(model_dir, 'latest.pth'))
- else:
- torch.save(model, os.path.join(model_dir, '{}.pth'.format(epoch)))
-
- # remove previous pretrained model if the number of models is too big
- pths = [
- int(pth.split('.')[0]) for pth in os.listdir(model_dir)
- if pth != 'latest.pth'
- ]
- if len(pths) <= 20:
- return
- os.system('rm {}'.format(
- os.path.join(model_dir, '{}.pth'.format(min(pths)))))
-
-
- def load_network(net, model_dir, resume=True, epoch=-1, strict=True):
- if not resume:
- return 0
-
- if not os.path.exists(model_dir):
- print(colored('pretrained model does not exist', 'red'))
- return 0
-
- if os.path.isdir(model_dir):
- pths = [
- int(pth.split('.')[0]) for pth in os.listdir(model_dir)
- if pth != 'latest.pth'
- ]
- if len(pths) == 0 and 'latest.pth' not in os.listdir(model_dir):
- return 0
- if epoch == -1:
- if 'latest.pth' in os.listdir(model_dir):
- pth = 'latest'
- else:
- pth = max(pths)
- else:
- pth = epoch
- model_path = os.path.join(model_dir, '{}.pth'.format(pth))
- else:
- model_path = model_dir
-
- print('load model: {}'.format(model_path))
- pretrained_model = torch.load(model_path)
- net.load_state_dict(pretrained_model['net'], strict=strict)
- return pretrained_model['epoch'] + 1
-
-
- def remove_net_prefix(net, prefix):
- net_ = OrderedDict()
- for k in net.keys():
- if k.startswith(prefix):
- net_[k[len(prefix):]] = net[k]
- else:
- net_[k] = net[k]
- return net_
-
-
- def add_net_prefix(net, prefix):
- net_ = OrderedDict()
- for k in net.keys():
- net_[prefix + k] = net[k]
- return net_
-
-
- def replace_net_prefix(net, orig_prefix, prefix):
- net_ = OrderedDict()
- for k in net.keys():
- if k.startswith(orig_prefix):
- net_[prefix + k[len(orig_prefix):]] = net[k]
- else:
- net_[k] = net[k]
- return net_
-
-
- def remove_net_layer(net, layers):
- keys = list(net.keys())
- for k in keys:
- for layer in layers:
- if k.startswith(layer):
- del net[k]
- return net
|