什么是ROC曲线?可以参见(https://blog.csdn.net/hesongzefairy/article/details/104295431

现在我们知道ROC曲线上的一组组(FPR,TPR)值是通过改变阈值得到,那么具体在程序中是如何实现的?

首先我们需要了解sklearn.metrics中的roc_curve方法(metrics是度量、指标,curve是曲线)

roc_curve(y_true, y_score, pos_label=None, sample_weight=None, drop_intermediate=None)

参数含义:

y_true:简单来说就是label,范围在(0,1)或(-1,1)的二进制标签,若非二进制则需提供pos_label。

y_score:模型预测的类别概率值。

pos_label:label中被认定为正样本的标签。若label=[1,2,2,1]且pos_label=2,则2为positive,其他为negative。

sample_weight:采样权重,可选择一部分来计算。

drop_intermediate:可以去掉一些对ROC曲线不好的阈值,使得曲线展现的性能更好。

返回值:(tpr,fpr,thershold)

tpr:根据不同阈值得到一组tpr值。

fpr:根据不同阈值的到一组fpr值,与tpr一一对应。(这两个值就是绘制ROC曲线的关键)

thresholds:选择的不同阈值,按照降序排列。

有了以上参数和返回值的基础,举个例子来验证一下:

from sklearn.metrics import roc_curve

y_label = ([1, 1, 1, 2, 2, 2])  # 非二进制需要pos_label
y_pre = ([0.3, 0.5, 0.9, 0.8, 0.4, 0.6])
fpr, tpr, thersholds = roc_curve(y_label, y_pre, pos_label=2)

for i, value in enumerate(thersholds):
    print("%f %f %f" % (fpr[i], tpr[i], value))

输出结果:

分析:

代码中故意将label设置为非二进制,展示了pos_label这个参数的用处,如果pos_label设置错误程序会报错,大家可以自己尝试,y_label是标签,y_pre和y_labe相对应是模型预测为label的概率值,当pos_label=2则标签为2的看作正样本,标签为1的看做负样本。

我们都知道ROC曲线是需要通过改变阈值来获取一组组(fprp, tpr),那么roc_curve方法中是如何选取阈值

从输出结果可以看到,第三列代表返回值thersholds记录的就是roc_curve所选取的阈值,其阈值就是将y_pre降序排列并依次选取,如果选取的阈值对fpr和tpr值无影响则忽略,输出结果中没有记录阈值为0.8时情况。

需要注意的是,输出结果第一行therholds=1.9,这个值很奇怪,乍一看不知道为什么会出现这个值。这里我们阅读sklearn官网的原文对therholds这个参数的解释:

Decreasing thresholds on the decision function used to compute fpr and tpr. thresholds[0] represents no instances being predicted and is arbitrarily set to max(y_score) + 1.

thresholds[0] 表示没有任何预测的实例 并且被设置为max(y_score) + 1,这样就知道1.9其实是y_pre中最大的0.9 +1,再简单点来说,第一行其实就是ROC曲线的起点(0,0)。

到这里roc_curve方法的用法应该已经非常清楚了,画ROC曲线之前还有一个评估模型优劣重要的值AUC需要得到。

算AUC的方法很简单,使用auc方法即可。

from sklearn.metrics import auc

roc_auc = auc(fpr, tpr)

最后就是画出ROC曲线了,完整代码如下:

from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

y_label = ([1, 1, 1, 2, 2, 2])  # 非二进制需要pos_label
y_pre = ([0.3, 0.5, 0.9, 0.8, 0.4, 0.6])
fpr, tpr, thersholds = roc_curve(y_label, y_pre, pos_label=2)

for i, value in enumerate(thersholds):
    print("%f %f %f" % (fpr[i], tpr[i], value))

roc_auc = auc(fpr, tpr)

plt.plot(fpr, tpr, 'k--', label='ROC (area = {0:.2f})'.format(roc_auc), lw=2)

plt.xlim([-0.05, 1.05])  # 设置x、y轴的上下限,以免和边缘重合,更好的观察图像的整体
plt.ylim([-0.05, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')  # 可以使用中文,但需要导入一些库即字体
plt.title('ROC Curve')
plt.legend(loc="lower right")
plt.show()

输出结果:

 

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐