|
- from __future__ import absolute_import
- import matplotlib.pyplot as plt
- import math
- import numpy as np
- import matplotlib.pyplot as plt
- from random import randint
- from torch import nn
- from sklearn.metrics import matthews_corrcoef,roc_curve,auc#MCC
- from sklearn.metrics import f1_score,precision_score#F1
- from sklearn.metrics import precision_recall_curve,precision_score,recall_score
- from sklearn import metrics#ROC
- from sklearn.metrics import roc_auc_score#AUC
- from torch.nn import functional as F
- import torch
- #import cv2
-
- __all__ = ['accuracy', 'AverageMeter']
-
- def accuracy(output, target, thr=0.5):#the percentage of detections that fall within a normalized distance of the ground truth
- ''' Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations
- First value to be returned is average accuracy across 'idxs', followed by individual accuracies
-
- '''
- if output.dim() > 2:
- v,i = torch.max(output,1);
- else:
- v,i = torch.max(output,1);
-
- #print(torch.sum(target.squeeze(1).long() == i.long()),target.numel())
- return torch.sum(target.long() == i.long()).float()/target.numel()
-
-
- class AverageMeter(object):
- """Computes and stores the average and current value"""
- def __init__(self):
- self.reset()
-
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
-
- def cal_fmeasure(precision, recall):
-
- fmeasure = [[(2 * p * r) / (p + r + 1e-10)] for p, r in zip(precision, recall)]
-
- fmeasure=np.array(fmeasure)
-
- max_fmeasure=max(fmeasure)
- return max_fmeasure
-
- def F_measure(target_var,output_mask):
-
- b,c,h,w = output_mask.shape
-
- output = output_mask.view(b,h*w)
- tgs = target_var.view(b,h*w)
-
- output = F.softmax(output,dim=1)
-
- tgs_count = tgs.split(split_size=1,dim=0)
- mask_count = output.split(split_size=1,dim=0)
-
- aucs = []
- f1s = []
-
- for y_scores,y_true in zip(mask_count,tgs_count):#2,hw hw
- y_scores = y_scores.view(h*w)
- y_scores = y_scores.numpy()
-
- y_true = y_true.view(h*w).numpy().astype(np.int32)
-
- precision, recall, threshold = precision_recall_curve(y_true, y_scores)
- auc_score = roc_auc_score(y_true, y_scores)
- maxf = cal_fmeasure(precision, recall)
-
- aucs.append(auc_score)
- f1s.append(maxf)
-
- return aucs, f1s
|