PyTorch 的 Dataloader (Single Process)

5 minute read

Published:

Dataloader 是 PyTorch 提供的根据需求载入数据的接口,但许多地方看文档会觉得不清不楚,阅读 PyTorch 源码,有助于我们了解框架背后具体做了什么,更灵活地写符合需求的代码。

前言

第一次看学长的 PyTorch 训练 pipeline 时看到 dataloader 部分可以说是一脸懵逼,无数个问号划过我的脑海,主要是关于 collate_fn 这个传给 Dataloader 的参数:啥是 collate_fn? 为啥要传这个参数? 它接收的参数是啥? Dataloader 又是怎么工作的,怎么给主程序 load 数据的? 不过当时这些疑问过也就过去了,代码跑的没问题就不管它。后来终于遇到了自己希望不同类型的数据进行不同的 collate 方式需求,于是需要看 Dataloader 传给 collate_fn 的究竟是啥,也就干脆把 Dataloader 的源码都看了一遍,感觉还不错,记录下来,以备日后查阅。本篇先介绍一下单进程 (single process) 下 Dataloader 的工作流程。

Dataset

分析 Dataloader 当然要从 Dataset 开始谈起,因为 Dataset 是数据的来源。Dataset 非常简单,只是个需要实现 __getitem____len__ 的类:

class Dataset(object):
      r"""
      All datasets that represent a map from keys to data samples should subclass
      it. All subclasses should overrite :meth:`__getitem__`, supporting fetching a
      data sample for a given key. Subclasses could also optionally overwrite
      :meth:`__len__`, which is expected to return the size of the dataset by many
      :class:`~torch.utils.data.Sampler` implementations and the default options
      of :class:`~torch.utils.data.DataLoader`.
      """
   
      def __getitem__(self, index):
          raise NotImplementedError
            
      def __add__(self, other):
          return ConcatDataset([self, other])

文档中写得很清楚:

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overrite __getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__.

这就是一个需要自己实现的 map-style 的根据 key 取 data 的类,__len__ 可以实现也可以不实现。这里假设我们实验中的 Dataset 是这么定义的:

class dataset(torch.utils.data.Dataset):
    
    def __init__(self, h5file, label_df, colname=("h5index", "label")):
        self._h5file = h5file
        self._label_df = label_df
        self._colname = colname
        
    def __getitem__(self, index):
        if self._dataset is None:
            self._dataset = h5py.File(self._h5file, "r")
        h5index, label = self._label_df.iloc[index].reindex(self._colname).values
        data = self._dataset[h5index][()]
        return torch.as_tensor(data), torch.as_tensor(label)
    
    def __len__(self):
        return len(self._label_df)

h5file 是存特征的 hdf5 文件,label_df 是处理过的记录标签的 pandas.DataFramecolnamelabel_df中记录 hdf5 文件的索引和标签的那两列列名。

以这个实现为例,我们看看 Dataloader 读取数据经过了怎样的一个过程。

Dataloader 初始化

Dataloader 的参数很多,有些是互斥的,这里还是以最简单的情况为例分析数据读取流程:

ds = dataset(h5file, label_df)
dl = torch.utils.data.Dataloader(ds)

那么其他参数都是 Dataloader 的默认参数了,需要注意的是 num_workers 此时默认为 0,也就是单进程读取。

class DataLoader(object):
    
    __initialized = False
    
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=None,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None, multiprocessing_context=None):
        torch._C._log_api_usage_once("python.data_loader")

        if num_workers < 0:
            raise ValueError('num_workers option should be non-negative; '
                             'use num_workers=0 to disable multiprocessing.')

        if timeout < 0:
            raise ValueError('timeout option should be non-negative')

            self.dataset = dataset
            self.num_workers = num_workers
            self.pin_memory = pin_memory
            self.timeout = timeout
            self.worker_init_fn = worker_init_fn
            self.multiprocessing_context = multiprocessing_context

首先检查参数并将初始化参数设置一下。

if isinstance(dataset, IterableDataset):
    # ...
    pass
else:                                                                             
    self._dataset_kind = _DatasetKind.Map

这里根据 dataset 的种类来检查其他的参数。_DatasetKind中定义了MapIterable两种dataset,很好理解:Map就是我们前面定义的那样,按 map 风格用 key 取 value 的 dataset;Iterable就是只能迭代取值无法随机读的 dataset,适合 stream data。所以这里_dataset_kind就设置成了MapIterable的分支代码在此就省略,让思路更清晰。

class _DatasetKind(object):
    Map = 0    
    Iterable = 1

    @staticmethod
    def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
        # ...
        pass

接下来要确定sampler或者batch_sampler,这是用来生成要读取数据的 index 的采样器,既可以传入自己实现的按自己需求取数据的sampler,也可以传入直接返回 index batch 的batch_sampler

if sampler is not None and shuffle:
    # ...
    pass

if batch_sampler is not None:
    # ...
    pass
elif batch_size is None:
    # ...
    pass

if sampler is None:  # give default samplers
    if self._dataset_kind == _DatasetKind.Iterable:   
        sampler = _InfiniteConstantSampler()
    else:  # map-style
        if shuffle:
            sampler = RandomSampler(dataset)
        else:
            sampler = SequentialSampler(dataset)

        if batch_size is not None and batch_sampler is None:
            # auto_collation without custom batch_sampler
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

self.batch_size = batch_size
self.drop_last = drop_last
self.sampler = sampler
self.batch_sampler = batch_sampler

默认的 samplebatch_samplerNonebatch_size是 1,shuffleFalse 所以 sampler 设置成了 SequentialSampler,这个类的详细实现见源码,非常简单,就是一个顺序生成 index 的 Iterable;如果 shuffleTruesampler 就是 RandomSampler,也是一个很简单的实现,只是将全体 index 先打乱,再顺序生成。

接着 batch_sampler 就是默认给定的 BatchSampler,它的实现也很简单:

def __iter__(self):
    batch = []
    for idx in self.sampler:
        batch.append(idx)
        if len(batch) == self.batch_size:
            yield batch
            batch = []
    if len(batch) > 0 and not self.drop_last:
        yield batch

迭代过程就是不断往 batch 中添加 sampler 生成的 index 直到达到 batch_size,返回这个 index 组成的 batch。

if collate_fn is None:
    if self._auto_collation:
        collate_fn = _utils.collate.default_collate
    else:
        collate_fn = _utils.collate.default_convert
           
self.collate_fn = collate_fn
self.__initialized = True

最后是初始化 collate_fn,这个参数表示将根据 index 取出的 samples 整合成需要的输入的函数,比如对于变长输入就需要 pad 到长度最长的 sample。这里是根据 _auto_collation 这个参数确定的。

@property                                                                             
def _auto_collation(self):
    return self.batch_sampler is not None

那么很明显,batch_sampler 不是 None,于是 collate_fn 就是 default_collate

数据读取

循环读取数据就是对 iter() 得到的 Iterable 结果不断 next 取值,于是:

def __iter__(self): 
    if self.num_workers == 0:
        return _SingleProcessDataLoaderIter(self)
    else:
        return _MultiProcessingDataLoaderIter(self)

由于是单线程,返回的 Iterable 就是 _SingleProcessDataLoaderIter

class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
      def __init__(self, loader):
          super(_SingleProcessDataLoaderIter, self).__init__(loader)
          assert self._timeout == 0
          assert self._num_workers == 0
        
          self._dataset_fetcher = _DatasetKind.create_fetcher(
              self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
       
      def __next__(self):
          index = self._next_index()  # may raise StopIteration
          data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
          if self._pin_memory:
              data = _utils.pin_memory.pin_memory(data)
          return data
       
      next = __next__  # Python 2 compatibility

我们来看 __next__ 的过程:总体思路很简单,先生成 index,然后根据 index 来 fectch data,如果有 pin_memory 的要求再进行相应的操作,这里我们设定数据只读到内存的情况,不需要 pin_memory,所以就直接返回 data。

生成 index

生成 index 是通过 _next_index 完成的, 定义在基类 _BaseDataLoaderIter 中:

class _BaseDataLoaderIter(object):
    def __init__(self, loader):
        self._dataset = loader.dataset
        self._dataset_kind = loader._dataset_kind
        self._auto_collation = loader._auto_collation
        self._drop_last = loader.drop_last
        self._index_sampler = loader._index_sampler
        self._num_workers = loader.num_workers
        self._pin_memory = loader.pin_memory and torch.cuda.is_available()
        self._timeout = loader.timeout
        self._collate_fn = loader.collate_fn
        self._sampler_iter = iter(self._index_sampler)
        self._base_seed = torch.empty((), dtype=torch.int64).random_().item()

    def __iter__(self):
        return self

    def _next_index(self):
        return next(self._sampler_iter)  # may raise StopIteration

可以看到,指向的是 _sampler_iter,这又是 iter(self._index_sampler),即 loader._index_sampler,定义在 Dataloader 的参数中:

@property
def _index_sampler(self):
    if self._auto_collation:
        return self.batch_sampler
    else:
        return self.sampler

前面已分析过,_auto_collationTrue,所以 index_sampler 就是 self.batch_sampler,于是生成的 index 就是前面分析过的迭代过程:不断往batch中添加sampler生成的 index 直到达到batch_size,返回这个 index 组成的 batch。

fetch data

fetch data 的操作就是这一句:

data = self._dataset_fetcher.fetch(index)

那么 _dataset_fetcher 是怎么定义的呢?

self._dataset_fetcher = _DatasetKind.create_fetcher(
    self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

_DatasetKind 中定义的:

@staticmethod
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):  
    if kind == _DatasetKind.Map:
        return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
    else:  
        return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)

显然这里用的是 _MapDatasetFetcher

class _BaseDatasetFetcher(object):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        self.dataset = dataset
        self.auto_collation = auto_collation
        self.collate_fn = collate_fn
        self.drop_last = drop_last

    def fetch(self, possibly_batched_index):
        raise NotImplementedError()

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

所以这里取数据的过程就很清楚了:auto_collation 就是 _SingleProcessDataLoaderIter_auto_collation,值为 True,所以 fetch 函数先将 index batch 中每个 index 对应 dataset 中的数据取出来放到一个 list 中,然后返回经过 collate_fn 整合处理之后的结果。

这里的 collate_fndefault_collate,就不分析代码了,过程也很简单,就是将各个单独的数据整合成一个总的 Tensor,多出第一维即 batch_size。

总结一下,DatasetMap 风格的或者 Iterable 的生成数据的数据源;Dataloader 接受各种参数,迭代过程中不断生成数据 batch;每次迭代过程 (Single Process) 是:由 batch_sampler 迭代取出下一批数据的 index -> fetch 函数将数据 batch 放到一个 list 中 -> 将这个 list 送给 collate_fn (可以自己定义) 得到整合之后的 data batch,一般第一维是 batch_size。