2.1.4 训练模型
构建模型(假设为model)后,接下来就是训练模型。PyTorch训练模型主要包括加载和预处理数据集、定义损失函数、定义优化算法、循环训练模型、循环测试或验证模型、可视化结果等步骤。
(1)加载和预处理数据集
加载和预处理数据集可以使用PyTorch的数据处理工具,如torch.utils和torchvision等,这些工具将在第4章中详细介绍。
(2)定义损失函数
定义损失函数可以通过自定义方法或使用PyTorch内置的损失函数,如回归使用的nn.MSELoss()、分类使用的nn.BCELoss等损失函数,更多内容可参考2.9节。
(3)定义优化算法
PyTorch常用的优化算法都封装在torch.optim中,其设计灵活,可以扩展为自定义的优化算法。所有的优化算法都是继承了基类optim.Optimizer,并实现了自己的优化步骤。
最常用的优化算法就是梯度下降法及其变种,具体将在2.10节详细介绍,这些优化算法大多使用梯度更新参数。
如使用SGD优化器时,可设置为optimizer=torch.optim.SGD(params,lr=0.001)。
(4)循环训练模型
1)设置为训练模式:model.train()。调用model.train()会把所有的module设置为训练模式。
2)梯度清零:optimizer. zero_grad()。在默认情况下,梯度是累加的,需要手工把梯度初始化或清零,调用optimizer.zero_grad()即可。
3)求损失值:y_prev=model(x),loss=loss_fun(y_prev,y_true)。
4)自动求导,实现梯度的反向传播:loss.backward()。
5)更新参数:optimizer.step()。
(5)循环测试或验证模型
1)设置为测试或验证模式:model.eval()。调用model.eval()会把所有的training属性设置为False。
2)在不跟踪梯度模式下计算损失值、预测值等:with.torch.no_grad()。
(6)可视化结果
下面通过实例来说明如何使用nn来构建网络模型、训练模型。
说明:如果模型中有BN(Batch Normalization,批归一化)层和dropout层,需要在训练时添加model.train(),在测试时添加model.eval()。其中,model.train()用于确保BN层使用每一批数据的均值和方差进行训练,而model.eval()用于确保BN使用全部训练数据的均值和方差进行评估;而对于dropout层,model.train()用于随机取一部分网络连接来训练更新参数,而model.eval()则利用到了所有网络连接进行评估。