PyTorch 的 Dataloader (Single Process)
Published:
Dataloader 是 PyTorch 提供的根据需求载入数据的接口,但许多地方看文档会觉得不清不楚,阅读 PyTorch 源码,有助于我们了解框架背后具体做了什么,更灵活地写符合需求的代码。
前言
第一次看学长的 PyTorch 训练 pipeline 时看到 dataloader 部分可以说是一脸懵逼,无数个问号划过我的脑海,主要是关于 collate_fn 这个传给 Dataloader 的参数:啥是 collate_fn? 为啥要传这个参数? 它接收的参数是啥? Dataloader 又是怎么工作的,怎么给主程序 load 数据的? 不过当时这些疑问过也就过去了,代码跑的没问题就不管它。后来终于遇到了自己希望不同类型的数据进行不同的 collate 方式需求,于是需要看 Dataloader 传给 collate_fn 的究竟是啥,也就干脆把 Dataloader 的源码都看了一遍,感觉还不错,记录下来,以备日后查阅。本篇先介绍一下单进程 (single process) 下 Dataloader 的工作流程。
Dataset
分析 Dataloader 当然要从 Dataset 开始谈起,因为 Dataset 是数据的来源。Dataset 非常简单,只是个需要实现 __getitem__ 和 __len__ 的类:
class Dataset(object):
r"""
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
"""
def __getitem__(self, index):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
文档中写得很清楚:
All datasets that represent a map from keys to data samples should subclass it. All subclasses should overrite
__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite__len__.
这就是一个需要自己实现的 map-style 的根据 key 取 data 的类,__len__ 可以实现也可以不实现。这里假设我们实验中的 Dataset 是这么定义的:
class dataset(torch.utils.data.Dataset):
def __init__(self, h5file, label_df, colname=("h5index", "label")):
self._h5file = h5file
self._label_df = label_df
self._colname = colname
def __getitem__(self, index):
if self._dataset is None:
self._dataset = h5py.File(self._h5file, "r")
h5index, label = self._label_df.iloc[index].reindex(self._colname).values
data = self._dataset[h5index][()]
return torch.as_tensor(data), torch.as_tensor(label)
def __len__(self):
return len(self._label_df)
h5file 是存特征的 hdf5 文件,label_df 是处理过的记录标签的 pandas.DataFrame,colname 是 label_df中记录 hdf5 文件的索引和标签的那两列列名。
以这个实现为例,我们看看 Dataloader 读取数据经过了怎样的一个过程。
Dataloader 初始化
Dataloader 的参数很多,有些是互斥的,这里还是以最简单的情况为例分析数据读取流程:
ds = dataset(h5file, label_df)
dl = torch.utils.data.Dataloader(ds)
那么其他参数都是 Dataloader 的默认参数了,需要注意的是 num_workers 此时默认为 0,也就是单进程读取。
class DataLoader(object):
__initialized = False
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None):
torch._C._log_api_usage_once("python.data_loader")
if num_workers < 0:
raise ValueError('num_workers option should be non-negative; '
'use num_workers=0 to disable multiprocessing.')
if timeout < 0:
raise ValueError('timeout option should be non-negative')
self.dataset = dataset
self.num_workers = num_workers
self.pin_memory = pin_memory
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.multiprocessing_context = multiprocessing_context
首先检查参数并将初始化参数设置一下。
if isinstance(dataset, IterableDataset):
# ...
pass
else:
self._dataset_kind = _DatasetKind.Map
这里根据 dataset 的种类来检查其他的参数。_DatasetKind中定义了Map和Iterable两种dataset,很好理解:Map就是我们前面定义的那样,按 map 风格用 key 取 value 的 dataset;Iterable就是只能迭代取值无法随机读的 dataset,适合 stream data。所以这里_dataset_kind就设置成了Map,Iterable的分支代码在此就省略,让思路更清晰。
class _DatasetKind(object):
Map = 0
Iterable = 1
@staticmethod
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
# ...
pass
接下来要确定sampler或者batch_sampler,这是用来生成要读取数据的 index 的采样器,既可以传入自己实现的按自己需求取数据的sampler,也可以传入直接返回 index batch 的batch_sampler。
if sampler is not None and shuffle:
# ...
pass
if batch_sampler is not None:
# ...
pass
elif batch_size is None:
# ...
pass
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.batch_size = batch_size
self.drop_last = drop_last
self.sampler = sampler
self.batch_sampler = batch_sampler
默认的 sample 和 batch_sampler 是 None,batch_size是 1,shuffle 是 False 所以 sampler 设置成了 SequentialSampler,这个类的详细实现见源码,非常简单,就是一个顺序生成 index 的 Iterable;如果 shuffle 是 True,sampler 就是 RandomSampler,也是一个很简单的实现,只是将全体 index 先打乱,再顺序生成。
接着 batch_sampler 就是默认给定的 BatchSampler,它的实现也很简单:
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
迭代过程就是不断往 batch 中添加 sampler 生成的 index 直到达到 batch_size,返回这个 index 组成的 batch。
if collate_fn is None:
if self._auto_collation:
collate_fn = _utils.collate.default_collate
else:
collate_fn = _utils.collate.default_convert
self.collate_fn = collate_fn
self.__initialized = True
最后是初始化 collate_fn,这个参数表示将根据 index 取出的 samples 整合成需要的输入的函数,比如对于变长输入就需要 pad 到长度最长的 sample。这里是根据 _auto_collation 这个参数确定的。
@property
def _auto_collation(self):
return self.batch_sampler is not None
那么很明显,batch_sampler 不是 None,于是 collate_fn 就是 default_collate。
数据读取
循环读取数据就是对 iter() 得到的 Iterable 结果不断 next 取值,于是:
def __iter__(self):
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self)
由于是单线程,返回的 Iterable 就是 _SingleProcessDataLoaderIter。
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def __next__(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
next = __next__ # Python 2 compatibility
我们来看 __next__ 的过程:总体思路很简单,先生成 index,然后根据 index 来 fectch data,如果有 pin_memory 的要求再进行相应的操作,这里我们设定数据只读到内存的情况,不需要 pin_memory,所以就直接返回 data。
生成 index
生成 index 是通过 _next_index 完成的, 定义在基类 _BaseDataLoaderIter 中:
class _BaseDataLoaderIter(object):
def __init__(self, loader):
self._dataset = loader.dataset
self._dataset_kind = loader._dataset_kind
self._auto_collation = loader._auto_collation
self._drop_last = loader.drop_last
self._index_sampler = loader._index_sampler
self._num_workers = loader.num_workers
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
self._timeout = loader.timeout
self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler)
self._base_seed = torch.empty((), dtype=torch.int64).random_().item()
def __iter__(self):
return self
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
可以看到,指向的是 _sampler_iter,这又是 iter(self._index_sampler),即 loader._index_sampler,定义在 Dataloader 的参数中:
@property
def _index_sampler(self):
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
前面已分析过,_auto_collation 是 True,所以 index_sampler 就是 self.batch_sampler,于是生成的 index 就是前面分析过的迭代过程:不断往batch中添加sampler生成的 index 直到达到batch_size,返回这个 index 组成的 batch。
fetch data
fetch data 的操作就是这一句:
data = self._dataset_fetcher.fetch(index)
那么 _dataset_fetcher 是怎么定义的呢?
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
是 _DatasetKind 中定义的:
@staticmethod
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
if kind == _DatasetKind.Map:
return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
else:
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
显然这里用的是 _MapDatasetFetcher:
class _BaseDatasetFetcher(object):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
self.dataset = dataset
self.auto_collation = auto_collation
self.collate_fn = collate_fn
self.drop_last = drop_last
def fetch(self, possibly_batched_index):
raise NotImplementedError()
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
所以这里取数据的过程就很清楚了:auto_collation 就是 _SingleProcessDataLoaderIter 的 _auto_collation,值为 True,所以 fetch 函数先将 index batch 中每个 index 对应 dataset 中的数据取出来放到一个 list 中,然后返回经过 collate_fn 整合处理之后的结果。
这里的 collate_fn 是 default_collate,就不分析代码了,过程也很简单,就是将各个单独的数据整合成一个总的 Tensor,多出第一维即 batch_size。
总结一下,Dataset 是 Map 风格的或者 Iterable 的生成数据的数据源;Dataloader 接受各种参数,迭代过程中不断生成数据 batch;每次迭代过程 (Single Process) 是:由 batch_sampler 迭代取出下一批数据的 index -> fetch 函数将数据 batch 放到一个 list 中 -> 将这个 list 送给 collate_fn (可以自己定义) 得到整合之后的 data batch,一般第一维是 batch_size。
