|
@@ -95,12 +95,15 @@ def _process_dataset_params(cfg, dataset_params, train):
|
|
|
|
|
|
|
|
|
|
def _process_dataloader_params(cfg, dataloader_params, dataset, train):
|
|
def _process_dataloader_params(cfg, dataloader_params, dataset, train):
|
|
- default_dataloader_params = (
|
|
|
|
- cfg.dataset_params.train_dataloader_params if train else cfg.dataset_params.val_dataloader_params
|
|
|
|
- )
|
|
|
|
|
|
+ default_dataloader_params = cfg.dataset_params.train_dataloader_params if train else cfg.dataset_params.val_dataloader_params
|
|
default_dataloader_params = hydra.utils.instantiate(default_dataloader_params)
|
|
default_dataloader_params = hydra.utils.instantiate(default_dataloader_params)
|
|
- is_dist = super_gradients.is_distributed()
|
|
|
|
|
|
+ dataloader_params = _process_sampler_params(dataloader_params, dataset, default_dataloader_params)
|
|
|
|
+
|
|
|
|
+ return dataloader_params
|
|
|
|
+
|
|
|
|
|
|
|
|
+def _process_sampler_params(dataloader_params, dataset, default_dataloader_params):
|
|
|
|
+ is_dist = super_gradients.is_distributed()
|
|
if get_param(dataloader_params, "sampler") is not None:
|
|
if get_param(dataloader_params, "sampler") is not None:
|
|
dataloader_params = _instantiate_sampler(dataset, dataloader_params)
|
|
dataloader_params = _instantiate_sampler(dataset, dataloader_params)
|
|
elif get_param(default_dataloader_params, "sampler") is not None:
|
|
elif get_param(default_dataloader_params, "sampler") is not None:
|
|
@@ -108,13 +111,11 @@ def _process_dataloader_params(cfg, dataloader_params, dataset, train):
|
|
elif is_dist:
|
|
elif is_dist:
|
|
default_dataloader_params["sampler"] = {"DistributedSampler": {}}
|
|
default_dataloader_params["sampler"] = {"DistributedSampler": {}}
|
|
default_dataloader_params = _instantiate_sampler(dataset, default_dataloader_params)
|
|
default_dataloader_params = _instantiate_sampler(dataset, default_dataloader_params)
|
|
-
|
|
|
|
dataloader_params = override_default_params_without_nones(dataloader_params, default_dataloader_params)
|
|
dataloader_params = override_default_params_without_nones(dataloader_params, default_dataloader_params)
|
|
if get_param(dataloader_params, "batch_sampler"):
|
|
if get_param(dataloader_params, "batch_sampler"):
|
|
sampler = dataloader_params.pop("sampler")
|
|
sampler = dataloader_params.pop("sampler")
|
|
batch_size = dataloader_params.pop("batch_size")
|
|
batch_size = dataloader_params.pop("batch_size")
|
|
dataloader_params["batch_sampler"] = BatchSampler(sampler=sampler, batch_size=batch_size, drop_last=False)
|
|
dataloader_params["batch_sampler"] = BatchSampler(sampler=sampler, batch_size=batch_size, drop_last=False)
|
|
-
|
|
|
|
return dataloader_params
|
|
return dataloader_params
|
|
|
|
|
|
|
|
|
|
@@ -641,18 +642,26 @@ ALL_DATALOADERS = {
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
-def get(name: str, dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
|
|
|
|
|
|
+def get(name: str = None, dataset_params: Dict = None, dataloader_params: Dict = None, dataset: torch.utils.data.Dataset = None) -> DataLoader:
|
|
"""
|
|
"""
|
|
Get DataLoader of the recipe-configured dataset defined by name in ALL_DATALOADERS.
|
|
Get DataLoader of the recipe-configured dataset defined by name in ALL_DATALOADERS.
|
|
|
|
|
|
:param name: dataset name in ALL_DATALOADERS.
|
|
:param name: dataset name in ALL_DATALOADERS.
|
|
:param dataset_params: dataset params that override the yaml configured defaults, then passed to the dataset_cls.__init__.
|
|
:param dataset_params: dataset params that override the yaml configured defaults, then passed to the dataset_cls.__init__.
|
|
:param dataloader_params: DataLoader params that override the yaml configured defaults, then passed to the DataLoader.__init__
|
|
:param dataloader_params: DataLoader params that override the yaml configured defaults, then passed to the DataLoader.__init__
|
|
|
|
+ :param dataset: torch.utils.data.Dataset to be used instead of passing "name" (i.e for external dataset objects).
|
|
:return: initialized DataLoader.
|
|
:return: initialized DataLoader.
|
|
"""
|
|
"""
|
|
|
|
|
|
- if name not in ALL_DATALOADERS.keys():
|
|
|
|
|
|
+ if dataset is not None:
|
|
|
|
+ if name or dataset_params:
|
|
|
|
+ raise ValueError("'name' and 'dataset_params' cannot be passed with initialized dataset.")
|
|
|
|
+ dataloader_params = _process_sampler_params(dataloader_params, dataset, {})
|
|
|
|
+ dataloader = DataLoader(dataset=dataset, **dataloader_params)
|
|
|
|
+ elif name not in ALL_DATALOADERS.keys():
|
|
raise ValueError("Unsupported dataloader: " + str(name))
|
|
raise ValueError("Unsupported dataloader: " + str(name))
|
|
|
|
+ else:
|
|
|
|
+ dataloader_cls = ALL_DATALOADERS[name]
|
|
|
|
+ dataloader = dataloader_cls(dataset_params=dataset_params, dataloader_params=dataloader_params)
|
|
|
|
|
|
- dataloader_cls = ALL_DATALOADERS[name]
|
|
|
|
- return dataloader_cls(dataset_params=dataset_params, dataloader_params=dataloader_params)
|
|
|
|
|
|
+ return dataloader
|