0%

[DL] PyTorch 折桂 5:PyTorch 模块总览 & torch.utils.data

1. PyTorch 模块总览

前面用了四篇文章详细讲解了 tensor 的性质,本篇开始进入功能的介绍。相比 TensorFlow,PyTorch 是非常轻量级的:相比 TensorFlow 追求兼容并包,PyTorch 把外围功能放在了扩展包中,比如 torchtext,以保持主体的轻便。

纵观 PyTorch 的 API,其核心大概如下:

  1. torch.nn & torch.nn.functional:构建神经网络
  2. torch.nn.init:初始化权重
  3. torch.optim:优化器
  4. torch.utils.data:载入数据

可以说,掌握了上面四个模块和前文中提到的底层 API,至少 80% 的 PyTorch 任务都可以完成。剩下的外围事物则有如下的模块支持:

  1. torch.cuda:管理 GPU 资源
  2. torch.distributed:分布式训练
  3. torch.jit:构建静态图提升性能
  4. torch.tensorboard:神经网络的可视化

如果额外掌握了上面的四个的模块,PyTorch 就只剩下一些边边角角的特殊需求了。

下面我们来了解第一个功能包:torch.utils.data。这个功能包的作用是收集、打包数据,给数据索引,然后按照 batch 将数据分批喂给神经网络。

2. torch.utils.data 综述

PyTorch 数据读取的核心是 torch.utils.data.DataLoader 类。它是一个 数据迭代读取器,支持

  • 映射方式和迭代方式读取数据;
  • 自定义数据读取顺序;
  • 自动批;
  • 单线程或多线程数据读取;
  • 自动内存定位。

所有上述功能都可以在 torch.utils.data.DataLoader 的变量中定义:

1
2
3
4
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)

最重要的变量为 dataset,它指明了数据的来源。DataLoader 支持两种数据类型:

  • 映射风格的数据封装(map-style datasets):
    这种数据结构拥有自定义的 __getitem__()__len__() 属性,可以以“索引/值”的方式读取数据,对应 torch.utils.data.Dataset 类;
  • 迭代风格的数据封装(iterable-style datasets):
    这种数据结构拥有自定义的 __iter__() 属性,通常适用于不方便随机获取数据或不定长数据集的读取上,对应 torch.utils.data.IterableDataset 类。

下面我们从顶层的 torch.utils.data.DataLoader 开始,然后一步一步深入到自定义的细节上。为了方便讨论,我们先人工构建一个数据集:

1
2
>>> samples = torch.arange(100)
>>> labels = torch.cat([torch.zeros(50), torch.ones(50)], dim=0)

3. torch.utils.data.DataLoader 数据加载器

我们看一下常用的变量:

  • dataset:数据源;
  • batch_size:一个整数,定义每一批读取的元素个数;
  • shuffle:一个布尔值,定义是否随机读取;
  • sampler:定义获取数据的策略,必须与 shuffle 互斥;
  • num_workers:一个整数,读取数据使用的线程数;
  • collate_fn:一个将读取的数据处理、聚合成一个一个 batch 的自定义函数;
  • drop_last:一个布尔值,如果最后一批数据的个数不足 batch 的大小,是否保留这个 batch。

datasetsamplercollate_fn 是自定义的类或功能,我们从后往前看。

4. 数据集的分割

在介绍这三个变量以前,我们先看看如何将数据集分割,比如分成训练集和测试集。

  • torch.utils.data.Subset(dataset, indices)

这个函数可以根据索引将数据集分割。

1
2
3
4
>>> even = [i for i in range(100) if i % 2 == 0]
>>> new1 = torch.utils.data.Subset(samples, even)
>>> print(new1[:5])
tensor([0, 2, 4, 6, 8])
  • torch.utils.data.random_split(dataset, lengths)

先将数据随机排列,然后按照指定的长度进行选择。长度的和必须等于数据集中的数据数量。

1
2
3
>>> train, test = torch.utils.data.random_split(samples, [90, 10])
>>> print(torch.tensor(test))
tensor([79, 60, 98, 74, 31, 43, 21, 69, 55, 76])

5. collate_fn 核对函数

这个变量的功能是在数据被读取后,送进模型前对所有数据进行处理、打包。比如我们有一个不定长度的视频数据集或文本数据集,我们可以自定义一个函数将它们的长度归一化。比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
>>> a = [[1,2,3],[4,5],[6,7,8,9]]
>>> def collate_fn(data):
... '''
... padding data, so they have same length.
... '''
... max_len = max([len(feature) for feature in data])
... new = torch.zeros(len(data), max_len)

... for i in range(len(data)):
... tmp = torch.as_tensor(data[i])
... j = len(tmp)
... new[i][:j] = tmp

... return new

>>> collate_fn(a)
tensor([[1., 2., 3., 0.],
[4., 5., 0., 0.],
[6., 7., 8., 9.]])

将这个函数赋值给 collate_fn,在读取数据的时候就可以自动对数据进行 padding 并打包成一个 batch。

6. sampler 采样器

sampler 变量决定了数据读取的顺序。注意,sampler 只对 iterable-style datasets 有效。除了可以自定义采样器,Python 内置了几种不同的采样器:

  • torch.utils.data.SequentialSampler(data_source)

默认的采样器。

  • torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)

随机选择数据。可以指定一次读取 num_samples 个数据。replacementTrue 的话可以指定 num_samples(我并不理解为什么)。

1
2
3
>>> batch = torch.utils.data.RandomSampler(samples, replacement=True, num_samples=5) # 生成一个迭代器
>>> print(list(batch))
[85, 70, 5, 63, 79]

我个人的理解是这个采样器仅对一个 batch 内的数据进行 shuffle。

还有三个采样器无法独立使用,必须先实例化,然后放进 DataLoader

  • torch.utils.data.SubsetRandomSampler(indices):先按照索引选取数据,然后随机排列。
  • torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True):字面意思是按照概率选择不同类别的元素,不过暂时没有搞明白怎么用,先挖个坑。
  • torch.utils.data.BatchSampler(sampler, batch_size, drop_last):在一个 batch 中应用另外一个采样器。

    7. dataset 数据集生成器

  • torch.utils.data.IterableDataset

生成一个 iterable-style 的数据封装,可以实现多线程读取数据。不过官方文档是这么说,我暂时没有弄明白怎么用这个类。

  • torch.utils.data.Dataset

这个类需要覆写 __getitem____len__ 属性。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
>>> class MyData(torch.utils.data.Dataset):
... def __init__(self, data):
... super(MyData, self).__init__()
... self.data = data

... def __len__(self, data):
... return len(self.data)

... def __getitem__(self, index):
... return self.data[index]

>>> mydata = MyData(samples)
>>> mydata[0]
tensor(0)
>>> mydata[10:15]
tensor([10, 11, 12, 13, 14])

8. 总结

选择让我们把所有知识应用一下。假设我们想以 10 为一个 batch,随机选择数据:

>>> train = MyData(samples)
>>> ds = torch.utils.data.DataLoader(train[:], batch_size=10, shuffle=True)

>>> for _ in range(5):
...     print(next(iter(ds)))
tensor([22, 44, 56, 38, 86, 47, 14, 63, 88, 64])
tensor([32, 38,  6, 64, 67, 91, 54,  3, 80, 22])
tensor([77, 98, 61,  7, 17, 97, 83, 50, 26, 42])
tensor([67, 13, 10, 83, 54, 11, 31, 78, 15, 36])
tensor([ 2, 55, 87, 39, 61, 92,  0, 79, 69, 84])

欢迎关注我的其它发布渠道