Source code for accel_hydra.utils.data

from copy import deepcopy

import hydra


[docs] def init_dataloader_from_config(config: dict): """ A helper function to initialize a dataloader from a config. Args: config: A dictionary or DictConfig containing the dataloader configuration. Returns: instantiated dataloader object. Example: .. code-block:: python config = ''' train_dataloader: _target_: torch.utils.data.DataLoader dataset: _target_: data.train_dataset data_root: /path/to/data # sampler: # _target_: torch.utils.data.Sampler # ... # batch_sampler: # _target_: torch.utils.data.BatchSampler # ... ''' config = OmegaConf.create(config) train_dataloader = init_dataloader_from_config(config["train_dataloader"]) """ config = deepcopy(config) kwargs = {} # NOTE: Here we manually take class and instantiating arguments to avoid the dataset class being # converted between instance and OmegaConf.DictConfig, especially when there is randomness in # the initialization process of the dataset instance dataset = hydra.utils.instantiate(config.pop("dataset"), _convert_="all") kwargs["dataset"] = dataset if "sampler" in config: sampler_cfg = config.pop("sampler") sampler_cls = hydra.utils.get_class(sampler_cfg.pop("_target_")) sampler = sampler_cls(data_source=dataset, **sampler_cfg) kwargs["sampler"] = sampler elif "batch_sampler" in config: batch_sampler_cfg = config.pop("batch_sampler") batch_sampler_cls = hydra.utils.get_class( batch_sampler_cfg.pop("_target_") ) batch_sampler = batch_sampler_cls( data_source=dataset, **batch_sampler_cfg ) kwargs["batch_sampler"] = batch_sampler if "collate_fn" in config: collate_fn = hydra.utils.instantiate( config.pop("collate_fn"), _convert_="all" ) else: collate_fn = None kwargs["collate_fn"] = collate_fn dataloader_cls = hydra.utils.get_class(config.pop("_target_")) dataloader = dataloader_cls(**kwargs, **config) return dataloader