|
@@ -40,6 +40,7 @@ from super_gradients.training.dataloaders.dataloaders import (
|
|
supervisely_persons_val,
|
|
supervisely_persons_val,
|
|
pascal_voc_detection_train,
|
|
pascal_voc_detection_train,
|
|
pascal_voc_detection_val,
|
|
pascal_voc_detection_val,
|
|
|
|
+ get,
|
|
)
|
|
)
|
|
from super_gradients.training.datasets import (
|
|
from super_gradients.training.datasets import (
|
|
COCODetectionDataset,
|
|
COCODetectionDataset,
|
|
@@ -136,9 +137,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
|
|
|
|
def test_imagenet_resnet50_kd_train_creation(self):
|
|
def test_imagenet_resnet50_kd_train_creation(self):
|
|
# Here we need to overwrite the sampler because the RepeatAugSampler used in KD is only supported for DDP
|
|
# Here we need to overwrite the sampler because the RepeatAugSampler used in KD is only supported for DDP
|
|
- dl = imagenet_resnet50_kd_train(
|
|
|
|
- dataloader_params={"sampler": {"InfiniteSampler": {}}}
|
|
|
|
- )
|
|
|
|
|
|
+ dl = imagenet_resnet50_kd_train(dataloader_params={"sampler": {"InfiniteSampler": {}}})
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
|
|
|
|
@@ -242,6 +241,11 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl.dataset, PascalVOCDetectionDataset))
|
|
self.assertTrue(isinstance(dl.dataset, PascalVOCDetectionDataset))
|
|
|
|
|
|
|
|
+ def test_get_with_external_dataset_creation(self):
|
|
|
|
+ dataset = Cifar10(root="./data/cifar10", train=False, download=True)
|
|
|
|
+ dl = get(dataset=dataset, dataloader_params={"batch_size": 256, "num_workers": 8, "drop_last": False, "pin_memory": True})
|
|
|
|
+ self.assertTrue(isinstance(dl, DataLoader))
|
|
|
|
+
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
unittest.main()
|
|
unittest.main()
|