Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

coco_parsing_test.py 6.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
  1. import os.path
  2. import unittest
  3. import numpy as np
  4. from pycocotools.coco import COCO
  5. from super_gradients.training.datasets.data_formats.bbox_formats.xywh import xywh_to_xyxy
  6. from super_gradients.training.datasets.detection_datasets.coco_format_detection import parse_coco_into_detection_annotations
  7. from super_gradients.training.datasets.pose_estimation_datasets.coco_utils import parse_coco_into_keypoints_annotations, segmentation2mask
  8. class COCOParsingTest(unittest.TestCase):
  9. """
  10. Unit test for checking whether our implementation of COCO parsing produce the same results as the original pycoctools implementation.
  11. """
  12. def setUp(self) -> None:
  13. self.data_dir = os.environ.get("SUPER_GRADIENTS_COCO_DATASET_DIR", "/data/coco")
  14. self.keypoint_annotations = [
  15. "annotations/person_keypoints_val2017.json",
  16. ]
  17. self.detection_annotations = [
  18. "annotations/person_keypoints_val2017.json",
  19. ]
  20. def test_detection_parsing(self):
  21. for annotation_file in self.detection_annotations:
  22. annotation_file = os.path.join(self.data_dir, annotation_file)
  23. with self.subTest(annotation_file=annotation_file):
  24. coco = COCO(annotation_file)
  25. category_id_to_name = {category["id"]: category["name"] for category in coco.loadCats(coco.getCatIds())}
  26. all_class_names, annotations = parse_coco_into_detection_annotations(
  27. annotation_file,
  28. exclude_classes=None,
  29. include_classes=None,
  30. image_path_prefix="",
  31. )
  32. self.assertEquals(len(annotations), len(coco.getImgIds()))
  33. for annotation in annotations:
  34. img_id = annotation.image_id
  35. ann_ids = coco.getAnnIds(imgIds=[img_id])
  36. anns = coco.loadAnns(ann_ids)
  37. coco_boxes = np.array([ann["bbox"] for ann in anns], dtype=np.float32).reshape(-1, 4)
  38. coco_boxes_xyxy = xywh_to_xyxy(coco_boxes, image_shape=None)
  39. coco_classes = [ann["category_id"] for ann in anns]
  40. coco_class_names = [category_id_to_name[category_id] for category_id in coco_classes]
  41. coco_is_crowd = np.array([ann["iscrowd"] for ann in anns], dtype=bool).reshape(-1)
  42. ann_class_names = [all_class_names[category_id] for category_id in annotation.ann_labels]
  43. self.assertArrayEqual(coco_class_names, ann_class_names)
  44. self.assertArrayEqual(coco_is_crowd, annotation.ann_is_crowd)
  45. self.assertArrayAlmostEqual(coco_boxes_xyxy, annotation.ann_boxes_xyxy, rtol=1e-5, atol=1)
  46. def test_keypoints_segmentation_masks(self):
  47. for annotation_file in self.keypoint_annotations:
  48. annotation_file = os.path.join(self.data_dir, annotation_file)
  49. with self.subTest(annotation_file=annotation_file):
  50. coco = COCO(annotation_file)
  51. global_intersection = 0.0
  52. global_cardinality = 0.0
  53. _, keypoints, annotations = parse_coco_into_keypoints_annotations(annotation_file, image_path_prefix=self.data_dir)
  54. num_keypoints = len(keypoints)
  55. self.assertEquals(len(annotations), len(coco.getImgIds()))
  56. for annotation in annotations:
  57. img_id = annotation.image_id
  58. img_metadata = coco.loadImgs([img_id])[0]
  59. ann_ids = coco.getAnnIds(imgIds=[img_id])
  60. anns = coco.loadAnns(ann_ids)
  61. coco_areas = [ann["area"] for ann in anns]
  62. coco_keypoints = np.array([np.array(ann["keypoints"], dtype=np.float32).reshape(-1, 3) for ann in anns]).reshape(-1, num_keypoints, 3)
  63. self.assertArrayAlmostEqual(coco_areas, annotation.ann_areas, rtol=1e-5, atol=1)
  64. self.assertArrayAlmostEqual(coco_keypoints, annotation.ann_keypoints, rtol=1e-5, atol=1)
  65. for ann_index in range(len(anns)):
  66. ann = anns[ann_index]
  67. expected_mask = coco.annToMask(ann)
  68. expected_mask[expected_mask > 0] = 1
  69. actual_mask = segmentation2mask(annotation.ann_segmentations[ann_index], image_shape=(img_metadata["height"], img_metadata["width"]))
  70. actual_mask[actual_mask > 0] = 1
  71. global_intersection += np.sum(expected_mask * actual_mask, dtype=np.float64)
  72. global_cardinality += np.sum(expected_mask + actual_mask, dtype=np.float64)
  73. iou = np.sum(expected_mask * actual_mask) / (np.sum(expected_mask + actual_mask) - np.sum(expected_mask * actual_mask))
  74. # Uncomment this to visualize the differences for low IoU scores (if it happens)
  75. # if iou < 0.2:
  76. # cv2.imshow("expected", expected_mask * 255)
  77. # cv2.imshow("actual", actual_mask * 255)
  78. # cv2.imshow("diff", cv2.absdiff(expected_mask * 255, actual_mask * 255))
  79. # cv2.waitKey(0)
  80. # print(f"iou: {iou}")
  81. self.assertGreater(iou, 0.2, msg=f"iou: {iou} for img_id: {img_id} ann_index: {ann_index}")
  82. global_iou = global_intersection / (global_cardinality - global_intersection)
  83. print(global_iou, annotation_file)
  84. # The polygon rasterization implementation in pycocotools is slightly different from the one we use (OpenCV)
  85. # To evaluate how well the masks are parsed, we calculate the global IoU between the all the masks instances
  86. # This is done intentionally to avoid the influece of the low IoU scores for extremely small masks
  87. self.assertGreater(global_iou, 0.98)
  88. def assertArrayAlmostEqual(self, first, second, rtol, atol):
  89. self.assertTrue(np.allclose(first, second, rtol=rtol, atol=atol))
  90. def assertArrayEqual(self, first, second):
  91. self.assertTrue(np.array_equal(first, second))
  92. if __name__ == "__main__":
  93. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...