Dataloader 是 PyTorch 提供的根据需求载入数据的接口,但许多地方看文档会觉得不清不楚,阅读 PyTorch 源码,有助于我们了解框架背后具体做了什么,更灵活地写符合需求的代码。
第一次看学长的 PyTorch 训练 pipeline 时看到 dataloader 部分可以说是一脸懵逼,无数个问号划过我的脑海,主要是关于 collate_fn
这个传给 Dataloader 的参数:啥是 collate_fn? 为啥要传这个参数? 它接收的参数是啥? Dataloader 又是怎么工作的,怎么给主程序 load 数据的? 不过当时这些疑问过也就过去了,代码跑的没问题就不管它。后来终于遇到了自己希望不同类型的数据进行不同的 collate 方式需求,于是需要看 Dataloader 传给 collate_fn 的究竟是啥,也就干脆把 Dataloader 的源码都看了一遍,感觉还不错,记录下来,以备日后查阅。本篇先介绍一下单进程 (single process) 下 Dataloader 的工作流程。
Dataset
分析 Dataloader 当然要从 Dataset 开始谈起,因为 Dataset 是数据的来源。Dataset 非常简单,只是个需要实现 __getitem__
和 __len__
的类:
1 | class Dataset(object): |
文档中写得很清楚:
All datasets that represent a map from keys to data samples should subclass it. All subclasses should overrite
__getitem__
, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite__len__
.
这就是一个需要自己实现的 map-style 的根据 key 取 data 的类,__len__
可以实现也可以不实现。这里假设我们实验中的 Dataset 是这么定义的:
1 | class dataset(torch.utils.data.Dataset): |
h5file
是存特征的 hdf5 文件,label_df
是处理过的记录标签的 pandas.DataFrame
,colname
是 label_df
中记录 hdf5 文件的索引和标签的那两列列名。
以这个实现为例,我们看看 Dataloader
读取数据经过了怎样的一个过程。
Dataloader 初始化
Dataloader 的参数很多,有些是互斥的,这里还是以最简单的情况为例分析数据读取流程:
1 | ds = dataset(h5file, label_df) |
那么其他参数都是 Dataloader 的默认参数了,需要注意的是 num_workers 此时默认为 0,也就是单进程读取。
1 | class DataLoader(object): |
首先检查参数并将初始化参数设置一下。
1 | if isinstance(dataset, IterableDataset): |
这里根据 dataset 的种类来检查其他的参数。_DatasetKind
中定义了Map
和Iterable
两种dataset,很好理解:Map
就是我们前面定义的那样,按 map 风格用 key 取 value 的 dataset;Iterable
就是只能迭代取值无法随机读的 dataset,适合 stream data。所以这里_dataset_kind
就设置成了Map
,Iterable
的分支代码在此就省略,让思路更清晰。
1 | class _DatasetKind(object): |
接下来要确定sampler
或者batch_sampler
,这是用来生成要读取数据的 index 的采样器,既可以传入自己实现的按自己需求取数据的sampler
,也可以传入直接返回 index batch 的batch_sampler
。
1 | if sampler is not None and shuffle: |
默认的 sample
和 batch_sampler
是 None
,batch_size
是 1,shuffle
是 False
所以 sampler
设置成了 SequentialSampler
,这个类的详细实现见源码,非常简单,就是一个顺序生成 index 的 Iterable
;如果 shuffle
是 True
,sampler
就是 RandomSampler
,也是一个很简单的实现,只是将全体 index 先打乱,再顺序生成。
接着 batch_sampler
就是默认给定的 BatchSampler
,它的实现也很简单:
1 | def __iter__(self): |
迭代过程就是不断往 batch
中添加 sampler
生成的 index 直到达到 batch_size
,返回这个 index 组成的 batch。
1 | if collate_fn is None: |
最后是初始化 collate_fn
,这个参数表示将根据 index 取出的 samples 整合成需要的输入的函数,比如对于变长输入就需要 pad 到长度最长的 sample。这里是根据 _auto_collation
这个参数确定的。
1 |
|
那么很明显,batch_sampler
不是 None
,于是 collate_fn
就是 default_collate
。
数据读取
循环读取数据就是对 iter() 得到的 Iterable
结果不断 next 取值,于是:
1 | def __iter__(self): |
由于是单线程,返回的 Iterable
就是 _SingleProcessDataLoaderIter
。
1 | class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): |
我们来看 __next__
的过程:总体思路很简单,先生成 index,然后根据 index 来 fectch data,如果有 pin_memory 的要求再进行相应的操作,这里我们设定数据只读到内存的情况,不需要 pin_memory,所以就直接返回 data。
生成 index
生成 index 是通过 _next_index
完成的, 定义在基类 _BaseDataLoaderIter
中:
1 | class _BaseDataLoaderIter(object): |
可以看到,指向的是 _sampler_iter
,这又是 iter(self._index_sampler)
,即 loader._index_sampler
,定义在 Dataloader
的参数中:
1 |
|
前面已分析过,_auto_collation
是 True
,所以 index_sampler
就是 self.batch_sampler
,于是生成的 index 就是前面分析过的迭代过程:不断往batch
中添加sampler
生成的 index 直到达到batch_size
,返回这个 index 组成的 batch。
fetch data
fetch data 的操作就是这一句:
1 | data = self._dataset_fetcher.fetch(index) |
那么 _dataset_fetcher
是怎么定义的呢?
1 | self._dataset_fetcher = _DatasetKind.create_fetcher( |
是 _DatasetKind
中定义的:
1 |
|
显然这里用的是 _MapDatasetFetcher
:
1 | class _BaseDatasetFetcher(object): |
所以这里取数据的过程就很清楚了:auto_collation
就是 _SingleProcessDataLoaderIter
的 _auto_collation
,值为 True
,所以 fetch
函数先将 index batch 中每个 index 对应 dataset 中的数据取出来放到一个 list 中,然后返回经过 collate_fn
整合处理之后的结果。
这里的 collate_fn
是 default_collate
,就不分析代码了,过程也很简单,就是将各个单独的数据整合成一个总的 Tensor
,多出第一维即 batch_size。
总结一下,Dataset
是 Map
风格的或者 Iterable
的生成数据的数据源;Dataloader
接受各种参数,迭代过程中不断生成数据 batch;每次迭代过程 (Single Process) 是:由 batch_sampler
迭代取出下一批数据的 index -> fetch
函数将数据 batch 放到一个 list 中 -> 将这个 list 送给 collate_fn
(可以自己定义) 得到整合之后的 data batch,一般第一维是 batch_size。