上QQ阅读APP看书,第一时间看更新
3.1.4 训练网络
一旦网络结构编译完成,就可以进行模型训练了。这意味着在训练数据集上调整模型权重。训练网络需要指定训练数据,包括输入矩阵X和匹配输出数组y。使用反向传播算法训练网络,并根据编译模型时指定的优化算法和损失函数进行优化。
使用fit()函数进行训练,并将训练过程存储在history变量中。执行训练的程序代码如下:
> history <- model %>% fit( + x = x_train_scale, + y = y_train, + validation_split = 0.1, + epochs = 10, + batch_size = 32, + verbose = 2)
其参数描述如下。
- x:输入数据,如果模型只有一个输入,那么x的类型是数组;如果模型有多个输入,那么x的类型应当为list,list的元素对应用于各个输入的数组。
- y:标签。
- validation_split:0~1之间的浮点数,用来指定训练集中作为验证集的数据比例。验证集不参与训练,并在每个训练周期结束后测试模型的指标,如损失函数、准确率等。本例中该参数的值为0.1,而训练集的全部数据是1048,所以1048×0.9=943项作为训练数据,1048×0.1=105项作为验证数据。
- epochs:整数,训练周期数,每个训练周期会把训练集轮一遍,这里执行10个训练周期。
- batch_size:整数,指定进行梯度下降时每个批次包含的样本数。训练1个批次的样本会被计算1次梯度下降,使目标函数优化一步。此处设置为32。
- verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个训练周期输出一行记录。此处设置为2,可以减少每个训练周期显示的信息量。
以上代码执行后的结果如图3-4所示。
图3-4 训练网络每个训练周期显示的信息量
从以上执行结果可知,训练样本数量为943,验证样本数量为105,每个训练周期会返回训练数据和验证数据的计算误差与准确率。这里共执行了10个训练周期,并且误差越来越小,准确率越来越高。
之前的训练步骤会将每一个训练周期的误差与准确率记录在history变量中。我们使用plot()函数以图形显示训练过程,如图3-5所示。
> plot(history)
图3-5 绘制每个训练周期的评估结果