深度学习实战:基于TensorFlow 2和Keras(原书第2版)
上QQ阅读APP看书,第一时间看更新

3.5 分类任务和决策边界

在上一节中,我们了解了回归或预测任务。在本节中,我们将讨论另一个重要任务:分类任务。首先介绍回归(有时也称为预测)与分类之间的区别:

  • 在分类中,数据被分组为类别/类目,而在回归中,目的是获得给定数据的连续性数值。
  • 例如,识别手写数字是一项分类任务。所有手写数字都属于0~9之间的十个数字之一。而根据不同的输入变量来预测房屋价格的任务是回归任务。
  • 在分类任务中,模型用来找到某一类与另一类分开的决策边界。在回归任务中,模型用来近似拟合输入输出关系的函数。
  • 分类是回归的子集。此处,我们正在预测类别,但回归更为普遍。

下图显示了分类任务和回归任务两者之间的区别。在分类中,我们需要找到一条线(或多维空间中的平面或超平面)以分隔各类。

在回归中,目的是找到适合给定输入点的线(或平面或超平面)。

097-02

下面我们将说明logistic回归是一种非常普遍且有用的分类技术。

3.5.1 logistic回归

logistic回归用于确定事件的概率。按照惯例,事件表示为类别因变量。事件的概率使用sigmoid(或logit)函数表示:

097-03

现在的目标是估计权重W={w1w2,… wn}和偏置项b。在logistic回归中,使用最大似然估计器或随机梯度下降法估计系数。如果p是输入数据点的总数,则损失通常定义为由以下公式得出的交叉熵项:

098-01

logistic回归用于分类问题。例如,当查看医学数据时,我们可以使用logistic回归来对一个人是否患有癌症进行分类。如果输出分类变量具有两个或多个层级,则可以使用多项logistic回归。另一种常用于两个或多个输出变量的技术是“一对多”(one versus all)。

对于多类logistic回归,将交叉熵损失函数修改为:

098-02

其中K是类别总数。你可以在https://en.wikipedia.org/wiki/Logistic_regres-sion上了解有关logistic回归的更多信息。

现在,你对logistic回归有了一些了解,让我们看看如何将其应用于任意数据集。

3.5.2 MNIST数据集上的logistic回归

接下来,使用TensorFlow估计器中可用的Estimator分类器对手写数字进行分类。我们将用到MNIST(Modified National Institute of Standards and Technology)数据集。对于从事深度学习领域的人员来说,MNIST并不是什么新鲜事物,它就像机器学习的ABC。它包含了手写数字的图像和每个图像的标签(图像里的数字)。标签包含一个介于0到9之间的值,具体取决于手写数字。

分类器估计器采用特征和标签。它将它们转换为独热编码向量,也就是说,我们有10位表示输出。每个位的值可以为0或1,对于独热意味着每个标签为Y的图像,在这10位中只有1位的值为1,其余为0。在下图中,你可以看到手写数字5的图像及其One-Hot编码的值[0 0 0 0 0 0 0 0 0 1 0]。

估计器输出对数概率、10个类的softmax概率以及相应的标签。

098-03

让我们建立模型。

1)导入所需的模块:

099-01

2)从tensorflow.keras数据集中获取MNIST的输入数据:

099-02

3)预处理数据:

099-03

4)使用TensorFlow的feature_column模块定义大小为28×28的数字特征:

099-04

5)创建logistic回归估计器。我们使用一个简单的LinearClassifier。我们也建议你尝试DNNClassifier

099-05

6)构建一个input_function作为估计器输入:

099-06

7)训练分类器:

100-01

8)为验证数据创建输入函数:

100-02

9)在验证数据集上评估训练好的线性分类器:

100-03

10)经过130个时间步长,我们的准确度达到89.4%。还不错。注意,由于我们已经指定了时间步长,因此模型会针对指定的步长进行训练,并在10个时间步长(指定的步数)后记录值。现在,如果我们再次运行train,那么它将从第十步的状态开始。该时间步长会随着上述提到的步骤数的增加而增加。

上述模型的图如图3-3所示。

100-04

图3-3 生成模型的TensorBoard图

通过TensorBoard,我们还可以直观地看到准确度和平均损失的变化,这是线性分类器以十步为单位学习的,如图3-4所示。

101-01

图3-4 准确度和平均损失的可视化表示

人们还可以使用TensorBoard来查看网络训练时模型权重和偏置的更新。在图3-5中可以看到,随着时间的推移,偏置会发生变化。可以看出,随着模型的学习(x轴为时间),偏置从初始值0开始扩展。

101-02

图3-5 偏置的更新