深度强化学习实践(原书第2版)
上QQ阅读APP看书,第一时间看更新

3.8 PyTorch Ignite

PyTorch是一个优雅而灵活的库,因此它成为成千上万的研究人员、DL爱好者、行业开发人员和其他人员的首选。但是灵活性有其自身的代价:需要写太多的代码来解决问题。有时,这是非常有益的,例如,实现一些尚未包含在标准库中的新优化方法或DL技巧时。只需使用Python实现公式,PyTorch将神奇地完成所有梯度计算和反向传播机制。另一个证明这种方法有益的场景是,当你必须关注底层原理时,比如调整梯度、了解优化器详细信息以及NN转换数据的方式。

但是,在完成日常任务(例如图像分类器的简单监督训练)时,并不需要这种灵活性。对于此类任务,标准PyTorch可能太过底层,所以你需要一遍又一遍地处理相同的代码。以下是DL训练过程中主要部分的详尽列表,但需要编写一些代码:

  • 数据准备和转换以及批次的生成。
  • 计算训练指标,例如损失值、精度和F1分数。
  • 在测试和验证数据集中对模型进行周期性测试。
  • 经过一定数量的迭代或达到新的最佳度量标准后的模型的检查点。
  • 将指标输入到TensorBoard等监控工具中。
  • 超参随着时间而变化,例如学习率的降低或增加。
  • 在控制台上输出有关训练进度的消息。

当然,它们都能使用PyTorch来实现,但是可能需要编写大量的代码。这些任务在任何DL项目中都存在,一遍又一遍地编写相同的代码很快变得麻烦。解决此问题的常规方法是编写函数,将其包装到库中,然后重复使用。如果该库是开源的且质量很高(易于使用,提供了一定程度的灵活性,可以正确编写等),那么随着越来越多的人在其项目中使用它,该库将变得流行。该过程不只发生在DL领域,它在软件行业中无处不在。

PyTorch有多个库可简化常见任务,如ptlearnfastaiignite等。“PyTorch生态系统项目”参见https://pytorch.org/ecosystem。

开始就使用这些高级库可能会很有吸引力,因为使用它们可以仅用几行代码即可解决常见问题,但是这里也存在一些危险。如果只知道如何使用高级库而不了解底层细节,那么可能会陷入无法仅由标准方法解决问题的困境。在ML的动态领域中,这种情况经常发生。

本书的重点是确保你理解RL方法、它的实现及其适用性。因此,我们将使用递进的方法。首先,仅使用PyTorch代码来实现,但是随着学习的推进,将使用高级库来实现示例。对于RL,将使用由我编写的小型库:PTAN(https://github.com/Shmuma/ptan/)。PTAN将在第7章进行介绍。

为了减少DL样板代码的数量,我们将使用一个称为PyTorch Ignite(https://pytorch.org/ignite/)的库。本节将简要介绍Ignite,然后使用Ignite重写Atari GAN示例,并对其进行检查。

Ignite概念

从高层次上讲,Ignite简化了PyTorch DL中训练循环的编写。在本章前面的“优化器”部分,可以看到最小的训练循环包括:

  • 采样一批训练数据。
  • 将NN应用于这批数据,计算损失函数(要最小化的单个值)。
  • 对损失进行反向传播,以获取与损失有关的网络参数梯度。
  • 使优化器将梯度应用于网络。
  • 重复,直到满意或不想再等待。

Ignite的核心部分是Engine类,该类遍历数据源,并将处理函数应用于数据批。除此之外,Ignite还提供了在训练循环的特定条件下,调用某函数的功能。这些特定条件称为Event,可能在以下位置:

  • 整个训练过程的开始或结束位置。
  • 训练epoch(使用数据进行迭代)的开始或结束位置。
  • 单个批处理的开始或结束位置。

除此之外,还存在自定义事件,并且允许指定每N个事件调用一次函数,例如,每100个批次或每隔一个epoch进行一次计算。

以下代码块显示了一个非常简单的Ignite示例:

076-01

该代码不可运行,因为它缺少很多内容,例如数据源、模型和优化器创建,但它展示了Ignite基本概念。Ignite的主要优势在于它能够利用现有功能扩展训练模型。你希望平滑损失值并且每100批次将其写入TensorBoard中吗?没问题!加两行代码即可完成。你想每10个epoch运行一次模型验证吗?写一个函数来运行测试,并将其加入engine中,然后它将被如期调用。

关于Ignite功能的完整描述不在本书的讨论范围,可以阅读官方网站(https://pytorch.org/ignite)的文档来查看。

为了演示Ignite,我们更改一下用GAN训练Atari图像的例子。完整的示例代码见Chapter03/04_atari_gan_ignite.py,以下代码段将仅显示有改动的部分。

076-02

首先,导入几个Ignite类:EngineEventsignite.metrics包含与训练过程的性能指标有关的类,例如混淆矩阵、精度和召回率。在本示例中,将使用RunningAverage类,该类提供一种平滑时间序列值的方法。在前面的示例中,我们通过对一系列损失值调用np.mean()来完成此操作,但是RunningAverage提供了一种更方便(并且在数学上更正确)的方法。此外,Ignite的contrib包中导入TensorBoard记录器(该功能由其他人贡献)。

077-01

下一步,我们需要定义处理函数,该函数将获取批数据,并用该批数据对判别器和生成器模型进行更新。此函数可以返回训练过程中要跟踪的任何数据,在本示例中为两个模型各自的损失值。这个函数还可以保存要在TensorBoard中显示的图像。

完成此操作后,我们要做的就是创建一个Engine实例,加上所需的处理程序,然后运行训练过程。

077-02

在前面的代码中,我们创建了engine,传递了处理函数,并为两个损失值附加了RunningAverage转换。每个RunningAverage都会产生一个所谓的“指标”,即在训练过程中保持的派生值。平滑指标avg_loss_gen表示来自生成器的平滑损失,avg_loss_dis表示来自判别器的平滑损失。这两个值在每次迭代后写入TensorBoard中。

078-01

最后一段代码附加了另一个事件处理程序,并且在每次迭代完成时由Engine调用。它会写一行日志,其索引是迭代数,值是平滑后的度量值。最后一行启动Engine,将已定义的函数作为数据源传入(函数iterate_batches是一个生成器,分批返回迭代器,因此,将其输出作为data参数传递是很好的)。

这就是Ignite的全部内容。如果运行示例Chapter03/04_atari_gan_ignite.py,它与前面示例的运行方式相同,这样的小例子可能并不会令人印象深刻,但是在实际项目中,Ignite的使用通常会使代码更简洁、更可扩展。