Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

dataset_interface_test.py 8.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
  1. import unittest
  2. from super_gradients.training.datasets.dataset_interfaces.dataset_interface import PascalVOCUnifiedDetectionDatasetInterface,\
  3. CoCoDetectionDatasetInterface
  4. from super_gradients.training.transforms.transforms import DetectionPaddedRescale, DetectionTargetsFormatTransform, DetectionMosaic, DetectionRandomAffine,\
  5. DetectionHSV
  6. from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
  7. from super_gradients.training.utils.detection_utils import DetectionCollateFN
  8. from super_gradients.training.utils import sg_trainer_utils
  9. from super_gradients.training import utils as core_utils
  10. class TestDatasetInterface(unittest.TestCase):
  11. def setUp(self) -> None:
  12. self.root_dir = "/home/louis.dupont/data/"
  13. self.train_batch_size, self.val_batch_size = 16, 32
  14. self.train_image_size, self.val_image_size = 640, 640
  15. self.train_input_dim = (self.train_image_size, self.train_image_size)
  16. self.val_input_dim = (self.val_image_size, self.val_image_size)
  17. self.train_max_num_samples = 100
  18. self.val_max_num_samples = 90
  19. def setup_pascal_voc_interface(self):
  20. """setup PascalVOCUnifiedDetectionDatasetInterface and return dataloaders"""
  21. dataset_params = {
  22. "data_dir": self.root_dir + "pascal_unified_coco_format/",
  23. "cache_dir": self.root_dir + "pascal_unified_coco_format/",
  24. "batch_size": self.train_batch_size,
  25. "val_batch_size": self.val_batch_size,
  26. "train_image_size": self.train_image_size,
  27. "val_image_size": self.val_image_size,
  28. "train_max_num_samples": self.train_max_num_samples,
  29. "val_max_num_samples": self.val_max_num_samples,
  30. "train_transforms": [
  31. DetectionMosaic(input_dim=self.train_input_dim, prob=1),
  32. DetectionRandomAffine(degrees=0.373, translate=0.245, scales=0.898, shear=0.602, target_size=self.train_input_dim),
  33. DetectionHSV(prob=1, hgain=0.0138, sgain=0.664, vgain=0.464),
  34. DetectionPaddedRescale(input_dim=self.train_input_dim, max_targets=100),
  35. DetectionTargetsFormatTransform(input_format=DetectionTargetsFormat.XYXY_LABEL,
  36. output_format=DetectionTargetsFormat.LABEL_CXCYWH)],
  37. "val_transforms": [
  38. DetectionPaddedRescale(input_dim=self.val_input_dim),
  39. DetectionTargetsFormatTransform(input_format=DetectionTargetsFormat.XYXY_LABEL,
  40. output_format=DetectionTargetsFormat.LABEL_CXCYWH)],
  41. "train_collate_fn": DetectionCollateFN(),
  42. "val_collate_fn": DetectionCollateFN(),
  43. "download": False,
  44. "cache_train_images": False,
  45. "cache_val_images": False,
  46. "class_inclusion_list": ["person"]
  47. }
  48. dataset_interface = PascalVOCUnifiedDetectionDatasetInterface(dataset_params=dataset_params)
  49. train_loader, valid_loader, _test_loader, _classes = dataset_interface.get_data_loaders()
  50. return train_loader, valid_loader
  51. def setup_coco_detection_interface(self):
  52. """setup CoCoDetectionDatasetInterface and return dataloaders"""
  53. dataset_params = {
  54. "data_dir": "/data/coco",
  55. "train_subdir": "images/train2017", # sub directory path of data_dir containing the train data.
  56. "val_subdir": "images/val2017", # sub directory path of data_dir containing the validation data.
  57. "train_json_file": "instances_train2017.json", # path to coco train json file, data_dir/annotations/train_json_file.
  58. "val_json_file": "instances_val2017.json", # path to coco validation json file, data_dir/annotations/val_json_file.
  59. "batch_size": self.train_batch_size,
  60. "val_batch_size": self.val_batch_size,
  61. "train_image_size": self.train_image_size,
  62. "val_image_size": self.val_image_size,
  63. "train_max_num_samples": self.train_max_num_samples,
  64. "val_max_num_samples": self.val_max_num_samples,
  65. "mixup_prob": 1.0, # probability to apply per-sample mixup
  66. "degrees": 10., # rotation degrees, randomly sampled from [-degrees, degrees]
  67. "shear": 2.0, # shear degrees, randomly sampled from [-degrees, degrees]
  68. "flip_prob": 0.5, # probability to apply horizontal flip
  69. "hsv_prob": 1.0, # probability to apply HSV transform
  70. "hgain": 5, # HSV transform hue gain (randomly sampled from [-hgain, hgain])
  71. "sgain": 30, # HSV transform saturation gain (randomly sampled from [-sgain, sgain])
  72. "vgain": 30, # HSV transform value gain (randomly sampled from [-vgain, vgain])
  73. "mosaic_scale": [0.1, 2], # random rescale range (keeps size by padding/cropping) after mosaic transform.
  74. "mixup_scale": [0.5, 1.5], # random rescale range for the additional sample in mixup
  75. "mosaic_prob": 1., # probability to apply mosaic
  76. "translate": 0.1, # image translation fraction
  77. "filter_box_candidates": False, # whether to filter out transformed bboxes by edge size, area ratio, and aspect ratio.
  78. "wh_thr": 2, # edge size threshold when filter_box_candidates = True (pixels)
  79. "ar_thr": 20, # aspect ratio threshold when filter_box_candidates = True
  80. "area_thr": 0.1, # threshold for area ratio between original image and the transformed one, when when filter_box_candidates = True
  81. "tight_box_rotation": False,
  82. "download": False,
  83. "train_collate_fn": DetectionCollateFN(),
  84. "val_collate_fn": DetectionCollateFN(),
  85. "cache_train_images": False,
  86. "cache_val_images": False,
  87. "cache_dir": "/home/data/cache", # Depends on the user
  88. "class_inclusion_list": None
  89. # "with_crowd": True
  90. }
  91. dataset_interface = CoCoDetectionDatasetInterface(dataset_params=dataset_params)
  92. train_loader, valid_loader, _test_loader, _classes = dataset_interface.get_data_loaders()
  93. return train_loader, valid_loader
  94. def test_coco_detection(self):
  95. """Check that the dataset interface is correctly instantiated, and that the batch items are of expected size"""
  96. train_loader, valid_loader = self.setup_coco_detection_interface()
  97. for loader, batch_size, image_size, max_num_samples in [(train_loader, self.train_batch_size, self.train_image_size, self.train_max_num_samples),
  98. (valid_loader, self.val_batch_size, self.val_image_size, self.val_max_num_samples)]:
  99. # The dataset is at most of length max_num_samples, but can be smaller if not enough samples
  100. self.assertGreaterEqual(max_num_samples, len(loader.dataset))
  101. batch_items = next(iter(loader))
  102. batch_items = core_utils.tensor_container_to_device(batch_items, 'cuda', non_blocking=True)
  103. inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items)
  104. self.assertListEqual([batch_size, 3, image_size, image_size], list(inputs.shape))
  105. def test_pascal_voc(self):
  106. """Check that the dataset interface is correctly instantiated, and that the batch items are of expected size"""
  107. train_loader, valid_loader = self.setup_pascal_voc_interface()
  108. for loader, batch_size, image_size, max_num_samples in [(train_loader, self.train_batch_size, self.train_image_size, self.train_max_num_samples),
  109. (valid_loader, self.val_batch_size, self.val_image_size, self.val_max_num_samples)]:
  110. # The dataset is at most of length max_num_samples, but can be smaller if not enough samples
  111. self.assertGreaterEqual(max_num_samples, len(loader.dataset))
  112. batch_items = next(iter(loader))
  113. batch_items = core_utils.tensor_container_to_device(batch_items, 'cuda', non_blocking=True)
  114. inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items)
  115. self.assertListEqual([batch_size, 3, image_size, image_size], list(inputs.shape))
  116. if __name__ == '__main__':
  117. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...