Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

detection_metrics.py 7.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
  1. import numpy as np
  2. import torch
  3. from torchmetrics import Metric
  4. from super_gradients.training.utils.detection_utils import calc_batch_prediction_accuracy, DetectionPostPredictionCallback, \
  5. IouThreshold
  6. import super_gradients
  7. def compute_ap(recall, precision, method: str = 'interp'):
  8. """ Compute the average precision, given the recall and precision curves.
  9. Source: https://github.com/rbgirshick/py-faster-rcnn.
  10. # Arguments
  11. :param recall: The recall curve - ndarray [1, points in curve]
  12. :param precision: The precision curve - ndarray [1, points in curve]
  13. :param method: 'continuous', 'interp'
  14. # Returns
  15. The average precision as computed in py-faster-rcnn.
  16. """
  17. # IN ORDER TO CALCULATE, WE HAVE TO MAKE SURE THE CURVES GO ALL THE WAY TO THE AXES (FROM X=0 TO Y=0)
  18. # THIS IS HOW IT IS COMPUTED IN ORIGINAL REPO - A MORE CORRECT COMPUTE WOULD BE ([0.], recall, [recall[-1] + 1E-3])
  19. wrapped_recall = np.concatenate(([0.], recall, [1.0]))
  20. wrapped_precision = np.concatenate(([1.], precision, [0.]))
  21. # COMPUTE THE PRECISION ENVELOPE
  22. wrapped_precision = np.flip(np.maximum.accumulate(np.flip(wrapped_precision)))
  23. # INTEGRATE AREA UNDER CURVE
  24. if method == 'interp':
  25. x = np.linspace(0, 1, 101) # 101-point interp (COCO)
  26. ap = np.trapz(np.interp(x, wrapped_recall, wrapped_precision), x) # integrate
  27. else: # 'continuous'
  28. i = np.where(wrapped_recall[1:] != wrapped_recall[:-1])[0] # points where x axis (recall) changes
  29. ap = np.sum((wrapped_recall[i + 1] - wrapped_recall[i]) * wrapped_precision[i + 1]) # area under curve
  30. return ap
  31. def ap_per_class(tp, conf, pred_cls, target_cls):
  32. """ Compute the average precision, given the recall and precision curves.
  33. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
  34. # Arguments
  35. tp: True positives (nparray, nx1 or nx10).
  36. conf: Objectness value from 0-1 (nparray).
  37. pred_cls: Predicted object classes (nparray).
  38. target_cls: True object classes (nparray).
  39. # Returns
  40. The average precision as computed in py-faster-rcnn.
  41. """
  42. # SORT BY OBJECTNESS
  43. i = np.argsort(-conf)
  44. tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
  45. # FIND UNIQUE CLASSES
  46. unique_classes = np.unique(target_cls)
  47. # CREATE PRECISION-RECALL CURVE AND COMPUTE AP FOR EACH CLASS
  48. pr_score = 0.1 # SCORE TO EVALUATE P AND R https://github.com/ultralytics/yolov3/issues/898
  49. s = [unique_classes.shape[0], tp.shape[1]] # NUMBER CLASS, NUMBER IOU THRESHOLDS (I.E. 10 FOR MAP0.5...0.95)
  50. ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s)
  51. for ci, c in enumerate(unique_classes):
  52. i = pred_cls == c
  53. ground_truth_num = (target_cls == c).sum() # NUMBER OF GROUND TRUTH OBJECTS
  54. predictions_num = i.sum() # NUMBER OF PREDICTED OBJECTS
  55. if predictions_num == 0 or ground_truth_num == 0:
  56. continue
  57. else:
  58. # ACCUMULATE FPS AND TPS
  59. fpc = (1 - tp[i]).cumsum(0)
  60. tpc = tp[i].cumsum(0)
  61. # RECALL
  62. recall = tpc / (ground_truth_num + 1e-16) # RECALL CURVE
  63. r[ci] = np.interp(-pr_score, -conf[i], recall[:, 0]) # R AT PR_SCORE, NEGATIVE X, XP BECAUSE XP DECREASES
  64. # PRECISION
  65. precision = tpc / (tpc + fpc) # precision curve
  66. p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # P AT PR_SCORE
  67. # AP FROM RECALL-PRECISION CURVE
  68. for j in range(tp.shape[1]):
  69. ap[ci, j] = compute_ap(recall[:, j], precision[:, j])
  70. # COMPUTE F1 SCORE (HARMONIC MEAN OF PRECISION AND RECALL)
  71. f1 = 2 * p * r / (p + r + 1e-16)
  72. return p, r, ap, f1, unique_classes.astype('int32')
  73. class DetectionMetrics(Metric):
  74. def __init__(self, num_cls,
  75. post_prediction_callback: DetectionPostPredictionCallback = None,
  76. iou_thres: IouThreshold = IouThreshold.MAP_05_TO_095,
  77. dist_sync_on_step=False):
  78. """
  79. @param post_prediction_callback:
  80. @param iou_thres:
  81. @param dist_sync_on_step:
  82. """
  83. super().__init__(dist_sync_on_step=dist_sync_on_step)
  84. self.num_cls = num_cls
  85. self.iou_thres = iou_thres
  86. self.map_str = 'mAP@%.1f' % iou_thres[0] if not iou_thres.is_range() else 'mAP@%.2f:%.2f' % iou_thres
  87. self.component_names = ["Precision", "Recall", self.map_str, "F1"]
  88. self.components = len(self.component_names)
  89. self.post_prediction_callback = post_prediction_callback
  90. self.is_distributed = super_gradients.is_distributed()
  91. self.world_size = None
  92. self.rank = None
  93. self.add_state("metrics", default=[], dist_reduce_fx=None)
  94. def update(self, preds: torch.Tensor, target: torch.Tensor, device, inputs):
  95. preds = self.post_prediction_callback(preds, device=device)
  96. _, _, height, width = inputs.shape
  97. metrics, batch_images_counter = calc_batch_prediction_accuracy(preds, target, height, width,
  98. self.iou_thres)
  99. acc_metrics = getattr(self, "metrics")
  100. setattr(self, "metrics", acc_metrics + metrics)
  101. def compute(self):
  102. precision, recall, f1, mean_precision, mean_recall, mean_ap, mf1 = 0., 0., 0., 0., 0., 0., 0.
  103. metrics = getattr(self, "metrics")
  104. metrics = [np.concatenate(x, 0) for x in list(zip(*metrics))]
  105. if len(metrics):
  106. precision, recall, average_precision, f1, ap_class = ap_per_class(*metrics)
  107. if self.iou_thres.is_range():
  108. precision, recall, average_precision, f1 = precision[:, 0], recall[:, 0], average_precision.mean(
  109. 1), average_precision[:, 0]
  110. mean_precision, mean_recall, mean_ap, mf1 = precision.mean(), recall.mean(), average_precision.mean(), f1.mean()
  111. return {"Precision": mean_precision, "Recall": mean_recall, self.map_str: mean_ap, "F1": mf1}
  112. def _sync_dist(self, dist_sync_fn=None, process_group=None):
  113. """
  114. When in distributed mode, stats are aggregated after each forward pass to the metric state. Since these have all
  115. different sizes we override the synchronization function since it works only for tensors (and use
  116. all_gather_object)
  117. @param dist_sync_fn:
  118. @return:
  119. """
  120. if self.world_size is None:
  121. self.world_size = torch.distributed.get_world_size() if self.is_distributed else -1
  122. if self.rank is None:
  123. self.rank = torch.distributed.get_rank() if self.is_distributed else -1
  124. if self.is_distributed:
  125. local_state_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()}
  126. gathered_state_dicts = [None] * self.world_size
  127. torch.distributed.barrier()
  128. torch.distributed.all_gather_object(gathered_state_dicts, local_state_dict)
  129. metrics = []
  130. for state_dict in gathered_state_dicts:
  131. metrics += state_dict["metrics"]
  132. setattr(self, "metrics", metrics)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...