Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

dataset_interface_test.py 4.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
  1. import unittest
  2. from super_gradients.training.datasets.dataset_interfaces.dataset_interface import PascalVOCUnifiedDetectionDatasetInterface
  3. from super_gradients.training.transforms.transforms import DetectionPaddedRescale, DetectionTargetsFormatTransform, DetectionMosaic, DetectionRandomAffine,\
  4. DetectionHSV
  5. from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
  6. from super_gradients.training.utils.detection_utils import DetectionCollateFN
  7. from super_gradients.training.utils import sg_model_utils
  8. from super_gradients.training import utils as core_utils
  9. class TestDatasetInterface(unittest.TestCase):
  10. def setUp(self) -> None:
  11. self.root_dir = "/home/data/"
  12. self.train_batch_size, self.val_batch_size = 16, 32
  13. self.train_image_size, self.val_image_size = 640, 640
  14. self.train_input_dim = (self.train_image_size, self.train_image_size)
  15. self.val_input_dim = (self.val_image_size, self.val_image_size)
  16. self.train_max_num_samples = 100
  17. self.val_max_num_samples = 90
  18. def setup_pascal_voc_interface(self):
  19. """setup PascalVOCUnifiedDetectionDataSetInterfaceV2 and return dataloaders"""
  20. dataset_params = {
  21. "data_dir": self.root_dir + "pascal_unified_coco_format/",
  22. "cache_dir": self.root_dir + "pascal_unified_coco_format/",
  23. "batch_size": self.train_batch_size,
  24. "val_batch_size": self.val_batch_size,
  25. "train_image_size": self.train_image_size,
  26. "val_image_size": self.val_image_size,
  27. "train_max_num_samples": self.train_max_num_samples,
  28. "val_max_num_samples": self.val_max_num_samples,
  29. "train_transforms": [
  30. DetectionMosaic(input_dim=self.train_input_dim, prob=1),
  31. DetectionRandomAffine(degrees=0.373, translate=0.245, scales=0.898, shear=0.602, target_size=self.train_input_dim),
  32. DetectionHSV(prob=1, hgain=0.0138, sgain=0.664, vgain=0.464),
  33. DetectionPaddedRescale(input_dim=self.train_input_dim, max_targets=100),
  34. DetectionTargetsFormatTransform(input_format=DetectionTargetsFormat.XYXY_LABEL,
  35. output_format=DetectionTargetsFormat.LABEL_CXCYWH)],
  36. "val_transforms": [
  37. DetectionPaddedRescale(input_dim=self.val_input_dim),
  38. DetectionTargetsFormatTransform(input_format=DetectionTargetsFormat.XYXY_LABEL,
  39. output_format=DetectionTargetsFormat.LABEL_CXCYWH)],
  40. "train_collate_fn": DetectionCollateFN(),
  41. "val_collate_fn": DetectionCollateFN(),
  42. "download": False,
  43. "cache_train_images": False,
  44. "cache_val_images": False,
  45. "class_inclusion_list": ["person"]
  46. }
  47. dataset_interface = PascalVOCUnifiedDetectionDatasetInterface(dataset_params=dataset_params)
  48. train_loader, valid_loader, _test_loader, _classes = dataset_interface.get_data_loaders()
  49. return train_loader, valid_loader
  50. def test_pascal_voc(self):
  51. """Check that the dataset interface is correctly instantiated, and that the batch items are of expected size"""
  52. train_loader, valid_loader = self.setup_pascal_voc_interface()
  53. for loader, batch_size, image_size, max_num_samples in [(train_loader, self.train_batch_size, self.train_image_size, self.train_max_num_samples),
  54. (valid_loader, self.val_batch_size, self.val_image_size, self.val_max_num_samples)]:
  55. # The dataset is at most of length max_num_samples, but can be smaller if not enough samples
  56. self.assertGreaterEqual(max_num_samples, len(loader.dataset))
  57. batch_items = next(iter(loader))
  58. batch_items = core_utils.tensor_container_to_device(batch_items, 'cuda', non_blocking=True)
  59. inputs, targets, additional_batch_items = sg_model_utils.unpack_batch_items(batch_items)
  60. self.assertListEqual([batch_size, 3, image_size, image_size], list(inputs.shape))
  61. if __name__ == '__main__':
  62. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...