Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#450 Feature/sg 321 ddp sampler handling for external dataloaders

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-321_ddp_sampler_handling_for_external_dataloaders
1 changed files with 7 additions and 3 deletions
  1. 7
    3
      tests/unit_tests/datalaoder_factory_test.py
@@ -40,6 +40,7 @@ from super_gradients.training.dataloaders.dataloaders import (
     supervisely_persons_val,
     supervisely_persons_val,
     pascal_voc_detection_train,
     pascal_voc_detection_train,
     pascal_voc_detection_val,
     pascal_voc_detection_val,
+    get,
 )
 )
 from super_gradients.training.datasets import (
 from super_gradients.training.datasets import (
     COCODetectionDataset,
     COCODetectionDataset,
@@ -136,9 +137,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
 
 
     def test_imagenet_resnet50_kd_train_creation(self):
     def test_imagenet_resnet50_kd_train_creation(self):
         # Here we need to overwrite the sampler because the RepeatAugSampler used in KD is only supported for DDP
         # Here we need to overwrite the sampler because the RepeatAugSampler used in KD is only supported for DDP
-        dl = imagenet_resnet50_kd_train(
-            dataloader_params={"sampler": {"InfiniteSampler": {}}}
-        )
+        dl = imagenet_resnet50_kd_train(dataloader_params={"sampler": {"InfiniteSampler": {}}})
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
 
 
@@ -242,6 +241,11 @@ class DataLoaderFactoryTest(unittest.TestCase):
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl.dataset, PascalVOCDetectionDataset))
         self.assertTrue(isinstance(dl.dataset, PascalVOCDetectionDataset))
 
 
+    def test_get_with_external_dataset_creation(self):
+        dataset = Cifar10(root="./data/cifar10", train=False, download=True)
+        dl = get(dataset=dataset, dataloader_params={"batch_size": 256, "num_workers": 8, "drop_last": False, "pin_memory": True})
+        self.assertTrue(isinstance(dl, DataLoader))
+
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     unittest.main()
     unittest.main()
Discard