Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#717 Feature/sg 636 pose estimation metrics

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-636-pose-estimation-metrics
@@ -34,3 +34,4 @@ pygments>=2.7.4
 stringcase>=1.2.0
 stringcase>=1.2.0
 numpy<=1.23
 numpy<=1.23
 rapidfuzz
 rapidfuzz
+json-tricks==3.16.1
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
  1. import numpy as np
  2. from pycocotools.coco import COCO
  3. from super_gradients.common.abstractions.abstract_logger import get_logger
  4. logger = get_logger(__name__)
  5. __all__ = ["check_keypoints_outside_image", "check_for_duplicate_annotations", "make_keypoints_outside_image_invisible", "remove_duplicate_annotations"]
  6. def check_keypoints_outside_image(coco: COCO) -> None:
  7. """
  8. Check if there are any keypoints outside the image.
  9. :param coco:
  10. :return: None
  11. """
  12. for ann in coco.anns.values():
  13. keypoints = np.array(ann["keypoints"]).reshape(-1, 3)
  14. image_rows = coco.imgs[ann["image_id"]]["height"]
  15. image_cols = coco.imgs[ann["image_id"]]["width"]
  16. visible_joints = keypoints[:, 2] > 0
  17. joints_outside_image = (keypoints[:, 0] < 0) | (keypoints[:, 0] >= image_cols) | (keypoints[:, 1] < 0) | (keypoints[:, 1] >= image_rows)
  18. visible_joints_outside_image = visible_joints & joints_outside_image
  19. if visible_joints_outside_image.any():
  20. logger.warning(
  21. f"Annotation {ann['id']} for image {ann['image_id']} (width={image_cols}, height={image_rows}) "
  22. f"contains keypoints outside image boundary {keypoints[joints_outside_image]}. "
  23. )
  24. def check_for_duplicate_annotations(coco: COCO, max_distance_threshold=2) -> None:
  25. """
  26. Check if there are any duplicate (overlapping) object annotations.
  27. :param coco:
  28. :param max_distance_threshold: Maximum average distance between keypoints of two instances to be considered as duplicate.
  29. :return: None
  30. """
  31. image_ids = list(coco.imgs.keys())
  32. for image_id in image_ids:
  33. ann_ids = coco.getAnnIds(imgIds=image_id)
  34. annotations = coco.loadAnns(ann_ids)
  35. joints = []
  36. for ann in annotations:
  37. keypoints = np.array(ann["keypoints"]).reshape(-1, 3)
  38. joints.append(keypoints[:, :2])
  39. if len(joints) == 0:
  40. continue
  41. gt_joints1 = np.expand_dims(joints, axis=0) # [1, Num_people, Num_joints, 2]
  42. gt_joints2 = np.expand_dims(joints, axis=1) # [Num_people, 1, Num_joints, 2]
  43. diff = np.sqrt(np.sum((gt_joints1 - gt_joints2) ** 2, axis=-1)) # [Num_people, Num_people, Num_joints]
  44. diffmean = np.mean(diff, axis=-1)
  45. duplicate_mask = np.triu(diffmean < max_distance_threshold, k=1)
  46. duplicate_indexes_i, duplicate_indexes_j = np.nonzero(duplicate_mask)
  47. for i, j in zip(duplicate_indexes_i, duplicate_indexes_j):
  48. logger.warning(f"Duplicate annotations for image {image_id}: {annotations[i]['id']} and {annotations[j]['id']}")
  49. def make_keypoints_outside_image_invisible(coco: COCO):
  50. for ann in coco.anns.values():
  51. keypoints = np.array(ann["keypoints"]).reshape(-1, 3)
  52. image_rows = coco.imgs[ann["image_id"]]["height"]
  53. image_cols = coco.imgs[ann["image_id"]]["width"]
  54. visible_joints = keypoints[:, 2] > 0
  55. joints_outside_image = (keypoints[:, 0] < 0) | (keypoints[:, 0] >= image_cols) | (keypoints[:, 1] < 0) | (keypoints[:, 1] >= image_rows)
  56. visible_joints_outside_image = visible_joints & joints_outside_image
  57. if visible_joints_outside_image.any():
  58. logger.debug(
  59. f"Detected GT joints outside image size (width={image_cols}, height={image_rows}). "
  60. f"{keypoints[joints_outside_image]} for obj_id {ann['id']} image_id {ann['image_id']}. "
  61. f"Changing visibility to invisible."
  62. )
  63. keypoints[visible_joints_outside_image, 2] = 0
  64. ann["keypoints"] = [float(x) for x in keypoints.reshape(-1)]
  65. return coco
  66. def remove_duplicate_annotations(coco: COCO):
  67. ann_to_remove = []
  68. image_ids = list(coco.imgs.keys())
  69. for image_id in image_ids:
  70. ann_ids = coco.getAnnIds(imgIds=image_id)
  71. annotations = coco.loadAnns(ann_ids)
  72. joints = []
  73. for ann in annotations:
  74. keypoints = np.array(ann["keypoints"]).reshape(-1, 3)
  75. joints.append(keypoints[:, :2])
  76. if len(joints) == 0:
  77. continue
  78. gt_joints1 = np.expand_dims(joints, axis=0) # [1, Num_people, Num_joints, 2]
  79. gt_joints2 = np.expand_dims(joints, axis=1) # [Num_people, 1, Num_joints, 2]
  80. diff = np.sqrt(np.sum((gt_joints1 - gt_joints2) ** 2, axis=-1)) # [Num_people, Num_people, Num_joints]
  81. diffmean = np.mean(diff, axis=-1)
  82. duplicate_mask = np.triu(diffmean < 2, k=1)
  83. duplicate_indexes_i, duplicate_indexes_j = np.nonzero(duplicate_mask)
  84. for j in duplicate_indexes_j:
  85. ann_to_remove.append(ann_ids[j])
  86. if len(ann_to_remove) > 0:
  87. logger.debug(f"Removing {len(ann_to_remove)} duplicate annotations")
  88. len_before = len(coco.dataset["annotations"])
  89. coco.dataset["annotations"] = [v for v in coco.dataset["annotations"] if v["id"] not in ann_to_remove]
  90. len_after = len(coco.dataset["annotations"])
  91. logger.debug(f"Removed {len_before - len_after} duplicate annotations")
  92. coco.createIndex()
  93. return coco
  94. def remove_crowd_annotations(coco: COCO):
  95. ann_to_remove = []
  96. image_ids = list(coco.imgs.keys())
  97. for image_id in image_ids:
  98. ann_ids = coco.getAnnIds(imgIds=image_id)
  99. annotations = coco.loadAnns(ann_ids)
  100. for ann in annotations:
  101. if bool(ann["iscrowd"]):
  102. ann_to_remove.append(ann["id"])
  103. if len(ann_to_remove) > 0:
  104. logger.debug(f"Removing {len(ann_to_remove)} crowd annotations")
  105. len_before = len(coco.dataset["annotations"])
  106. coco.dataset["annotations"] = [v for v in coco.dataset["annotations"] if v["id"] not in ann_to_remove]
  107. len_after = len(coco.dataset["annotations"])
  108. logger.debug(f"Removed {len_before - len_after} crowd annotations")
  109. coco.createIndex()
  110. return coco
Discard
@@ -4,7 +4,7 @@ from super_gradients.training.metrics.classification_metrics import accuracy, Ac
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics, DetectionMetrics_050, DetectionMetrics_075, DetectionMetrics_050_095
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics, DetectionMetrics_050, DetectionMetrics_075, DetectionMetrics_050_095
 from super_gradients.training.metrics.segmentation_metrics import PreprocessSegmentationMetricsArgs, PixelAccuracy, IoU, Dice, BinaryIOU, BinaryDice
 from super_gradients.training.metrics.segmentation_metrics import PreprocessSegmentationMetricsArgs, PixelAccuracy, IoU, Dice, BinaryIOU, BinaryDice
 from super_gradients.training.metrics.all_metrics import METRICS, Metrics
 from super_gradients.training.metrics.all_metrics import METRICS, Metrics
-
+from super_gradients.training.metrics.pose_estimation_metrics import PoseEstimationMetrics
 
 
 __all__ = [
 __all__ = [
     "METRICS",
     "METRICS",
@@ -23,4 +23,5 @@ __all__ = [
     "DetectionMetrics_050",
     "DetectionMetrics_050",
     "DetectionMetrics_075",
     "DetectionMetrics_075",
     "DetectionMetrics_050_095",
     "DetectionMetrics_050_095",
+    "PoseEstimationMetrics",
 ]
 ]
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
  1. import itertools
  2. from typing import Dict, Union, List, Optional, Tuple, Callable, Iterable, Any
  3. import numpy as np
  4. import torch
  5. from torch import Tensor
  6. from torchmetrics import Metric
  7. import super_gradients
  8. from super_gradients.common.abstractions.abstract_logger import get_logger
  9. from super_gradients.common.registry.registry import register_metric
  10. from super_gradients.training.metrics.pose_estimation_utils import compute_img_keypoint_matching, compute_visible_bbox_xywh
  11. from super_gradients.training.utils.detection_utils import compute_detection_metrics_per_cls
  12. logger = get_logger(__name__)
  13. __all__ = ["PoseEstimationMetrics"]
  14. @register_metric("PoseEstimationMetrics")
  15. class PoseEstimationMetrics(Metric):
  16. """
  17. Implementation of COCO Keypoint evaluation metric.
  18. When instantiated with default parameters, it will default to COCO params.
  19. By default, only AR and AP metrics are computed:
  20. >>> from super_gradients.training.metrics import PoseEstimationMetrics
  21. >>> metric = PoseEstimationMetrics(...)
  22. >>> metric.update(...)
  23. >>> metrics = metric.compute() # {"AP": 0.123, "AR": 0.456 }
  24. If you wish to get AR/AR at specific thresholds, you can specify them using `iou_thresholds_to_report` argument:
  25. >>> from super_gradients.training.metrics import PoseEstimationMetrics
  26. >>> metric = PoseEstimationMetrics(..., iou_thresholds_to_report=[0.5, 0.75])
  27. >>> metric.update(...)
  28. >>> metrics = metric.compute() # {"AP": 0.123, "AP_0.5": 0.222, "AP_0.75: 0.111, "AR": 0.456, "AR_0.5":0.212, "AR_0.75": 0.443 }
  29. """
  30. def __init__(
  31. self,
  32. post_prediction_callback: Callable[[Any], Tuple[Tensor, Tensor]],
  33. num_joints: int,
  34. max_objects_per_image: int = 20,
  35. oks_sigmas: Optional[Iterable] = None,
  36. iou_thresholds: Optional[Iterable] = None,
  37. recall_thresholds: Optional[Iterable] = None,
  38. iou_thresholds_to_report: Optional[Iterable] = None,
  39. ):
  40. """
  41. Compute the AP & AR metrics for pose estimation. By default, this class returns only AP and AR values.
  42. If you need to get additional metrics (AP at specific threshold), pass these thresholds via `iou_thresholds_to_report` argument.
  43. :param post_prediction_callback: A callback to decode model predictions to poses. This should be callable that takes input (model predictions)
  44. and returns a tuple of (poses, scores)
  45. :param num_joints: Number of joints per pose
  46. :param max_objects_per_image: Maximum number of predicted poses to include in evaluation (Top-K poses will be used).
  47. :param oks_sigmas: OKS sigma factor for custom keypoint detection dataset.
  48. If None, then metric will use default OKS from COCO and expect num_joints to be equal 17
  49. :param recall_thresholds: List of recall thresholds to compute AP.
  50. If None, then will use default 101 recall thresholds from COCO in range [0..1]
  51. :param iou_thresholds: List of IoU thresholds to use. If None, then COCO version of IoU will be used (0.5 ... 0.95)
  52. :param: iou_thresholds_to_report: List of IoU thresholds to return in metric. By default, only AP/AR metrics are returned, but one
  53. may also request to return AP_0.5,AP_0.75,AR_0.5,AR_0.75 setting `iou_thresholds_to_report=[0.5, 0.75]`
  54. """
  55. super().__init__(dist_sync_on_step=False)
  56. self.num_joints = num_joints
  57. self.max_objects_per_image = max_objects_per_image
  58. self.stats_names = ["AP", "AR"]
  59. if recall_thresholds is None:
  60. recall_thresholds = np.linspace(0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True, dtype=np.float32)
  61. self.recall_thresholds = torch.tensor(recall_thresholds, dtype=torch.float32)
  62. if iou_thresholds is None:
  63. iou_thresholds = np.linspace(0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True, dtype=np.float32)
  64. self.iou_thresholds = torch.tensor(iou_thresholds, dtype=torch.float32)
  65. if iou_thresholds_to_report is not None:
  66. self.iou_thresholds_to_report = np.array([float(t) for t in iou_thresholds_to_report], dtype=np.float32)
  67. if not np.isin(self.iou_thresholds_to_report, self.iou_thresholds).all():
  68. missing = ~np.isin(self.iou_thresholds_to_report, self.iou_thresholds)
  69. raise RuntimeError(
  70. f"One or many IoU thresholds to report are not present in IoU thresholds. Missing thresholds: {self.iou_thresholds_to_report[missing]}"
  71. )
  72. self.stats_names += [f"AP_{t:.2f}" for t in self.iou_thresholds_to_report]
  73. self.stats_names += [f"AR_{t:.2f}" for t in self.iou_thresholds_to_report]
  74. else:
  75. self.iou_thresholds_to_report = None
  76. self.greater_component_is_better = dict((k, True) for k in self.stats_names)
  77. if oks_sigmas is None:
  78. oks_sigmas = np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89]) / 10.0
  79. if len(oks_sigmas) != num_joints:
  80. raise ValueError(f"Length of oks_sigmas ({len(oks_sigmas)}) should be equal to num_joints {num_joints}")
  81. self.oks_sigmas = torch.tensor(oks_sigmas).float()
  82. self.component_names = list(self.greater_component_is_better.keys())
  83. self.components = len(self.component_names)
  84. self.post_prediction_callback = post_prediction_callback
  85. self.is_distributed = super_gradients.is_distributed()
  86. self.world_size = None
  87. self.rank = None
  88. self.add_state("predictions", default=[], dist_reduce_fx=None)
  89. def reset(self) -> None:
  90. self.predictions = []
  91. def update(
  92. self,
  93. preds,
  94. target,
  95. gt_joints: List[np.ndarray],
  96. gt_iscrowd: List[np.ndarray] = None,
  97. gt_bboxes: List[np.ndarray] = None,
  98. gt_areas: List[np.ndarray] = None,
  99. ):
  100. """
  101. Decode the predictions and update the metric
  102. :param preds : Raw output of the model
  103. :param target: Targets for the model training (rarely used for evaluation)
  104. :param gt_joints: List of ground-truth joints for each image in the batch. Each element is a numpy array of shape (num_instances, num_joints, 3).
  105. Note that augmentation/preprocessing transformations (Affine transforms specifically) must also be applied to gt_joints.
  106. This is to ensure joint coordinates are transforms identically as image. This is differs form COCO evaluation,
  107. where predictions rescaled back to original size of the image.
  108. However, this makes code much more (unnecessary) complicated, so we do it differently and evaluate joints in the coordinate
  109. system of the predicted image.
  110. :param gt_iscrowd: Optional argument indicating which instance is annotated with `iscrowd` flog and is not used for evaluation;
  111. If not provided, all instances are considered as non-crowd targets.
  112. For instance, in CrowdPose all instances are considered as "non-crowd".
  113. :param gt_bboxes: Bounding boxes of the groundtruth instances (XYWH).
  114. This is COCO-specific and is used in OKS computation for instances w/o visible keypoints.
  115. If not provided, the bounding box is computed as the minimum bounding box that contains all visible keypoints.
  116. :param gt_areas: Area of the groundtruth area. in COCO this is the area of the corresponding segmentation mask and not the bounding box,
  117. so it cannot be computed programmatically. It's value used in object-keypoint similarity metric (OKS) computation.
  118. If not provided, the area is computed as the product of the width and height of the bounding box.
  119. (For instance this is used in CrowdPose dataset)
  120. """
  121. predicted_poses, predicted_scores = self.post_prediction_callback(preds) # Decode raw predictions into poses
  122. if gt_bboxes is None:
  123. gt_bboxes = [compute_visible_bbox_xywh(torch.tensor(joints[:, :, 0:2]), torch.tensor(joints[:, :, 2])) for joints in gt_joints]
  124. if gt_areas is None:
  125. gt_areas = [bboxes[:, 2] * bboxes[:, 3] for bboxes in gt_bboxes]
  126. if gt_iscrowd is None:
  127. gt_iscrowd = [[False] * len(x) for x in gt_joints]
  128. for i in range(len(predicted_poses)):
  129. self.update_single_image(
  130. predicted_poses[i], predicted_scores[i], gt_joints[i], gt_areas=gt_areas[i], gt_bboxes=gt_bboxes[i], gt_is_crowd=gt_iscrowd[i]
  131. )
  132. def update_single_image(
  133. self,
  134. predicted_poses: Union[Tensor, np.ndarray],
  135. predicted_scores: Union[Tensor, np.ndarray],
  136. groundtruths: Union[Tensor, np.ndarray],
  137. gt_bboxes: Union[Tensor, np.ndarray],
  138. gt_areas: Union[Tensor, np.ndarray],
  139. gt_is_crowd: Union[Tensor, np.ndarray, List[bool]],
  140. ):
  141. if len(predicted_poses) == 0 and len(groundtruths) == 0:
  142. return
  143. if len(predicted_poses) != len(predicted_scores):
  144. raise ValueError("Length of predicted poses and scores should be equal. Got {} and {}".format(len(predicted_poses), len(predicted_scores)))
  145. if len(groundtruths) != len(gt_areas) != len(gt_bboxes) != len(gt_is_crowd):
  146. raise ValueError(
  147. "Length of groundtruths, areas, bboxes and iscrowd should be equal. Got {} and {} and {} and {}".format(
  148. len(groundtruths), len(gt_areas), len(gt_bboxes), len(gt_is_crowd)
  149. )
  150. )
  151. predicted_poses = torch.tensor(predicted_poses, dtype=torch.float, device=self.device)
  152. predicted_scores = torch.tensor(predicted_scores, dtype=torch.float, device=self.device)
  153. gt_keypoints = torch.tensor(groundtruths, dtype=torch.float, device=self.device)
  154. gt_areas = torch.tensor(gt_areas, dtype=torch.float, device=self.device)
  155. gt_bboxes = torch.tensor(gt_bboxes, dtype=torch.float, device=self.device)
  156. gt_keypoints_xy = gt_keypoints[:, :, 0:2]
  157. gt_keypoints_visibility = gt_keypoints[:, :, 2]
  158. gt_all_kpts_invisible = gt_keypoints_visibility.eq(0).all(dim=1)
  159. gt_is_crowd = torch.tensor(gt_is_crowd, dtype=torch.bool, device=self.device)
  160. gt_is_ignore = gt_all_kpts_invisible | gt_is_crowd
  161. targets = gt_keypoints_xy[~gt_is_ignore] if len(groundtruths) else []
  162. targets_visibilities = gt_keypoints_visibility[~gt_is_ignore] if len(groundtruths) else []
  163. targets_areas = gt_areas[~gt_is_ignore] if len(groundtruths) else []
  164. targets_bboxes = gt_bboxes[~gt_is_ignore]
  165. targets_ignored = gt_is_ignore[~gt_is_ignore]
  166. crowd_targets = gt_keypoints_xy[gt_is_ignore] if len(groundtruths) else []
  167. crowd_visibilities = gt_keypoints_visibility[gt_is_ignore] if len(groundtruths) else []
  168. crowd_targets_areas = gt_areas[gt_is_ignore]
  169. crowd_targets_bboxes = gt_bboxes[gt_is_ignore]
  170. preds_matched, preds_to_ignore, preds_scores, num_targets = compute_img_keypoint_matching(
  171. predicted_poses,
  172. predicted_scores,
  173. #
  174. targets=targets,
  175. targets_visibilities=targets_visibilities,
  176. targets_areas=targets_areas,
  177. targets_bboxes=targets_bboxes,
  178. targets_ignored=targets_ignored,
  179. #
  180. crowd_targets=crowd_targets,
  181. crowd_visibilities=crowd_visibilities,
  182. crowd_targets_areas=crowd_targets_areas,
  183. crowd_targets_bboxes=crowd_targets_bboxes,
  184. #
  185. iou_thresholds=self.iou_thresholds.to(self.device),
  186. sigmas=self.oks_sigmas.to(self.device),
  187. top_k=self.max_objects_per_image,
  188. )
  189. self.predictions.append((preds_matched, preds_to_ignore, preds_scores, num_targets))
  190. def _sync_dist(self, dist_sync_fn=None, process_group=None):
  191. """
  192. When in distributed mode, stats are aggregated after each forward pass to the metric state. Since these have all
  193. different sizes we override the synchronization function since it works only for tensors (and use
  194. all_gather_object)
  195. @param dist_sync_fn:
  196. @return:
  197. """
  198. if self.world_size is None:
  199. self.world_size = torch.distributed.get_world_size() if self.is_distributed else -1
  200. if self.rank is None:
  201. self.rank = torch.distributed.get_rank() if self.is_distributed else -1
  202. if self.is_distributed:
  203. local_state_dict = self.predictions
  204. gathered_state_dicts = [None] * self.world_size
  205. torch.distributed.barrier()
  206. torch.distributed.all_gather_object(gathered_state_dicts, local_state_dict)
  207. self.predictions = list(itertools.chain(*gathered_state_dicts))
  208. def compute(self) -> Dict[str, Union[float, torch.Tensor]]:
  209. """Compute the metrics for all the accumulated results.
  210. :return: Metrics of interest
  211. """
  212. T = len(self.iou_thresholds)
  213. K = 1 # num categories
  214. precision = -np.ones((T, K))
  215. recall = -np.ones((T, K))
  216. predictions = self.predictions # All gathered by this time
  217. if len(predictions) > 0:
  218. preds_matched = torch.cat([x[0] for x in predictions], dim=0)
  219. preds_to_ignore = torch.cat([x[1] for x in predictions], dim=0)
  220. preds_scores = torch.cat([x[2] for x in predictions], dim=0)
  221. n_targets = sum([x[3] for x in predictions])
  222. cls_precision, _, cls_recall = compute_detection_metrics_per_cls(
  223. preds_matched=preds_matched,
  224. preds_to_ignore=preds_to_ignore,
  225. preds_scores=preds_scores,
  226. n_targets=n_targets,
  227. recall_thresholds=self.recall_thresholds.to(self.device),
  228. score_threshold=0,
  229. device=self.device,
  230. )
  231. precision[:, 0] = cls_precision.cpu().numpy()
  232. recall[:, 0] = cls_recall.cpu().numpy()
  233. def summarize(s):
  234. if len(s[s > -1]) == 0:
  235. mean_s = -1
  236. else:
  237. mean_s = np.mean(s[s > -1])
  238. return mean_s
  239. metrics = {"AP": summarize(precision), "AR": summarize(recall)}
  240. if self.iou_thresholds_to_report is not None and len(self.iou_thresholds_to_report):
  241. for t in self.iou_thresholds_to_report:
  242. mask = np.where(t == self.iou_thresholds)[0]
  243. metrics[f"AP_{t:.2f}"] = summarize(precision[mask])
  244. metrics[f"AR_{t:.2f}"] = summarize(recall[mask])
  245. return metrics
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
  1. from typing import Tuple
  2. import numpy as np
  3. import torch
  4. from torch import Tensor
  5. def compute_visible_bbox_xywh(joints: Tensor, visibility_mask: Tensor) -> np.ndarray:
  6. """
  7. Compute the bounding box (X,Y,W,H) of the visible joints for each instance.
  8. :param joints: [Num Instances, Num Joints, 2+] last channel must have dimension of
  9. at least 2 that is considered to contain (X,Y) coordinates of the keypoint
  10. :param visibility_mask: [Num Instances, Num Joints]
  11. :return: A numpy array [Num Instances, 4] where last dimension contains bbox in format XYWH
  12. """
  13. visibility_mask = visibility_mask > 0
  14. initial_value = 1_000_000
  15. x1 = torch.min(joints[:, :, 0], where=visibility_mask, initial=initial_value, dim=-1)
  16. y1 = torch.min(joints[:, :, 1], where=visibility_mask, initial=initial_value, dim=-1)
  17. x1[x1 == initial_value] = 0
  18. y1[y1 == initial_value] = 0
  19. x2 = torch.max(joints[:, :, 0], where=visibility_mask, initial=0, dim=-1)
  20. y2 = torch.max(joints[:, :, 1], where=visibility_mask, initial=0, dim=-1)
  21. w = x2 - x1
  22. h = y2 - y1
  23. return torch.stack([x1, y1, w, h], dim=-1)
  24. def compute_oks(
  25. pred_joints: Tensor,
  26. gt_joints: Tensor,
  27. gt_keypoint_visibility: Tensor,
  28. sigmas: Tensor,
  29. gt_areas: Tensor = None,
  30. gt_bboxes: Tensor = None,
  31. ) -> np.ndarray:
  32. """
  33. :param pred_joints: [K, NumJoints, 2] or [K, NumJoints, 3]
  34. :param pred_scores: [K]
  35. :param gt_joints: [M, NumJoints, 2]
  36. :param gt_keypoint_visibility: [M, NumJoints]
  37. :param gt_areas: [M] Area of each ground truth instance. COCOEval uses area of the instance mask to scale OKs, so it must be provided separately.
  38. If None, we will use area of bounding box of each instance computed from gt_joints.
  39. :param gt_bboxes: [M, 4] Bounding box (X,Y,W,H) of each ground truth instance. If None, we will use bounding box of each instance computed from gt_joints.
  40. :param sigmas: [NumJoints]
  41. :return: IoU matrix [K, M]
  42. """
  43. ious = torch.zeros((len(pred_joints), len(gt_joints)), device=pred_joints.device)
  44. vars = (sigmas * 2) ** 2
  45. if gt_bboxes is None:
  46. gt_bboxes = compute_visible_bbox_xywh(gt_joints, gt_keypoint_visibility)
  47. if gt_areas is None:
  48. gt_areas = gt_bboxes[:, 2] * gt_bboxes[:, 3]
  49. # compute oks between each detection and ground truth object
  50. for gt_index, (gt_keypoints, gt_keypoint_visibility, gt_bbox, gt_area) in enumerate(zip(gt_joints, gt_keypoint_visibility, gt_bboxes, gt_areas)):
  51. # create bounds for ignore regions(double the gt bbox)
  52. xg = gt_keypoints[:, 0]
  53. yg = gt_keypoints[:, 1]
  54. k1 = torch.count_nonzero(gt_keypoint_visibility > 0)
  55. x0 = gt_bbox[0] - gt_bbox[2]
  56. x1 = gt_bbox[0] + gt_bbox[2] * 2
  57. y0 = gt_bbox[1] - gt_bbox[3]
  58. y1 = gt_bbox[1] + gt_bbox[3] * 2
  59. for pred_index, pred_keypoints in enumerate(pred_joints):
  60. xd = pred_keypoints[:, 0]
  61. yd = pred_keypoints[:, 1]
  62. if k1 > 0:
  63. # measure the per-keypoint distance if keypoints visible
  64. dx = xd - xg
  65. dy = yd - yg
  66. else:
  67. # measure minimum distance to keypoints in (x0,y0) & (x1,y1)
  68. dx = (x0 - xd).clamp_min(0) + (xd - x1).clamp_min(0)
  69. dy = (y0 - yd).clamp_min(0) + (yd - y1).clamp_min(0)
  70. e = (dx**2 + dy**2) / vars / (gt_area + torch.finfo(torch.float64).eps) / 2
  71. if k1 > 0:
  72. e = e[gt_keypoint_visibility > 0]
  73. ious[pred_index, gt_index] = torch.sum(torch.exp(-e)) / e.shape[0]
  74. return ious
  75. def compute_img_keypoint_matching(
  76. preds: Tensor,
  77. pred_scores: Tensor,
  78. targets: Tensor,
  79. targets_visibilities: Tensor,
  80. targets_areas: Tensor,
  81. targets_bboxes: Tensor,
  82. targets_ignored: Tensor,
  83. crowd_targets: Tensor,
  84. crowd_visibilities: Tensor,
  85. crowd_targets_areas: Tensor,
  86. crowd_targets_bboxes: Tensor,
  87. iou_thresholds: torch.Tensor,
  88. sigmas: Tensor,
  89. top_k: int,
  90. ) -> Tuple[Tensor, Tensor, Tensor, int]:
  91. """
  92. Match predictions and the targets (ground truth) with respect to IoU and confidence score for a given image.
  93. :param preds: Tensor of shape (K, NumJoints, 3) - Array of predicted skeletons.
  94. Last dimension encode X,Y and confidence score of each joint
  95. :param pred_scores: Tensor of shape (K) - Confidence scores for each pose
  96. :param targets: Targets joints (M, NumJoints, 2) - Array of groundtruth skeletons
  97. :param targets_visibilities: Visibility status for each keypoint (M, NumJoints).
  98. Values are 0 - invisible, 1 - occluded, 2 - fully visible
  99. :param targets_areas: Tensor of shape (M) - Areas of target objects
  100. :param targets_bboxes: Tensor of shape (M,4) - Bounding boxes (XYWH) of targets
  101. :param targets_ignored: Tensor of shape (M) - Array of target that marked as ignored
  102. (E.g all keypoints are not visible or target does not fit the desired area range)
  103. :param crowd_targets: Targets joints (Mc, NumJoints, 3) - Array of groundtruth skeletons
  104. Last dimension encode X,Y and visibility score of each joint:
  105. (0 - invisible, 1 - occluded, 2 - fully visible)
  106. :param crowd_visibilities: Visibility status for each keypoint of crowd targets (Mc, NumJoints).
  107. Values are 0 - invisible, 1 - occluded, 2 - fully visible
  108. :param crowd_targets_areas: Tensor of shape (Mc) - Areas of target objects
  109. :param crowd_targets_bboxes: Tensor of shape (Mc, 4) - Bounding boxes (XYWH) of crowd targets
  110. :param iou_thresholds: IoU Threshold to compute the mAP
  111. :param sigmas: Tensor of shape (NumJoints) with sigmas for each joint. Sigma value represent how 'hard'
  112. it is to locate the exact groundtruth position of the joint.
  113. :param top_k: Number of predictions to keep, ordered by confidence score
  114. :return:
  115. :preds_matched: Tensor of shape (min(top_k, len(preds)), n_iou_thresholds)
  116. True when prediction (i) is matched with a target with respect to the (j)th IoU threshold
  117. :preds_to_ignore: Tensor of shape (min(top_k, len(preds)), n_iou_thresholds)
  118. True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold
  119. :preds_scores: Tensor of shape (min(top_k, len(preds))) with scores of top-k predictions
  120. :num_targets: Number of groundtruth targets (total num targets minus number of ignored)
  121. """
  122. num_iou_thresholds = len(iou_thresholds)
  123. device = preds.device if torch.is_tensor(preds) else (targets.device if torch.is_tensor(targets) else "cpu")
  124. if preds is None or len(preds) == 0:
  125. preds_matched = torch.zeros((0, num_iou_thresholds), dtype=torch.bool, device=device)
  126. preds_to_ignore = torch.zeros((0, num_iou_thresholds), dtype=torch.bool, device=device)
  127. preds_scores = torch.zeros((0,), dtype=torch.float, device=device)
  128. return preds_matched, preds_to_ignore, preds_scores, len(targets)
  129. preds_matched = torch.zeros(len(preds), num_iou_thresholds, dtype=torch.bool, device=device)
  130. targets_matched = torch.zeros(len(targets), num_iou_thresholds, dtype=torch.bool, device=device)
  131. preds_to_ignore = torch.zeros(len(preds), num_iou_thresholds, dtype=torch.bool, device=device)
  132. # Ignore all but the predictions that were top_k
  133. k = min(top_k, len(pred_scores))
  134. preds_idx_to_use = torch.topk(pred_scores, k=k, sorted=True, largest=True).indices
  135. preds_to_ignore[:, :] = True
  136. preds_to_ignore[preds_idx_to_use] = False
  137. if len(targets) > 0:
  138. iou = compute_oks(preds[preds_idx_to_use], targets, targets_visibilities, sigmas, gt_areas=targets_areas, gt_bboxes=targets_bboxes)
  139. # The matching priority is first detection confidence and then IoU value.
  140. # The detection is already sorted by confidence in NMS, so here for each prediction we order the targets by iou.
  141. sorted_iou, target_sorted = iou.sort(descending=True, stable=True)
  142. # Only iterate over IoU values higher than min threshold to speed up the process
  143. for pred_selected_i, target_sorted_i in (sorted_iou > iou_thresholds[0]).nonzero(as_tuple=False):
  144. # pred_selected_i and target_sorted_i are relative to filters/sorting, so we extract their absolute indexes
  145. pred_i = preds_idx_to_use[pred_selected_i]
  146. target_i = target_sorted[pred_selected_i, target_sorted_i]
  147. # Vector[j], True when IoU(pred_i, target_i) is above the (j)th threshold
  148. is_iou_above_threshold = sorted_iou[pred_selected_i, target_sorted_i] > iou_thresholds
  149. # Vector[j], True when both pred_i and target_i are not matched yet for the (j)th threshold
  150. are_candidates_free = torch.logical_and(~preds_matched[pred_i, :], ~targets_matched[target_i, :])
  151. # Vector[j], True when (pred_i, target_i) can be matched for the (j)th threshold
  152. are_candidates_good = torch.logical_and(is_iou_above_threshold, are_candidates_free)
  153. is_matching_with_ignore = are_candidates_free & are_candidates_good & targets_ignored[target_i]
  154. if preds_matched[pred_i].any() and is_matching_with_ignore.any():
  155. continue
  156. # For every threshold (j) where target_i and pred_i can be matched together ( are_candidates_good[j]==True )
  157. # fill the matching placeholders with True
  158. targets_matched[target_i, are_candidates_good] = True
  159. preds_matched[pred_i, are_candidates_good] = True
  160. preds_to_ignore[pred_i] = torch.logical_or(preds_to_ignore[pred_i], is_matching_with_ignore)
  161. # When all the targets are matched with a prediction for every IoU Threshold, stop.
  162. if targets_matched.all():
  163. break
  164. # Crowd targets can be matched with many predictions.
  165. # Therefore, for every prediction we just need to check if it has IoA large enough with any crowd target.
  166. if len(crowd_targets) > 0:
  167. # shape = (n_preds_to_use x n_crowd_targets)
  168. ioa = compute_oks(
  169. preds[preds_idx_to_use],
  170. crowd_targets,
  171. crowd_visibilities,
  172. sigmas,
  173. gt_areas=crowd_targets_areas,
  174. gt_bboxes=crowd_targets_bboxes,
  175. )
  176. # For each prediction, we keep it's highest score with any crowd target (of same class)
  177. # shape = (n_preds_to_use)
  178. best_ioa, _ = ioa.max(1)
  179. # If a prediction has IoA higher than threshold (with any target of same class), then there is a match
  180. # shape = (n_preds_to_use x iou_thresholds)
  181. is_matching_with_crowd = best_ioa.view(-1, 1) > iou_thresholds.view(1, -1)
  182. preds_to_ignore[preds_idx_to_use] = torch.logical_or(preds_to_ignore[preds_idx_to_use], is_matching_with_crowd)
  183. # return preds_matched, preds_to_ignore, pred_scores, len(targets)
  184. num_targets = len(targets) - torch.count_nonzero(targets_ignored)
  185. return preds_matched[preds_idx_to_use], preds_to_ignore[preds_idx_to_use], pred_scores[preds_idx_to_use], num_targets.item()
Discard
This file is too large to be shown
Discard
@@ -57,6 +57,7 @@ from tests.unit_tests.config_inspector_test import ConfigInspectTest
 from tests.unit_tests.repvgg_block_tests import TestRepVGGBlock
 from tests.unit_tests.repvgg_block_tests import TestRepVGGBlock
 from tests.unit_tests.training_utils_test import TestTrainingUtils
 from tests.unit_tests.training_utils_test import TestTrainingUtils
 from tests.unit_tests.dekr_loss_test import DEKRLossTest
 from tests.unit_tests.dekr_loss_test import DEKRLossTest
+from tests.unit_tests.pose_estimation_metrics_test import TestPoseEstimationMetrics
 
 
 
 
 class CoreUnitTestSuiteRunner:
 class CoreUnitTestSuiteRunner:
@@ -125,6 +126,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestTransforms))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestTransforms))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPPYOLOE))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPPYOLOE))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DEKRLossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DEKRLossTest))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPoseEstimationMetrics))
 
 
     def _add_modules_to_end_to_end_tests_suite(self):
     def _add_modules_to_end_to_end_tests_suite(self):
         """
         """
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
  1. import collections
  2. import os.path
  3. import random
  4. import tempfile
  5. import unittest
  6. from pprint import pprint
  7. from typing import List, Tuple
  8. import json_tricks as json
  9. import numpy as np
  10. import torch.cuda
  11. from pycocotools.coco import COCO
  12. from pycocotools.cocoeval import COCOeval
  13. from super_gradients.training.datasets.pose_estimation_datasets.coco_utils import (
  14. remove_duplicate_annotations,
  15. make_keypoints_outside_image_invisible,
  16. remove_crowd_annotations,
  17. )
  18. from super_gradients.training.metrics.pose_estimation_metrics import PoseEstimationMetrics
  19. class TestPoseEstimationMetrics(unittest.TestCase):
  20. def _load_coco_groundtruth(self, with_crowd: bool, with_duplicates: bool, with_invisible_keypoitns: bool):
  21. gt_annotations_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data/coco2017/annotations/person_keypoints_val2017.json")
  22. assert os.path.isfile(gt_annotations_path)
  23. gt = COCO(gt_annotations_path)
  24. if not with_duplicates:
  25. gt = remove_duplicate_annotations(gt)
  26. if not with_invisible_keypoitns:
  27. gt = make_keypoints_outside_image_invisible(gt)
  28. if not with_crowd:
  29. gt = remove_crowd_annotations(gt)
  30. return gt
  31. def _internal_compare_method(self, with_crowd: bool, with_duplicates: bool, with_invisible_keypoitns: bool, device: str):
  32. random.seed(0)
  33. np.random.seed(0)
  34. # Load groundtruth annotations
  35. gt = self._load_coco_groundtruth(with_crowd, with_duplicates, with_invisible_keypoitns)
  36. # Generate predictions by randomly dropping some instances and adding noise to remaining poses
  37. (
  38. predicted_poses,
  39. predicted_scores,
  40. groundtruths_poses,
  41. groundtruths_iscrowd,
  42. groundtruths_areas,
  43. groundtruths_bboxes,
  44. image_ids,
  45. ) = self.generate_noised_predictions(gt, instance_drop_probability=0.1, pose_offset=1)
  46. # Compute metrics using SG implementation
  47. def convert_predictions_to_target_format(preds):
  48. # This is out predictions decode function. Here it's no-op since we pass decoded predictions as the input
  49. # but in real life this post-processing callback should be doing actual pose decoding & NMS
  50. return preds
  51. sg_metrics = PoseEstimationMetrics(
  52. post_prediction_callback=convert_predictions_to_target_format,
  53. num_joints=17,
  54. max_objects_per_image=20,
  55. iou_thresholds_to_report=(0.5, 0.75),
  56. ).to(device)
  57. sg_metrics.update(
  58. preds=(predicted_poses, predicted_scores),
  59. target=None,
  60. gt_joints=groundtruths_poses,
  61. gt_iscrowd=groundtruths_iscrowd,
  62. gt_areas=groundtruths_areas,
  63. gt_bboxes=groundtruths_bboxes,
  64. )
  65. actual_metrics = sg_metrics.compute()
  66. pprint(actual_metrics)
  67. coco_pred = self._coco_convert_predictions_to_dict(predicted_poses, predicted_scores, image_ids)
  68. with tempfile.TemporaryDirectory() as td:
  69. res_file = os.path.join(td, "keypoints_coco2017_results.json")
  70. with open(res_file, "w") as f:
  71. json.dump(coco_pred, f, sort_keys=True, indent=4)
  72. coco_dt = self._load_coco_groundtruth(with_crowd, with_duplicates, with_invisible_keypoitns)
  73. coco_dt = coco_dt.loadRes(res_file)
  74. coco_evaluator = COCOeval(gt, coco_dt, iouType="keypoints")
  75. coco_evaluator.evaluate() # run per image evaluation
  76. coco_evaluator.accumulate() # accumulate per image results
  77. coco_evaluator.summarize() # display summary metrics of results
  78. expected_metrics = coco_evaluator.stats
  79. self.assertAlmostEquals(expected_metrics[0], actual_metrics["AP"], delta=0.002)
  80. self.assertAlmostEquals(expected_metrics[5], actual_metrics["AR"], delta=0.002)
  81. def test_compare_pycocotools_with_our_implementation_no_crowd(self):
  82. for device in ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"]:
  83. self._internal_compare_method(False, True, True, device)
  84. def test_compare_pycocotools_with_our_implementation_no_duplicates(self):
  85. for device in ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"]:
  86. self._internal_compare_method(True, False, True, device)
  87. def test_compare_pycocotools_with_our_implementation_no_invisible(self):
  88. for device in ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"]:
  89. self._internal_compare_method(True, True, False, device)
  90. def test_metric_works_on_empty_predictions(self):
  91. # Compute metrics using SG implementation
  92. def convert_predictions_to_target_format(preds):
  93. # This is out predictions decode function. Here it's no-op since we pass decoded predictions as the input
  94. # but in real life this post-processing callback should be doing actual pose decoding & NMS
  95. return preds
  96. sg_metrics = PoseEstimationMetrics(
  97. post_prediction_callback=convert_predictions_to_target_format,
  98. num_joints=17,
  99. max_objects_per_image=20,
  100. iou_thresholds=None,
  101. oks_sigmas=None,
  102. )
  103. actual_metrics = sg_metrics.compute()
  104. pprint(actual_metrics)
  105. self.assertEqual(-1, actual_metrics["AP"])
  106. self.assertEqual(-1, actual_metrics["AR"])
  107. def generate_noised_predictions(self, coco: COCO, instance_drop_probability: float, pose_offset: float) -> Tuple[List, List, List]:
  108. """
  109. :param coco:
  110. :return: List of tuples (poses, image_id)
  111. """
  112. image_ids = []
  113. predicted_poses = []
  114. predicted_scores = []
  115. groundtruths_poses = []
  116. groundtruths_iscrowd = []
  117. groundtruths_areas = []
  118. groundtruths_bboxes = []
  119. for image_id, image_info in coco.imgs.items():
  120. image_id_int = int(image_id)
  121. image_width = image_info["width"]
  122. image_height = image_info["height"]
  123. ann_ids = coco.getAnnIds(imgIds=image_id_int)
  124. anns = coco.loadAnns(ann_ids)
  125. image_pred_keypoints = []
  126. image_gt_keypoints = []
  127. image_gt_iscrowd = []
  128. image_gt_areas = []
  129. image_gt_bboxes = []
  130. for ann in anns:
  131. gt_keypoints = np.array(ann["keypoints"]).reshape(-1, 3).astype(np.float32)
  132. image_gt_keypoints.append(gt_keypoints)
  133. image_gt_iscrowd.append(ann["iscrowd"])
  134. image_gt_areas.append(ann["area"])
  135. image_gt_bboxes.append(ann["bbox"])
  136. if np.random.rand() < instance_drop_probability:
  137. continue
  138. keypoints = gt_keypoints.copy()
  139. if pose_offset > 0:
  140. keypoints[:, 0] += (2 * np.random.randn() - 1) * pose_offset
  141. keypoints[:, 1] += (2 * np.random.randn() - 1) * pose_offset
  142. keypoints[:, 0] = np.clip(keypoints[:, 0], 0, image_width)
  143. keypoints[:, 1] = np.clip(keypoints[:, 1], 0, image_height)
  144. # Apply random score for visible keypoints
  145. keypoints[:, 2] = (keypoints[:, 2] > 0) * np.random.randn(len(keypoints))
  146. image_pred_keypoints.append(keypoints)
  147. image_ids.append(image_id_int)
  148. predicted_poses.append(image_pred_keypoints)
  149. predicted_scores.append(np.random.rand(len(image_pred_keypoints)))
  150. groundtruths_poses.append(image_gt_keypoints)
  151. groundtruths_iscrowd.append(np.array(image_gt_iscrowd, dtype=bool))
  152. groundtruths_areas.append(np.array(image_gt_areas))
  153. groundtruths_bboxes.append(np.array(image_gt_bboxes))
  154. return predicted_poses, predicted_scores, groundtruths_poses, groundtruths_iscrowd, groundtruths_areas, groundtruths_bboxes, image_ids
  155. def _coco_convert_predictions_to_dict(self, predicted_poses, predicted_scores, image_ids):
  156. kpts = collections.defaultdict(list)
  157. for poses, scores, image_id_int in zip(predicted_poses, predicted_scores, image_ids):
  158. for person_index, kpt in enumerate(poses):
  159. area = (np.max(kpt[:, 0]) - np.min(kpt[:, 0])) * (np.max(kpt[:, 1]) - np.min(kpt[:, 1]))
  160. kpt = self._coco_process_keypoints(kpt)
  161. kpts[image_id_int].append({"keypoints": kpt[:, 0:3], "score": float(scores[person_index]), "image": image_id_int, "area": area})
  162. oks_nmsed_kpts = []
  163. # image x person x (keypoints)
  164. for img in kpts.keys():
  165. # person x (keypoints)
  166. img_kpts = kpts[img]
  167. # person x (keypoints)
  168. # do not use nms, keep all detections
  169. keep = []
  170. if len(keep) == 0:
  171. oks_nmsed_kpts.append(img_kpts)
  172. else:
  173. oks_nmsed_kpts.append([img_kpts[_keep] for _keep in keep])
  174. classes = ["__background__", "person"]
  175. _class_to_coco_ind = {cls: i for i, cls in enumerate(classes)}
  176. data_pack = [
  177. {"cat_id": _class_to_coco_ind[cls], "cls_ind": cls_ind, "cls": cls, "ann_type": "keypoints", "keypoints": oks_nmsed_kpts}
  178. for cls_ind, cls in enumerate(classes)
  179. if not cls == "__background__"
  180. ]
  181. results = self._coco_keypoint_results_one_category_kernel(data_pack[0], num_joints=17)
  182. return results
  183. def _coco_keypoint_results_one_category_kernel(self, data_pack, num_joints: int):
  184. cat_id = data_pack["cat_id"]
  185. keypoints = data_pack["keypoints"]
  186. cat_results = []
  187. for img_kpts in keypoints:
  188. if len(img_kpts) == 0:
  189. continue
  190. _key_points = np.array([img_kpts[k]["keypoints"] for k in range(len(img_kpts))])
  191. key_points = np.zeros((_key_points.shape[0], num_joints * 3), dtype=np.float32)
  192. for ipt in range(num_joints):
  193. key_points[:, ipt * 3 + 0] = _key_points[:, ipt, 0]
  194. key_points[:, ipt * 3 + 1] = _key_points[:, ipt, 1]
  195. # keypoints score.
  196. key_points[:, ipt * 3 + 2] = _key_points[:, ipt, 2]
  197. for k in range(len(img_kpts)):
  198. kpt = key_points[k].reshape((num_joints, 3))
  199. left_top = np.amin(kpt, axis=0)
  200. right_bottom = np.amax(kpt, axis=0)
  201. w = right_bottom[0] - left_top[0]
  202. h = right_bottom[1] - left_top[1]
  203. cat_results.append(
  204. {
  205. "image_id": img_kpts[k]["image"],
  206. "category_id": cat_id,
  207. "keypoints": list(key_points[k]),
  208. "score": img_kpts[k]["score"],
  209. "bbox": list([left_top[0], left_top[1], w, h]),
  210. }
  211. )
  212. return cat_results
  213. def _coco_process_keypoints(self, keypoints):
  214. tmp = keypoints.copy()
  215. if keypoints[:, 2].max() > 0:
  216. num_keypoints = keypoints.shape[0]
  217. for i in range(num_keypoints):
  218. tmp[i][0:3] = [float(keypoints[i][0]), float(keypoints[i][1]), float(keypoints[i][2])]
  219. return tmp
  220. if __name__ == "__main__":
  221. unittest.main()
Discard