1. PyTorch 模块总览
前面用了四篇文章详细讲解了 tensor 的性质,本篇开始进入功能的介绍。相比 TensorFlow,PyTorch 是非常轻量级的:相比 TensorFlow 追求兼容并包,PyTorch 把外围功能放在了扩展包中,比如 torchtext
,以保持主体的轻便。
纵观 PyTorch 的 API,其核心大概如下:
torch.nn
&torch.nn.functional
:构建神经网络torch.nn.init
:初始化权重torch.optim
:优化器torch.utils.data
:载入数据
可以说,掌握了上面四个模块和前文中提到的底层 API,至少 80% 的 PyTorch 任务都可以完成。剩下的外围事物则有如下的模块支持:
torch.cuda
:管理 GPU 资源torch.distributed
:分布式训练torch.jit
:构建静态图提升性能torch.tensorboard
:神经网络的可视化
如果额外掌握了上面的四个的模块,PyTorch 就只剩下一些边边角角的特殊需求了。
下面我们来了解第一个功能包:torch.utils.data
。这个功能包的作用是收集、打包数据,给数据索引,然后按照 batch 将数据分批喂给神经网络。
2. torch.utils.data
综述
PyTorch 数据读取的核心是 torch.utils.data.DataLoader
类。它是一个 数据迭代读取器,支持
- 映射方式和迭代方式读取数据;
- 自定义数据读取顺序;
- 自动批;
- 单线程或多线程数据读取;
- 自动内存定位。
所有上述功能都可以在 torch.utils.data.DataLoader
的变量中定义:
1 | DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, |
最重要的变量为 dataset
,它指明了数据的来源。DataLoader
支持两种数据类型:
- 映射风格的数据封装(map-style datasets):
这种数据结构拥有自定义的__getitem__()
和__len__()
属性,可以以“索引/值”的方式读取数据,对应torch.utils.data.Dataset
类; - 迭代风格的数据封装(iterable-style datasets):
这种数据结构拥有自定义的__iter__()
属性,通常适用于不方便随机获取数据或不定长数据集的读取上,对应torch.utils.data.IterableDataset
类。
下面我们从顶层的 torch.utils.data.DataLoader
开始,然后一步一步深入到自定义的细节上。为了方便讨论,我们先人工构建一个数据集:
1 | 100) samples = torch.arange( |
3. torch.utils.data.DataLoader
数据加载器
我们看一下常用的变量:
dataset
:数据源;batch_size
:一个整数,定义每一批读取的元素个数;shuffle
:一个布尔值,定义是否随机读取;sampler
:定义获取数据的策略,必须与shuffle
互斥;num_workers
:一个整数,读取数据使用的线程数;collate_fn
:一个将读取的数据处理、聚合成一个一个 batch 的自定义函数;drop_last
:一个布尔值,如果最后一批数据的个数不足 batch 的大小,是否保留这个 batch。
dataset
, sampler
和 collate_fn
是自定义的类或功能,我们从后往前看。
4. 数据集的分割
在介绍这三个变量以前,我们先看看如何将数据集分割,比如分成训练集和测试集。
torch.utils.data.Subset(dataset, indices)
这个函数可以根据索引将数据集分割。
1 | for i in range(100) if i % 2 == 0] even = [i |
torch.utils.data.random_split(dataset, lengths)
先将数据随机排列,然后按照指定的长度进行选择。长度的和必须等于数据集中的数据数量。
1 | 90, 10]) train, test = torch.utils.data.random_split(samples, [ |
5. collate_fn
核对函数
这个变量的功能是在数据被读取后,送进模型前对所有数据进行处理、打包。比如我们有一个不定长度的视频数据集或文本数据集,我们可以自定义一个函数将它们的长度归一化。比如:
1 | 1,2,3],[4,5],[6,7,8,9]] a = [[ |
将这个函数赋值给 collate_fn
,在读取数据的时候就可以自动对数据进行 padding 并打包成一个 batch。
6. sampler
采样器
sampler
变量决定了数据读取的顺序。注意,sampler
只对 iterable-style datasets 有效。除了可以自定义采样器,Python 内置了几种不同的采样器:
torch.utils.data.SequentialSampler(data_source)
默认的采样器。
torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)
随机选择数据。可以指定一次读取 num_samples
个数据。replacement
为 True
的话可以指定 num_samples
(我并不理解为什么)。
1 | True, num_samples=5) # 生成一个迭代器 batch = torch.utils.data.RandomSampler(samples, replacement= |
我个人的理解是这个采样器仅对一个 batch 内的数据进行 shuffle。
还有三个采样器无法独立使用,必须先实例化,然后放进 DataLoader
:
torch.utils.data.SubsetRandomSampler(indices)
:先按照索引选取数据,然后随机排列。torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)
:字面意思是按照概率选择不同类别的元素,不过暂时没有搞明白怎么用,先挖个坑。torch.utils.data.BatchSampler(sampler, batch_size, drop_last)
:在一个 batch 中应用另外一个采样器。7.
dataset
数据集生成器torch.utils.data.IterableDataset
生成一个 iterable-style 的数据封装,可以实现多线程读取数据。不过官方文档是这么说,我暂时没有弄明白怎么用这个类。
torch.utils.data.Dataset
这个类需要覆写 __getitem__
和 __len__
属性。
1 | class MyData(torch.utils.data.Dataset): |
8. 总结
选择让我们把所有知识应用一下。假设我们想以 10 为一个 batch,随机选择数据:
10, shuffle=True)
for _ in range(5):
print(next(iter(ds)))
tensor([22, 44, 56, 38, 86, 47, 14, 63, 88, 64])
tensor([32, 38, 6, 64, 67, 91, 54, 3, 80, 22])
tensor([77, 98, 61, 7, 17, 97, 83, 50, 26, 42])
tensor([67, 13, 10, 83, 54, 11, 31, 78, 15, 36])
tensor([ 2, 55, 87, 39, 61, 92, 0, 79, 69, 84])
train = MyData(samples)
ds = torch.utils.data.DataLoader(train[:], batch_size=