Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#450 Feature/sg 321 ddp sampler handling for external dataloaders

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-321_ddp_sampler_handling_for_external_dataloaders
@@ -95,12 +95,15 @@ def _process_dataset_params(cfg, dataset_params, train):
 
 
 
 
 def _process_dataloader_params(cfg, dataloader_params, dataset, train):
 def _process_dataloader_params(cfg, dataloader_params, dataset, train):
-    default_dataloader_params = (
-        cfg.dataset_params.train_dataloader_params if train else cfg.dataset_params.val_dataloader_params
-    )
+    default_dataloader_params = cfg.dataset_params.train_dataloader_params if train else cfg.dataset_params.val_dataloader_params
     default_dataloader_params = hydra.utils.instantiate(default_dataloader_params)
     default_dataloader_params = hydra.utils.instantiate(default_dataloader_params)
-    is_dist = super_gradients.is_distributed()
+    dataloader_params = _process_sampler_params(dataloader_params, dataset, default_dataloader_params)
+
+    return dataloader_params
+
 
 
+def _process_sampler_params(dataloader_params, dataset, default_dataloader_params):
+    is_dist = super_gradients.is_distributed()
     if get_param(dataloader_params, "sampler") is not None:
     if get_param(dataloader_params, "sampler") is not None:
         dataloader_params = _instantiate_sampler(dataset, dataloader_params)
         dataloader_params = _instantiate_sampler(dataset, dataloader_params)
     elif get_param(default_dataloader_params, "sampler") is not None:
     elif get_param(default_dataloader_params, "sampler") is not None:
@@ -108,13 +111,11 @@ def _process_dataloader_params(cfg, dataloader_params, dataset, train):
     elif is_dist:
     elif is_dist:
         default_dataloader_params["sampler"] = {"DistributedSampler": {}}
         default_dataloader_params["sampler"] = {"DistributedSampler": {}}
         default_dataloader_params = _instantiate_sampler(dataset, default_dataloader_params)
         default_dataloader_params = _instantiate_sampler(dataset, default_dataloader_params)
-
     dataloader_params = override_default_params_without_nones(dataloader_params, default_dataloader_params)
     dataloader_params = override_default_params_without_nones(dataloader_params, default_dataloader_params)
     if get_param(dataloader_params, "batch_sampler"):
     if get_param(dataloader_params, "batch_sampler"):
         sampler = dataloader_params.pop("sampler")
         sampler = dataloader_params.pop("sampler")
         batch_size = dataloader_params.pop("batch_size")
         batch_size = dataloader_params.pop("batch_size")
         dataloader_params["batch_sampler"] = BatchSampler(sampler=sampler, batch_size=batch_size, drop_last=False)
         dataloader_params["batch_sampler"] = BatchSampler(sampler=sampler, batch_size=batch_size, drop_last=False)
-
     return dataloader_params
     return dataloader_params
 
 
 
 
@@ -641,18 +642,26 @@ ALL_DATALOADERS = {
 }
 }
 
 
 
 
-def get(name: str, dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
+def get(name: str = None, dataset_params: Dict = None, dataloader_params: Dict = None, dataset: torch.utils.data.Dataset = None) -> DataLoader:
     """
     """
     Get DataLoader of the recipe-configured dataset defined by name in ALL_DATALOADERS.
     Get DataLoader of the recipe-configured dataset defined by name in ALL_DATALOADERS.
 
 
     :param name: dataset name in ALL_DATALOADERS.
     :param name: dataset name in ALL_DATALOADERS.
     :param dataset_params: dataset params that override the yaml configured defaults, then passed to the dataset_cls.__init__.
     :param dataset_params: dataset params that override the yaml configured defaults, then passed to the dataset_cls.__init__.
     :param dataloader_params: DataLoader params that override the yaml configured defaults, then passed to the DataLoader.__init__
     :param dataloader_params: DataLoader params that override the yaml configured defaults, then passed to the DataLoader.__init__
+    :param dataset: torch.utils.data.Dataset to be used instead of passing "name" (i.e for external dataset objects).
     :return: initialized DataLoader.
     :return: initialized DataLoader.
     """
     """
 
 
-    if name not in ALL_DATALOADERS.keys():
+    if dataset is not None:
+        if name or dataset_params:
+            raise ValueError("'name' and 'dataset_params' cannot be passed with initialized dataset.")
+        dataloader_params = _process_sampler_params(dataloader_params, dataset, {})
+        dataloader = DataLoader(dataset=dataset, **dataloader_params)
+    elif name not in ALL_DATALOADERS.keys():
         raise ValueError("Unsupported dataloader: " + str(name))
         raise ValueError("Unsupported dataloader: " + str(name))
+    else:
+        dataloader_cls = ALL_DATALOADERS[name]
+        dataloader = dataloader_cls(dataset_params=dataset_params, dataloader_params=dataloader_params)
 
 
-    dataloader_cls = ALL_DATALOADERS[name]
-    return dataloader_cls(dataset_params=dataset_params, dataloader_params=dataloader_params)
+    return dataloader
Discard
@@ -40,6 +40,7 @@ from super_gradients.training.dataloaders.dataloaders import (
     supervisely_persons_val,
     supervisely_persons_val,
     pascal_voc_detection_train,
     pascal_voc_detection_train,
     pascal_voc_detection_val,
     pascal_voc_detection_val,
+    get,
 )
 )
 from super_gradients.training.datasets import (
 from super_gradients.training.datasets import (
     COCODetectionDataset,
     COCODetectionDataset,
@@ -136,9 +137,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
 
 
     def test_imagenet_resnet50_kd_train_creation(self):
     def test_imagenet_resnet50_kd_train_creation(self):
         # Here we need to overwrite the sampler because the RepeatAugSampler used in KD is only supported for DDP
         # Here we need to overwrite the sampler because the RepeatAugSampler used in KD is only supported for DDP
-        dl = imagenet_resnet50_kd_train(
-            dataloader_params={"sampler": {"InfiniteSampler": {}}}
-        )
+        dl = imagenet_resnet50_kd_train(dataloader_params={"sampler": {"InfiniteSampler": {}}})
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
 
 
@@ -242,6 +241,11 @@ class DataLoaderFactoryTest(unittest.TestCase):
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl.dataset, PascalVOCDetectionDataset))
         self.assertTrue(isinstance(dl.dataset, PascalVOCDetectionDataset))
 
 
+    def test_get_with_external_dataset_creation(self):
+        dataset = Cifar10(root="./data/cifar10", train=False, download=True)
+        dl = get(dataset=dataset, dataloader_params={"batch_size": 256, "num_workers": 8, "drop_last": False, "pin_memory": True})
+        self.assertTrue(isinstance(dl, DataLoader))
+
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     unittest.main()
     unittest.main()
Discard