0%

[DL] PyTorch 折桂 7:torch.nn 总览 & nn.Linear & 常用激活函数

1 torch.nn 总览

PyTorch 把与深度学习模型搭建相关的全部类全部在 torch.nn 这个子模块中。根据类的功能分类,常用的有如下十几个部分:

  • Containers:容器类,如 torch.nn.Module
  • Convolution Layers:卷积层,如 torch.nn.Conv2d
  • Pooling Layers:池化层,如 torch.nn.MaxPool2d
  • Non-linear activations:非线性激活层,如 torch.nn.ReLU
  • Normalization layers:归一化层,如 torch.nn.BatchNorm2d
  • Recurrent layers:循环神经层,如 torch.nn.LSTM
  • Transformer layers:transformer 层,如 torch.nn.TransformerEncoder
  • Linear layers:线性连接层,如 torch.nn.Linear
  • Dropout layers:dropout 层,如 torch.nn.Dropout
  • Sparse layers:稀疏层,如 torch.nn.Embedding
  • Vision layers:vision 层,如 torch.nn.Upsample
  • DataParallel layers:平行计算层,如 torch.nn.DataParallel
  • Utilities:其它功能,如 torch.nn.utils.clip_grad_value_ 而在 torch.nn 下面还有一个子模块 torch.nn.functional,基本上是 torch.nn 里对应类的函数,比如 torch.nn.ReLU 的对应函数是 torch.nn.functional.relu。为什么要这么做呢?

    你可能会疑惑为什么需要这两个功能如此相近的模块,其实这么设计是有其原因的。如果我们只保留 nn.functional 下的函数的话,在训练或者使用时,我们就要手动去维护 weight,bias,stride 这些中间量的值,这显然是给用户带来了不便。而如果我们只保留 nn 下的类的话,其实就牺牲了一部分灵活性,因为做一些简单的计算都需要创造一个类,这也与 PyTorch 的风格不符。(知乎回答

torch.nn 可以被 nn.Module 识别,并成为网络组成的一部分;torch.nn.functional 则不行。比较以下两个模型:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
>>> class Simple(nn.Module):
... def __init__(self):
... super(Simple, self).__init__()
... self.fc = nn.Linear(10, 1)
... self.dropout = nn.Dropout(0.5) # 使用 nn.Dropout 类

... def forward(self, x):
... x = self.fc(x)
... x = self.dropout(x)
... return x
>>> simple = Simple()
>>> print(simple)
Simple(
(fc): Linear(in_features=10, out_features=1, bias=True)
(dropout): Dropout(p=0.5, inplace=False) #可以被识别成一层
)

>>> class Simple2(nn.Module):
... def __init__(self):
... super(Simple2, self).__init__()
... self.fc = nn.Linear(10, 1)

... def forward(self, x):
... x = F.dropout(self.fc(x)) # 使用 nn.functional.dropout,不能被识别
... return x
>>> simple2 = Simple2()
>>> print(simple2)
Simple2(
(fc): Linear(in_features=10, out_features=1, bias=True)
)

什么时候调用 torch.nn,什么时候调用 torch.nn.functional 呢?个人的经验是:不需要存储权重的时候使用 torch.nn.functional,需要存储权重的时候使用 torch.nn

  • 层、dropout 使用 torch.nn
  • 激活函数使用 torch.nn.functional

这里要额外说一下 dropout 层。理论上 dropout 没有权重,可以使用 torch.nn.functional.dropout,然而 dropout 有traineval 模式,使用 torch.nn.Dropout 可以方便地对模式进行控制,而函数就不行。所以为了方便,推荐使用 torch.nn.Dropout

以后若没有特殊说明,均在引入模块时省略 torch 模块名称。

2. nn.Linear

线性连接层又叫做全连接层(fully connected layer),指的是通过矩阵乘法将前一层的矩阵变换为下一层的矩阵:
$$layer1*W+b=layer2$$
在这里插入图片描述
W 被称为全连接层的 weights,b 被称为全连接层的 bias。通常为了演示方便,我们忽略 bias。
layer1 如果是一个 $m*n$ 的矩阵,$W$ 是一个 $n*k$ 的矩阵,那么下一层 layer2 就是一个 $m*k$ 的矩阵。n 称为输入特征数(input size),k 称为输出特征数(output size),那么这个线性连接层可以被这样初始化:

1
fc = nn.Linear(input_size, output_size)

multilayer perception(多层感知机,MLP)就是通过若干个全连接层组合而成的。但是事实证明 MLP 的性能并不好,为什么呢?假设一个 MLP 由三个全连接层组成,三层分别为
$$x_3=x_2*W_2$$
$$x_2=x_1*W_1$$
我们把第二个式子中的 $x_2$ 代入第一个式子,可得:
$$X_3=(x_1*W_1)*W_2=x_1*(W_1*W_2)$$
可见若干层全连接层相连,最终可以化简为一个全连接层。为了解决这个问题,激活函数(activation function)出现了。

3. 激活函数

激活函数就是非线性连接层,通过非线性函数将一层变为另一层。常用的激活函数有 sigmoidtanhrelu 及其变种。虽然 torch.nn 有激活函数层,因为激活函数比较轻量级,使用 torch.nn.functional 里的函数功能就足够了。通常我们将 torch.nn.functional 写成 F

1
import torch.nn.functional as F
  • F.sigmoid
    在这里插入图片描述
    sigmoid 又叫做 logistic,通常写作 $\sigma$,公式为
    $$sigmoid(x)=\sigma(x)=\frac{1}{1+e^{-x}}$$
    sigmoid 的值域为 $(0,1)$,所以通常用于二分类问题:大于 $0.5$ 为一类,小于 $0.5$ 为另一类。sigmoid 的导数公式为
    $$\sigma’(x)=\sigma(x)(1-\sigma(x))$$
    导数的值域为 $(0,0.25)$。sigmoid 函数的特点为:
  1. 函数的值在 $(0,1)$ 之间,符合概率分布;
  2. 导数的值域为 $(0,0.25)$,容易造成梯度消失;
  3. 输出为非对称正值,破坏数据分布。
  • F.tanh
    在这里插入图片描述
    tanh 是正切函数,公式为
    $$tanh(x)=\frac{sin(x)}{cos(x)}=\frac{e^x+e^{-x}}{e^x+e^{-x}}$$
    tanh 的值域为 $(0,1)$,对称分布。它的导数公式为
    $$tanh’(x)=1-tanh^2(x)$$
    导数的值域为 $(0,1)$。tanh 的特点为:
  1. 函数值域为 $(0,1)$,对称分布;
  2. 导数值域为 $(0,1)$,容易造成梯度消失。
  • F.relu
    在这里插入图片描述
    为了解决上述两个激活函数容易产生梯度消失的问题,Rectified Linear Unit(relu) 横空出世了。它实际上是一个分段函数:
    $$relu(x)=
    \begin{cases}
    0,\ x<0\
    x,\ x>0
    \end{cases}$$
    relu 的优点在于求导非常方便,而且非常稳定:
    $$relu’(x)=
    \begin{cases}
    0,\ x<0\
    \text{unidentified},\ x=0\
    1,\ x>0
    \end{cases}$$
    缺点在于
  1. 当 $x<0$ 时导数为 0,神经元“死亡”,即不再更新;
  2. 虽然没有梯度消失的问题,但有梯度爆炸的问题。
  • F.leakyrelu
    在这里插入图片描述
    为了解决 relu 的问题,对其稍加改动成为了 leakyrelu
    $$relu(x)=
    \begin{cases}
    0,\ x<0\
    \alpha x,\ x>0
    \end{cases}$$
    $\alpha$ 是一个很小的数,通常是 0.01。这样它的导数就变成了
    $$relu(x)=
    \begin{cases}
    0,\ x<0\
    \alpha,\ x>0
    \end{cases}$$

欢迎关注我的其它发布渠道