Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

coco_detection.py 9.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
  1. import copy
  2. import os
  3. import cv2
  4. import numpy as np
  5. from pycocotools.coco import COCO
  6. from super_gradients.common.abstractions.abstract_logger import get_logger
  7. from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
  8. from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset
  9. from super_gradients.training.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException
  10. from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
  11. logger = get_logger(__name__)
  12. class COCODetectionDataset(DetectionDataset):
  13. """Dataset for COCO object detection.
  14. To use this Dataset you need to:
  15. - Download coco dataset:
  16. annotations: http://images.cocodataset.org/annotations/annotations_trainval2017.zip
  17. train2017: http://images.cocodataset.org/zips/train2017.zip
  18. val2017: http://images.cocodataset.org/zips/val2017.zip
  19. - Unzip and organize it as below:
  20. coco
  21. ├── annotations
  22. │ ├─ instances_train2017.json
  23. │ ├─ instances_val2017.json
  24. │ └─ ...
  25. └── images
  26. ├── train2017
  27. │ ├─ 000000000001.jpg
  28. │ └─ ...
  29. └── val2017
  30. └─ ...
  31. - Install CoCo API: https://github.com/pdollar/coco/tree/master/PythonAPI
  32. - Instantiate the dataset:
  33. >> train_set = COCODetectionDataset(data_dir='.../coco', subdir='images/train2017', json_file='instances_train2017.json', ...)
  34. >> valid_set = COCODetectionDataset(data_dir='.../coco', subdir='images/val2017', json_file='instances_val2017.json', ...)
  35. """
  36. def __init__(
  37. self,
  38. json_file: str = "instances_train2017.json",
  39. subdir: str = "images/train2017",
  40. tight_box_rotation: bool = False,
  41. with_crowd: bool = True,
  42. *args,
  43. **kwargs,
  44. ):
  45. """
  46. :param json_file: Name of the coco json file, that resides in data_dir/annotations/json_file.
  47. :param subdir: Sub directory of data_dir containing the data.
  48. :param tight_box_rotation: bool, whether to use of segmentation maps convex hull as target_seg
  49. (check get_sample docs).
  50. :param with_crowd: Add the crowd groundtruths to __getitem__
  51. kwargs:
  52. all_classes_list: all classes list, default is COCO_DETECTION_CLASSES_LIST.
  53. """
  54. self.subdir = subdir
  55. self.json_file = json_file
  56. self.tight_box_rotation = tight_box_rotation
  57. self.with_crowd = with_crowd
  58. target_fields = ["target", "crowd_target"] if self.with_crowd else ["target"]
  59. kwargs["target_fields"] = target_fields
  60. kwargs["output_fields"] = ["image", *target_fields]
  61. kwargs["original_target_format"] = DetectionTargetsFormat.XYXY_LABEL
  62. kwargs["all_classes_list"] = kwargs.get("all_classes_list") or COCO_DETECTION_CLASSES_LIST
  63. super().__init__(*args, **kwargs)
  64. if len(self.original_classes) != len(self.all_classes_list):
  65. if set(self.all_classes_list).issubset(set(self.original_classes)):
  66. raise ParameterMismatchException(
  67. "Parameter `all_classes_list` contains a subset of classes from dataset JSON. "
  68. "Please use `class_inclusion_list` to train with reduced number of classes",
  69. )
  70. else:
  71. raise DatasetValidationException(
  72. "Number of classes in dataset JSON do not match with number of classes in all_classes_list parameter. "
  73. "Most likely this indicates an error in your all_classes_list parameter"
  74. )
  75. def _setup_data_source(self) -> int:
  76. """Initialize img_and_target_path_list and warn if label file is missing
  77. :return: List of tuples made of (img_path,target_path)
  78. """
  79. self.coco = self._init_coco()
  80. self.class_ids = sorted(self.coco.getCatIds())
  81. self.original_classes = list([category["name"] for category in self.coco.loadCats(self.class_ids)])
  82. self.classes = copy.deepcopy(self.original_classes)
  83. self.sample_id_to_coco_id = self.coco.getImgIds()
  84. return len(self.sample_id_to_coco_id)
  85. def _init_coco(self) -> COCO:
  86. annotation_file_path = os.path.join(self.data_dir, "annotations", self.json_file)
  87. if not os.path.exists(annotation_file_path):
  88. raise ValueError("Could not find annotation file under " + str(annotation_file_path))
  89. coco = COCO(annotation_file_path)
  90. remove_useless_info(coco, self.tight_box_rotation)
  91. return coco
  92. def _load_annotation(self, sample_id: int) -> dict:
  93. """
  94. Load relevant information of a specific image.
  95. :param sample_id: Sample_id in the dataset
  96. :return target: Target Bboxes (detection) in XYXY_LABEL format
  97. :return crowd_target: Crowd target Bboxes (detection) in XYXY_LABEL format
  98. :return target_segmentation: Segmentation
  99. :return initial_img_shape: Image (height, width)
  100. :return resized_img_shape: Resides image (height, width)
  101. :return img_path: Path to the associated image
  102. """
  103. img_id = self.sample_id_to_coco_id[sample_id]
  104. img_metadata = self.coco.loadImgs(img_id)[0]
  105. width = img_metadata["width"]
  106. height = img_metadata["height"]
  107. img_annotation_ids = self.coco.getAnnIds(imgIds=[int(img_id)])
  108. img_annotations = self.coco.loadAnns(img_annotation_ids)
  109. cleaned_annotations = []
  110. for annotation in img_annotations:
  111. x1 = np.max((0, annotation["bbox"][0]))
  112. y1 = np.max((0, annotation["bbox"][1]))
  113. x2 = np.min((width, x1 + np.max((0, annotation["bbox"][2]))))
  114. y2 = np.min((height, y1 + np.max((0, annotation["bbox"][3]))))
  115. if annotation["area"] > 0 and x2 >= x1 and y2 >= y1:
  116. annotation["clean_bbox"] = [x1, y1, x2, y2]
  117. cleaned_annotations.append(annotation)
  118. non_crowd_annotations = [annotation for annotation in cleaned_annotations if annotation["iscrowd"] == 0]
  119. target = np.zeros((len(non_crowd_annotations), 5))
  120. num_seg_values = 98 if self.tight_box_rotation else 0
  121. target_segmentation = np.ones((len(non_crowd_annotations), num_seg_values))
  122. target_segmentation.fill(np.nan)
  123. for ix, annotation in enumerate(non_crowd_annotations):
  124. cls = self.class_ids.index(annotation["category_id"])
  125. target[ix, 0:4] = annotation["clean_bbox"]
  126. target[ix, 4] = cls
  127. if self.tight_box_rotation:
  128. seg_points = [j for i in annotation.get("segmentation", []) for j in i]
  129. if seg_points:
  130. seg_points_c = np.array(seg_points).reshape((-1, 2)).astype(np.int)
  131. seg_points_convex = cv2.convexHull(seg_points_c).ravel()
  132. else:
  133. seg_points_convex = []
  134. target_segmentation[ix, : len(seg_points_convex)] = seg_points_convex
  135. crowd_annotations = [annotation for annotation in cleaned_annotations if annotation["iscrowd"] == 1]
  136. crowd_target = np.zeros((len(crowd_annotations), 5))
  137. for ix, annotation in enumerate(crowd_annotations):
  138. cls = self.class_ids.index(annotation["category_id"])
  139. crowd_target[ix, 0:4] = annotation["clean_bbox"]
  140. crowd_target[ix, 4] = cls
  141. r = min(self.input_dim[0] / height, self.input_dim[1] / width)
  142. target[:, :4] *= r
  143. crowd_target[:, :4] *= r
  144. target_segmentation *= r
  145. initial_img_shape = (height, width)
  146. resized_img_shape = (int(height * r), int(width * r))
  147. file_name = img_metadata["file_name"] if "file_name" in img_metadata else "{:012}".format(img_id) + ".jpg"
  148. img_path = os.path.join(self.data_dir, self.subdir, file_name)
  149. img_id = self.sample_id_to_coco_id[sample_id]
  150. annotation = {
  151. "target": target,
  152. "crowd_target": crowd_target,
  153. "target_segmentation": target_segmentation,
  154. "initial_img_shape": initial_img_shape,
  155. "resized_img_shape": resized_img_shape,
  156. "img_path": img_path,
  157. "id": np.array([img_id]),
  158. }
  159. return annotation
  160. def remove_useless_info(coco, use_seg_info=False):
  161. """
  162. Remove useless info in coco dataset. COCO object is modified inplace.
  163. This function is mainly used for saving memory (save about 30% mem).
  164. """
  165. if isinstance(coco, COCO):
  166. dataset = coco.dataset
  167. dataset.pop("info", None)
  168. dataset.pop("licenses", None)
  169. for img in dataset["images"]:
  170. img.pop("license", None)
  171. img.pop("coco_url", None)
  172. img.pop("date_captured", None)
  173. img.pop("flickr_url", None)
  174. if "annotations" in coco.dataset and not use_seg_info:
  175. for anno in coco.dataset["annotations"]:
  176. anno.pop("segmentation", None)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...