3.2 PyTorch基础
本章使用的机器学习库是基于PyTorch的。PyTorch是由Facebook开源的基于Python的机器学习库[229]。本节我们简要介绍PyTorch的相关基础知识,包括Tensor的创建、操作、以及自动求导。如果读者想更深入了解PyTorch的使用,请参考PyTorch官方文档(链接3-4)。
3.2.1 创建Tensor
Tensor是PyTorch的基础数据结构,是一个高维的数组,可以在跨设备(CPU、GPU等)中存储,其作用类似于Numpy中的ndarray。PyTorch中内置了多种创建Tensor的方式,我们首先导入torch模块。
• 仅指定形状大小:可以仅仅通过指定形状大小,自动生成没有初始化的任意值,包括empty、IntTensor、FloatTensor等接口。
• 通过随机化函数(PyTorch内置了很多随机化函数)创建具有某种初始分布的值,比如服从标准正态分布的randn、服从均匀分布的rand、服从高斯分布的normal等,一般我们只需要指定输出tensor值的形状。
• 通过填充特定的元素值来创建,比如通过ones函数构建一个全1矩阵,通过zeros函数构建全0矩阵,通过full函数指定其他特征值。
更多Tensor的创建方式,读者也可以参考PyTorch的官方文档,这里不再详述。
3.2.2 Tensor与Python数据结构的转换
除了上一小节提到的创建方式,PyTorch还可以将已有Python数据结构(如list,numpy.ndarray等)转换为Tensor的接口。PyTorch的运算都以Tensor为单位进行,在运算时都需要将非Tensor的数据格式转化为Tensor,主要的转换函数包括tensor、as_tensor、from_numpy。用户只需要将list或者ndarray数值作为参数传入,即可自动转换为PyTorch的Tensor数据结构。
需要注意的是,as_tensor和from_numpy会复用原数据的内存空间,也就是说,原数据或者Tensor的任意一方改变,都会导致另一方的数据改变。
3.2.3 数据操作
Tensor支持多种数据运算,例如四则运算、数学运算(如指数运算、对数运算等)等。并且,对于每一种数据的操作,PyTorch提供了多种不同的方式来完成。我们以加法运算为例,PyTorch有三种实现加法运算的方式。
• 方式一:直接使用符号“+”来完成。
• 方式二:使用add函数。
• 方式三:PyTorch对数据的操作还提供了一种独特的inplace模式,即运算后的结果直接替换原来的值,而不需要额外的临时空间。这种inplace版本一般在操作函数后面都有后缀“_”。
对于其他的张量四则运算操作,也可以仿照上面的三种方法来完成。Tensor的另一种常见操作是改变形状。PyTorch使用view()来改变Tensor中的形状,如下所示。
Tensor的创建默认是存储在CPU上的。如果设备中有GPU,为了提高数据操作的速度,我们可以将数据放置在GPU中。PyTorch提供了方便的接口将数据在两者之间切换。
如果想将数据重新放置在CPU中,只需要执行下面的操作即可。
3.2.4 自动求导
自动求导功能是PyTorch进行模型训练的核心模块,文献[228]对PyTorch的自动求导功能进行了深入的讲解和原理剖析。当前,PyTorch的自动求导功能通过autograd包实现。autograd包求导时,首先要求Tensor将requires_grad属性设置为True;随后,PyTorch将自动跟踪该Tensor的所有操作;当调用backward()进行反向计算时,将自动计算梯度值并保存在grad属性中。下面我们可以通过一个例子来查看自动求导的过程,计算过程如下。
这是一个比较简单的数学运算求解,上面的代码块所要求解的计算公式可以表示为
PyTorch采用的是动态图机制,也就是说,在训练模型时候,每迭代一次都会构建一个新的计算图。计算图代表的是程序中变量之间的相互关系,因此,我们可以将式(3.1),表示为如图3-4所示的计算图。
图3-4 对应上面代码示例的计算图
当对out变量执行backward操作后,系统将自动求取所有叶子变量对应的梯度,这里的叶子节点,就是我们的输入变量x:
但应该注意的是,PyTorch在设计时为了节省内存,没有保留中间节点的梯度值,因此,如果用户需要使用中间节点的梯度,或者自定义反向传播算法(比如Guided Backpropagation[260],GBP),就需要用到PyTorch的Hooks机制,包括register_hook和register_backward_hook。这个技巧在卷积神经网络可视化中经常使用[31,308]。Hooks机制是PyTorch的高级技巧,鉴于本书的写作目的和篇幅,我们不在此详述,读者可以查阅相关的资料[153]。
通过对式(3.1)进行求导,得到out变量关于x的导数结果如下: