美文网首页
AUC:使用Java计算AUC

AUC:使用Java计算AUC

作者: xiaogp | 来源:发表于2023-06-28 10:04 被阅读0次

关键词:AUCJAVA

业务背景

最近需要在Java业务代码上实现计算AUC的逻辑,网上没找到,因此手动实现一下

代码实战

参考这篇文章如何计算AUC的方法2,公式如下

公式

Java代码如下

public static double getAUCValue(List<Integer> labels, List<Double> predictions) throws Exception {
        if (labels.size() != predictions.size()) {
            throw new Exception("labels和predictions长度必须一致");
        }
        Map<Double, List<Integer>> map = new HashMap<>();
        int totalPositiveNum = 0;
        int totalNegativeNum = 0;
        for (int i = 0; i < labels.size(); i++) {
            int oneLabel = labels.get(i);
            double onePred = predictions.get(i);
            if (oneLabel == 1) {
                totalPositiveNum += 1;
            } else {
                totalNegativeNum += 1;
            }
            map.putIfAbsent(onePred, new ArrayList<>());
            map.get(onePred).add(oneLabel);
        }
        List<Double> sortPred = map.keySet().stream().sorted().collect(Collectors.toList());
        int startRank = 1;
        double pairAll = 0.0;
        for (double pred : sortPred) {
            List<Integer> list = map.get(pred);
            int positiveNum = list.stream().mapToInt(Integer::intValue).sum();
            int endRank = startRank + list.size() - 1;
            double avgRank = 1.0 * (startRank + endRank) / 2;
            startRank = endRank + 1;
            // 分子左边
            pairAll += positiveNum * avgRank;
        }
        double pairPositive = 1.0 * totalPositiveNum * (totalPositiveNum + 1) / 2;
        return (pairAll - pairPositive) / (totalPositiveNum * totalNegativeNum);
    }

测试结果0.6111111111111112

public static void main(String[] args) throws Exception {
        List<Integer> label = Arrays.asList(1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1);
        List<Double> probs = Arrays.asList(0.4, 0.8, 0.2, 0.4, 0.5, 0.7, 0.9, 0.6, 0.3, 0.2, 0.1, 0.1, 0.2, 0.3, 0.5, 0.8);
        double res = getAUCValue(label, probs);
        System.out.println(res);
    }

用Python的sklearn验证一下完全一致

>>> from sklearn.metrics import roc_auc_score
>>> a = [1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1]
>>> b = [0.4, 0.8, 0.2, 0.4, 0.5, 0.7, 0.9, 0.6, 0.3, 0.2, 0.1, 0.1, 0.2, 0.3, 0.5, 0.8]
>>> roc_auc_score(a, b)
0.6111111111111112

相关文章

网友评论

      本文标题:AUC:使用Java计算AUC

      本文链接:https://www.haomeiwen.com/subject/tbdoydtx.html