PyTorch 的 Dataloader (Single Process)

Dataloader 是 PyTorch 提供的根据需求载入数据的接口,但许多地方看文档会觉得不清不楚,阅读 PyTorch 源码,有助于我们了解框架背后具体做了什么,更灵活地写符合需求的代码。

第一次看学长的 PyTorch 训练 pipeline 时看到 dataloader 部分可以说是一脸懵逼,无数个问号划过我的脑海,主要是关于 collate_fn 这个传给 Dataloader 的参数:啥是 collate_fn? 为啥要传这个参数? 它接收的参数是啥? Dataloader 又是怎么工作的,怎么给主程序 load 数据的? 不过当时这些疑问过也就过去了,代码跑的没问题就不管它。后来终于遇到了自己希望不同类型的数据进行不同的 collate 方式需求,于是需要看 Dataloader 传给 collate_fn 的究竟是啥,也就干脆把 Dataloader 的源码都看了一遍,感觉还不错,记录下来,以备日后查阅。本篇先介绍一下单进程 (single process) 下 Dataloader 的工作流程。

Dataset

分析 Dataloader 当然要从 Dataset 开始谈起,因为 Dataset 是数据的来源。Dataset 非常简单,只是个需要实现 __getitem____len__ 的类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Dataset(object):
r"""
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
"""

def __getitem__(self, index):
raise NotImplementedError

def __add__(self, other):
return ConcatDataset([self, other])

文档中写得很清楚:

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overrite __getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__.

这就是一个需要自己实现的 map-style 的根据 key 取 data 的类,__len__ 可以实现也可以不实现。这里假设我们实验中的 Dataset 是这么定义的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class dataset(torch.utils.data.Dataset):

def __init__(self, h5file, label_df, colname=("h5index", "label")):
self._h5file = h5file
self._label_df = label_df
self._colname = colname

def __getitem__(self, index):
if self._dataset is None:
self._dataset = h5py.File(self._h5file, "r")
h5index, label = self._label_df.iloc[index].reindex(self._colname).values
data = self._dataset[h5index][()]
return torch.as_tensor(data), torch.as_tensor(label)

def __len__(self):
return len(self._label_df)

h5file 是存特征的 hdf5 文件,label_df 是处理过的记录标签的 pandas.DataFramecolnamelabel_df中记录 hdf5 文件的索引和标签的那两列列名。

以这个实现为例,我们看看 Dataloader 读取数据经过了怎样的一个过程。

Dataloader 初始化

Dataloader 的参数很多,有些是互斥的,这里还是以最简单的情况为例分析数据读取流程:

1
2
ds = dataset(h5file, label_df)
dl = torch.utils.data.Dataloader(ds)

那么其他参数都是 Dataloader 的默认参数了,需要注意的是 num_workers 此时默认为 0,也就是单进程读取。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class DataLoader(object):

__initialized = False

def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None):
torch._C._log_api_usage_once("python.data_loader")

if num_workers < 0:
raise ValueError('num_workers option should be non-negative; '
'use num_workers=0 to disable multiprocessing.')

if timeout < 0:
raise ValueError('timeout option should be non-negative')

self.dataset = dataset
self.num_workers = num_workers
self.pin_memory = pin_memory
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.multiprocessing_context = multiprocessing_context

首先检查参数并将初始化参数设置一下。

1
2
3
4
5
if isinstance(dataset, IterableDataset):
# ...
pass
else:
self._dataset_kind = _DatasetKind.Map

这里根据 dataset 的种类来检查其他的参数。_DatasetKind中定义了MapIterable两种dataset,很好理解:Map就是我们前面定义的那样,按 map 风格用 key 取 value 的 dataset;Iterable就是只能迭代取值无法随机读的 dataset,适合 stream data。所以这里_dataset_kind就设置成了MapIterable的分支代码在此就省略,让思路更清晰。

1
2
3
4
5
6
7
8
class _DatasetKind(object):
Map = 0
Iterable = 1

@staticmethod
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
# ...
pass

接下来要确定sampler或者batch_sampler,这是用来生成要读取数据的 index 的采样器,既可以传入自己实现的按自己需求取数据的sampler,也可以传入直接返回 index batch 的batch_sampler

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
if sampler is not None and shuffle:
# ...
pass

if batch_sampler is not None:
# ...
pass
elif batch_size is None:
# ...
pass

if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)

if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)

self.batch_size = batch_size
self.drop_last = drop_last
self.sampler = sampler
self.batch_sampler = batch_sampler

默认的 samplebatch_samplerNonebatch_size是 1,shuffleFalse 所以 sampler 设置成了 SequentialSampler,这个类的详细实现见源码,非常简单,就是一个顺序生成 index 的 Iterable;如果 shuffleTruesampler 就是 RandomSampler,也是一个很简单的实现,只是将全体 index 先打乱,再顺序生成。

接着 batch_sampler 就是默认给定的 BatchSampler,它的实现也很简单:

1
2
3
4
5
6
7
8
9
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch

迭代过程就是不断往 batch 中添加 sampler 生成的 index 直到达到 batch_size,返回这个 index 组成的 batch。

1
2
3
4
5
6
7
8
if collate_fn is None:
if self._auto_collation:
collate_fn = _utils.collate.default_collate
else:
collate_fn = _utils.collate.default_convert

self.collate_fn = collate_fn
self.__initialized = True

最后是初始化 collate_fn,这个参数表示将根据 index 取出的 samples 整合成需要的输入的函数,比如对于变长输入就需要 pad 到长度最长的 sample。这里是根据 _auto_collation 这个参数确定的。

1
2
3
@property                                                                             
def _auto_collation(self):
return self.batch_sampler is not None

那么很明显,batch_sampler 不是 None,于是 collate_fn 就是 default_collate

数据读取

循环读取数据就是对 iter() 得到的 Iterable 结果不断 next 取值,于是:

1
2
3
4
5
def __iter__(self): 
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self)

由于是单线程,返回的 Iterable 就是 _SingleProcessDataLoaderIter

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0

self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

def __next__(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data

next = __next__ # Python 2 compatibility

我们来看 __next__ 的过程:总体思路很简单,先生成 index,然后根据 index 来 fectch data,如果有 pin_memory 的要求再进行相应的操作,这里我们设定数据只读到内存的情况,不需要 pin_memory,所以就直接返回 data。

生成 index

生成 index 是通过 _next_index 完成的, 定义在基类 _BaseDataLoaderIter 中:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class _BaseDataLoaderIter(object):
def __init__(self, loader):
self._dataset = loader.dataset
self._dataset_kind = loader._dataset_kind
self._auto_collation = loader._auto_collation
self._drop_last = loader.drop_last
self._index_sampler = loader._index_sampler
self._num_workers = loader.num_workers
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
self._timeout = loader.timeout
self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler)
self._base_seed = torch.empty((), dtype=torch.int64).random_().item()

def __iter__(self):
return self

def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration

可以看到,指向的是 _sampler_iter,这又是 iter(self._index_sampler),即 loader._index_sampler,定义在 Dataloader 的参数中:

1
2
3
4
5
6
@property
def _index_sampler(self):
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler

前面已分析过,_auto_collationTrue,所以 index_sampler 就是 self.batch_sampler,于是生成的 index 就是前面分析过的迭代过程:不断往batch中添加sampler生成的 index 直到达到batch_size,返回这个 index 组成的 batch。

fetch data

fetch data 的操作就是这一句:

1
data = self._dataset_fetcher.fetch(index)

那么 _dataset_fetcher 是怎么定义的呢?

1
2
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

_DatasetKind 中定义的:

1
2
3
4
5
6
@staticmethod
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
if kind == _DatasetKind.Map:
return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
else:
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)

显然这里用的是 _MapDatasetFetcher

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class _BaseDatasetFetcher(object):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
self.dataset = dataset
self.auto_collation = auto_collation
self.collate_fn = collate_fn
self.drop_last = drop_last

def fetch(self, possibly_batched_index):
raise NotImplementedError()

class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)

所以这里取数据的过程就很清楚了:auto_collation 就是 _SingleProcessDataLoaderIter_auto_collation,值为 True,所以 fetch 函数先将 index batch 中每个 index 对应 dataset 中的数据取出来放到一个 list 中,然后返回经过 collate_fn 整合处理之后的结果。

这里的 collate_fndefault_collate,就不分析代码了,过程也很简单,就是将各个单独的数据整合成一个总的 Tensor,多出第一维即 batch_size。

总结一下,DatasetMap 风格的或者 Iterable 的生成数据的数据源;Dataloader 接受各种参数,迭代过程中不断生成数据 batch;每次迭代过程 (Single Process) 是:由 batch_sampler 迭代取出下一批数据的 index -> fetch 函数将数据 batch 放到一个 list 中 -> 将这个 list 送给 collate_fn (可以自己定义) 得到整合之后的 data batch,一般第一维是 batch_size。