4.2.2 多元分类
当超越简单的二元分类时,可能会经常处理关于多个类别的问题,如或。这在一定程度上限制了我们衡量错误或成功的方式。
考察图4.6中关于多元分类的混淆矩阵。
图4.6 多元分类的混淆矩阵
从图中可以看出,正类或负类的概念已经消失了,因为不再只有正类和负类,而是有限类的集合:
单个的类可以是字符串或数字,只要它们遵循集合的规则。也就是说,类别集合C必须是有限的和唯一的。
为了测量ACC,我们将计算混淆矩阵中主对角线上的所有元素,并将其除以样本总数:
式中表示混淆矩阵,为迹运算,也就是说计算方阵中所有主对角线元素之和。因此,总误差为1-ACC。但是,在样本数据点的类别分布不平衡的情况下,误差度量指标或简单的准确度指标可能具有欺骗性。为此,我们必须使用BER度量指标,对于多个类别的情形,可以将其定义为:
在新的BER公式中,表示混淆矩阵中第j行第i列的元素。
一些机器学习学派使用混淆矩阵的行表示真实标签,用列表示预测标签。它们背后的指标分析理论和对指标的解释都是一样的。不要担心sklearn使用的是翻转过来的方法,这是两码事,在接下来的讨论中你应该不会有这方面的问题。
作为示例,考虑图4.1所示的数据集。如果我们运行一个五层的神经网络分类器,那么可以得到图4.7中的判定边界。
图4.7 使用五层神经网络对二维数据集的分类效果
显然,数据集不能被非线性超平面完全准确分类,每个类别都有一些跨越边界的数据点。在前图中,可以看到只有Summer类没有基于分类边界被错误分类的点。
如果进行实际计算并显示混淆矩阵,就会更加明显,如图4.8所示。
图4.8 由二维数据集样本上的训练误差得到的混淆矩阵
在这种情况下,准确度的算式为,由此可以得到ACC的值为0.94,看起来似乎不错,错误率仅为1-ACC=0.06。这个特定的例子有轻微的样本类别分布不平衡。下面是每个类别的样本数:
·Summer:25
·Fall:25
·Winter:24
·Spring:26
冬季组的样本数比其他组的要少,春季组的样本数比其他组的要多。虽然这是一个非常小的类别不平衡,但是已经足以产生具有欺骗性的低错误率。我们现在必须计算平均错误率BER。
BER的计算方法如下:
在这里,BER和错误率之间的差值不足0.01%。然而,对于高度不平衡的类别,差距可能会更大,我们有责任仔细测量并报告适当的误差度量值BER。
另一个关于BER的有趣事实是,它在直观上是平衡准确度的对应,这就意味着,如果我们去掉BER方程中的1项,将得到平衡准确度。更进一步,如果我们检查分子各项,就可以发现其中作为加项的每个分数分别表示的某个特定类别的准确度;例如,第一类Summer准确率的为100%,第二类Fall的准确率为92%,以此类推。
在Python中,sklearn库有一个类,它可以根据真实标签和预测标签自动确定混淆矩阵。这个类名为confusion_matrix,它属于metrics超类,可以这样使用:
如果y包含真标签且y_pred包含预测标签,那么上述指令将输出如下:
可以这样简单地计算BER:
输出如下:
另外,sklearn有一个内置函数来计算与混淆矩阵具有相同超类中的平衡准确度。这个类名为balanced_accuracy_score,可以通过以下操作计算BER:
输出如下:
现在,我们来讨论回归分析矩阵。