|
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
-
- from sklearn.metrics import hamming_loss
- from sklearn.metrics import accuracy_score as accuracy_metric
- from sklearn.metrics import multilabel_confusion_matrix
- from sklearn.metrics import precision_recall_fscore_support
- from sklearn.metrics import average_precision_score
- from sklearn.preprocessing import binarize
-
- import numpy as np
-
- __all__ = ["multi_hot_encode", "hamming_distance", "accuracy_score", "precision_recall_fscore", "mean_average_precision"]
-
-
- def multi_hot_encode(logits, threshold=0.5):
- """
- Encode logits to multi-hot by elementwise for multilabel
- """
-
- return binarize(logits, threshold=threshold)
-
-
- def hamming_distance(output, target):
- """
- Soft metric based label for multilabel classification
- Returns:
- The smaller the return value is, the better model is.
- """
-
- return hamming_loss(target, output)
-
-
- def accuracy_score(output, target, base="sample"):
- """
- Hard metric for multilabel classification
- Args:
- output:
- target:
- base: ["sample", "label"], default="sample"
- if "sample", return metric score based sample,
- if "label", return metric score based label.
- Returns:
- accuracy:
- """
-
- assert base in ["sample", "label"], 'must be one of ["sample", "label"]'
-
- if base == "sample":
- accuracy = accuracy_metric(target, output)
- elif base == "label":
- mcm = multilabel_confusion_matrix(target, output)
- tns = mcm[:, 0, 0]
- fns = mcm[:, 1, 0]
- tps = mcm[:, 1, 1]
- fps = mcm[:, 0, 1]
-
- accuracy = (sum(tps) + sum(tns)) / (sum(tps) + sum(tns) + sum(fns) + sum(fps))
-
- return accuracy
-
-
- def precision_recall_fscore(output, target):
- """
- Metric based label for multilabel classification
- Returns:
- precisions:
- recalls:
- fscores:
- """
-
- precisions, recalls, fscores, _ = precision_recall_fscore_support(target, output)
-
- return precisions, recalls, fscores
-
-
- def mean_average_precision(logits, target):
- """
- Calculate average precision
- Args:
- logits: probability from network before sigmoid or softmax
- target: ground truth, 0 or 1
- """
- if not (isinstance(logits, np.ndarray) and isinstance(target, np.ndarray)):
- raise TypeError("logits and target should be np.ndarray.")
-
- aps = []
- for i in range(target.shape[1]):
- ap = average_precision_score(target[:, i], logits[:, i])
- aps.append(ap)
-
- return np.mean(aps)
|