深度强化学习实践(原书第2版)
上QQ阅读APP看书,第一时间看更新

3.3 NN构建块

torch.nn包中有大量预定义的类,可以提供基本的功能。这些类在设计时就考虑了实用性(例如,它们支持mini-batch处理,设置了合理的默认值,并且权重也经过了合理的初始化)。所有模块都遵循callable的约定,这意味着任何类的实例在应用于其参数时都可以充当函数。例如,Linear类实现了带有可选偏差的前馈层:

061-01

上述代码创建了一个随机初始化的前馈层,包含两个输入和五个输出,并将其应用于浮点张量。torch.nn包中的所有类均继承自nn.Module基类,可以通过该基类构建更高级别的NN模块。下一节将介绍如何自己构建,但是现在,我们先看一下所有nn.Module子类提供的方法。如下:

  • parameters():此函数返回所有需要进行梯度计算的变量的迭代器(即模块权重)。
  • zero_grad():此函数将所有参数的梯度初始化为零。
  • to(device):此函数将所有模块参数移至给定的设备(CPU或GPU)。
  • state_dict():此函数返回一个包含所有模块参数的字典,对于模型序列化很有用。
  • load_state_dict():此函数使用状态字典来初始化模块。

所有的类都可在文档(http://pytorch.org/docs)中找到。

现在,我将要提到一个非常方便的类,即Sequential,它可以将不同的层串起来。演示Sequential的最佳方法是通过一个示例:

062-01

上面的代码定义了一个三层的NN,输出层是softmax,softmax应用于第一维度(第零维度是批样本),还包括整流线性函数(Rectified Linear Unit,ReLU)非线性层和dropout。我们给这个模型输入一些数据:

062-02

mini-batch就是一个成功地遍历了网络的例子。