![TensorFlow知识图谱实战](https://wfqqreader-1252317822.image.myqcloud.com/cover/115/44510115/b_44510115.jpg)
2.1.2 使用Keras API实现鸢尾花分类的例子(顺序模型)
iris数据集是常用的分类实验数据集,由Fisher于1936年收集整理。iris也称鸢尾花卉数据集,是一类用于多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度、花萼宽度、花瓣长度、花瓣宽度4个属性预测鸢尾花卉(见图2.2)属于Setosa、Versicolour、Virginica这3个种类中的哪一类。
![](https://epubservercos.yuewen.com/281CEB/23721624209516806/epubprivate/OEBPS/Images/Figure-P42_4353.jpg?sign=1739255994-Vyu5y0Ds2P85zsFbaufyiyWHvePyNNbW-0-54a1f0572a2f587fced152c4954fbc8a)
图2.2 鸢尾花
第一步:数据的准备
不需要读者下载这个数据集,一般常用的机器学习工具自带iris数据集,引入数据集的代码如下:
from sklearn.datasets import load_iris data = load_iris()
这里调用的是sklearn数据库中的iris数据集,直接载入即可。
而其中的数据又是以key-value值对应存放,key值如下:
![](https://epubservercos.yuewen.com/281CEB/23721624209516806/epubprivate/OEBPS/Images/Figure-P42_4354.jpg?sign=1739255994-dBOjcCR2SgXBcJrD6rurvwUeH4AOQMsO-0-c53d6e8f5b04f49384cea4aa9272eb6b)
由于本例中需要iris的特征与分类目标,因此这里只需要获取data和target,代码如下:
![](https://epubservercos.yuewen.com/281CEB/23721624209516806/epubprivate/OEBPS/Images/Figure-P42_8238.jpg?sign=1739255994-NTkYj5bAD4QzRcuzmSszNZ36LxjHs4lX-0-d6c51cc4023f580f3990fa5864445652)
数据打印结果如图2.3所示。
![](https://epubservercos.yuewen.com/281CEB/23721624209516806/epubprivate/OEBPS/Images/Figure-P42_4355.jpg?sign=1739255994-HnTXDM0J3O0tzTFp2PfsjLrP4qpSbNA5-0-4fddb00aafc56fc7287708770c8927b4)
图2.3 数据打印结果
这里是分别打印了前5条数据。可以看到iris数据集中分成了4个不同特征进行数据记录,而每条特征又对应于一个分类表示。
第二步:数据的处理
下面就是数据处理部分,对特征的表示不需要变动。而对于分类表示的结果,全部打印结果如图2.4所示。
![](https://epubservercos.yuewen.com/281CEB/23721624209516806/epubprivate/OEBPS/Images/Figure-P43_4367.jpg?sign=1739255994-OPfbQjdiIfGu3T2dA13ADYXeJLy79H2Y-0-f1aa7d875ab32511324fa49ad38e07ff)
图2.4 数据处理
这里按数字分成了3类,0、1和2分别代表3种类型。如果按直接计算的思路,可以将数据结果向固定的数字进行拟合,这是一个回归问题,即通过回归曲线去拟合出最终结果。但是本例实际上是一个分类任务,因此需要对其进行分类处理。
分类处理的一个非常简单的方法就是进行one-hot处理,即将一个序列化数据分到不同的数据领域空间进行表示,如图2.5所示。
![](https://epubservercos.yuewen.com/281CEB/23721624209516806/epubprivate/OEBPS/Images/Figure-P43_4368.jpg?sign=1739255994-M02rAjaVVkP9gI0uEM5L3vrK63unjOtU-0-be003ddeaa09ab168fd9dabdc6c5916e)
图2.5 one-hot处理
具体在程序处理上,读者可以手动实现one-hot的编码表示,也可以使用Keras自带的分散工具对数据进行处理,代码如下:
iris_target = np.float32(tf.keras.utils.to_categorical(iris_target,num_classes=3))
这里的num_classes表示分成了3类,用一行三列对每个类别进行表示。
交叉熵函数与分散化表示的方法超出了本书的讲解范围,这里就不再过多介绍,读者只需要知道交叉熵函数需要和softmax配合,从分布上向离散空间靠拢即可。
iris_data = tf.data.Dataset.from_tensor_slices(iris_data).batch(50) iris_target = tf.data.Dataset.from_tensor_slices(iris_target).batch(50)
当生成的数据读取到内存中并准备以批量的形式打印,使用的是tf.data.Dataset.from_tensor_slices函数,并且可以根据具体情况对batch进行设置。关于tf.data.Dataset函数更多的细节和用法在后面章节中会专门介绍。
第三步:梯度更新函数的写法
梯度更新函数是根据误差的幅度对数据进行更新的方法,代码如下:
grads = tape.gradient(loss_value, model.trainable_variables) opt.apply_gradients(zip(grads, model.trainable_variables))
与前面线性回归例子的差别是,使用的模型直接获取参数的方式对数据自动进行更新而非人为指定,这一点请读者注意。至于人为的指定和排除某些参数的方法属于高级程序设计,在后面的章节会提到。
【程序2-1】
![](https://epubservercos.yuewen.com/281CEB/23721624209516806/epubprivate/OEBPS/Images/Figure-P43_8250.jpg?sign=1739255994-GBcXzE6a7vOb1jwjdvugTV78OMZfvpO6-0-976f07e6c7de963218b5c14b8989da1b)
最终打印结果如图2.6所示。可以看到损失值在符合要求的条件下不停降低,达到了预期目标。
![](https://epubservercos.yuewen.com/281CEB/23721624209516806/epubprivate/OEBPS/Images/Figure-P44_4377.jpg?sign=1739255994-ea8Q5r40dy0ZKxKO29QoM3wzRJlDhRau-0-1a83f94c3019a57e41bd54b9020c0e52)
图2.6 打印结果