|
@@ -84,13 +84,12 @@ class DetectionMetrics(Metric):
|
|
Apply NMS and match all the predictions and targets of a given batch, and update the metric state accordingly.
|
|
Apply NMS and match all the predictions and targets of a given batch, and update the metric state accordingly.
|
|
|
|
|
|
:param preds : Raw output of the model, the format might change from one model to another, but has to fit
|
|
:param preds : Raw output of the model, the format might change from one model to another, but has to fit
|
|
- the input format of the post_prediction_callback
|
|
|
|
- :param target: Targets for all images of shape (total_num_targets, 6)
|
|
|
|
- format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]
|
|
|
|
|
|
+ the input format of the post_prediction_callback (cx,cy,wh)
|
|
|
|
+ :param target: Targets for all images of shape (total_num_targets, 6) LABEL_CXCYWH
|
|
|
|
+ format: (index, label, cx, cy, w, h)
|
|
:param device: Device to run on
|
|
:param device: Device to run on
|
|
:param inputs: Input image tensor of shape (batch_size, n_img, height, width)
|
|
:param inputs: Input image tensor of shape (batch_size, n_img, height, width)
|
|
- :param crowd_targets: Crowd targets for all images of shape (total_num_targets, 6)
|
|
|
|
- format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]
|
|
|
|
|
|
+ :param crowd_targets: Crowd targets for all images of shape (total_num_targets, 6), LABEL_CXCYWH
|
|
"""
|
|
"""
|
|
self.iou_thresholds = self.iou_thresholds.to(device)
|
|
self.iou_thresholds = self.iou_thresholds.to(device)
|
|
_, _, height, width = inputs.shape
|
|
_, _, height, width = inputs.shape
|