深度强化学习实践(原书第2版)
上QQ阅读APP看书,第一时间看更新

3.5 最终黏合剂:损失函数和优化器

将输入数据转换为输出的网络并不是训练唯一需要的东西。我们还需要定义学习目标,即要有一个接受两个参数(网络输出和预期输出)的函数。它的责任是返回一个表示网络预测结果与预期结果之间的差距的数字。此函数称为损失函数,其输出为损失值。使用损失值,可以计算网络参数的梯度,并对其进行调整以减小损失值,以便优化模型的结果。损失函数和通过梯度调整网络参数的方法非常普遍,并且以多种形式存在,以至于它们构成了PyTorch库的重要组成部分。我们从损失函数开始介绍。

3.5.1 损失函数

损失函数在nn包中,并实现为nn.Module的子类。通常,它们接受两个参数:网络输出(预测)和预期输出(真实数据,也称为数据样本的标签)。在撰写本书时,PyTorch 1.3.0包含20个不同的损失函数,当然,你也可以显式地自定义要优化的函数。

最常用的标准损失函数是:

  • nn.MSELoss:参数之间的均方误差,是回归问题的标准损失。
  • nn.BCELossnn.BCEWithLogits:二分类交叉熵损失。前者期望输入是一个概率值(通常是Sigmoid层的输出),而后者则假定原始分数为输入并应用Sigmoid本身。第二种方法通常在数值上更稳定、更有效。这些损失(顾名思义)经常用于分类问题。
  • nn.CrossEntropyLossnn.NLLLoss:著名的“最大似然”标准,用于多类分类问题。前者期望的输入是每个类的原始分数,并在内部应用LogSoftmax,而后者期望将对数概率作为输入。

还有一些其他的损失函数可供使用,当然你也可以自己写Module子类来比较输出值和目标值。现在,来看下关于优化过程的部分。

3.5.2 优化器

基本优化器的职责是获取模型参数的梯度,并更改这些参数来降低损失值。通过降低损失值,使模型向期望的输出靠拢,使得模型性能越来越好。更改参数听起来很简单,但是有很多细节要处理,优化器仍是一个热门的研究主题。在torch.optim包中,PyTorch提供了许多流行的优化器实现,其中最广为人知的是:

  • SGD:具有可选动量的普通随机梯度下降算法。
  • RMSprop:Geoffrey Hinton提出的优化器。
  • Adagrad:自适应梯度优化器。
  • Adam:一种非常成功且流行的优化器,是RMSpropAdagrad的组合。

所有优化器都公开了统一的接口,因而可以轻松地尝试使用不同的优化方法(有时,优化方法可以在动态收敛和最终结果上表现优秀)。在构造时,需要传递可迭代的张量,该张量在优化过程中会被修改。通常的做法是传递上层nn.Module实例的params()调用的结果,结果将返回所有具有梯度的可迭代叶张量。

现在,我们来讨论训练循环的常见蓝图。

066-01

通常,需要一遍又一遍地遍历数据(所有数据运行一个迭代称为一个epoch)。数据通常太大而无法立即放入CPU或GPU内存中,因此将其分成大小相同的批次进行处理。每一批数据都包含数据样本和目标标签,并且它们都必须是张量(第2行和第3行代码)。

将数据样本传递给网络(第4行),并将其输出值和目标标签提供给损失函数(第5行),损失函数的结果显示了网络结果和目标标签的差距。网络的输入和网络的权重都是张量,所以网络的所有转换只不过是中间张量实例的操作图。损失函数也是如此——它的结果也是一个只有一个损失值的张量。

计算图中的每一个张量都记得其来源,因此要对整个网络计算梯度,只需要在损失函数的返回结果上调用backward()函数(第6行)即可。调用结果是展开已执行计算的图和计算requires_grad = True的叶张量的梯度。通常,这些张量是模型的参数,比如前馈网络的权重和偏差,以及卷积滤波器。每次计算梯度时,都会在tensor.grad字段中累加梯度,所以一个张量可以参与多次转换,梯度会相加。例如,循环神经网络(Recurrent Neural Network,RNN)的一个单元可以应用于多个输入项。

在调用loss.backwards()后,我们已经累加了梯度,现在是优化器执行其任务的时候了——它获取传递给它的参数的所有梯度并应用它们。所有这些都是使用step()完成的(第7行)。

训练循环最后且重要的部分是对参数梯度置零的处理。可以在网络上调用zero_grad()来实现,但是为了方便,优化器还公开了这样一个调用(第8行)。有时候zero_grad()被放在训练循环的开头,但这并没有什么影响。

上述方案是一种非常灵活的优化方法,即使在复杂的研究中也可以满足要求。例如,可以用两个优化器在同一份数据上调整不同模型的选项(这是一个来自生成对抗网络(Generative Adversarial Network,GAN)训练的真实场景)。

我们已经介绍完了训练NN所需的PyTorch的基本功能。本章以一个实际的场景结束,演示涵盖的所有概念,但在开始之前,我们需要讨论一个重要的主题——监控学习过程——这对NN从业人员来说是必不可少的。