Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

yolo_v3_loss.py 3.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
  1. import torch
  2. from torch import nn
  3. from torch.nn.modules.loss import _Loss
  4. from super_gradients.training.utils.detection_utils import build_detection_targets, calculate_bbox_iou_elementwise
  5. class YoLoV3DetectionLoss(_Loss):
  6. """
  7. YoLoV3DetectionLoss - Loss Class for Object Detection
  8. """
  9. def __init__(self, model: nn.Module, cls_pw: float = 1., obj_pw: float = 1., giou: float = 3.54, obj: float = 64.3,
  10. cls: float = 37.4):
  11. super(YoLoV3DetectionLoss, self).__init__()
  12. self.model = model
  13. self.cls_pw = cls_pw
  14. self.obj_pw = obj_pw
  15. self.giou = giou
  16. self.obj = obj
  17. self.cls = cls
  18. self.classes_num = self.model.net.module.num_classes
  19. def forward(self, model_output, targets):
  20. if isinstance(model_output, tuple) and len(model_output) == 2:
  21. # in test/eval mode the Yolo v3 model output a tuple where the second item is the raw predictions
  22. _, predictions = model_output
  23. else:
  24. predictions = model_output
  25. detection_targets = build_detection_targets(self.model.net.module, targets)
  26. float_tensor = torch.cuda.FloatTensor if predictions[0].is_cuda else torch.Tensor
  27. class_loss, giou_loss, objectness_loss = float_tensor([0]), float_tensor([0]), float_tensor([0])
  28. target_class, target_box, indices, anchor_vec = detection_targets
  29. reduction = 'mean' # Loss reduction (sum or mean)
  30. # DEFINE CRITERIA
  31. BCEcls = nn.BCEWithLogitsLoss(pos_weight=float_tensor([self.cls_pw]), reduction=reduction)
  32. BCEobj = nn.BCEWithLogitsLoss(pos_weight=float_tensor([self.obj_pw]), reduction=reduction)
  33. # COMPUTE THE LOSSES BASED ON EACH ONE OF THE YOLO LAYERS PREDICTIONS
  34. grid_points_num, targets_num = 0, 0
  35. for yolo_layer_index, yolo_layer_prediction in enumerate(predictions):
  36. image, anchor, grid_y, grid_x = indices[yolo_layer_index]
  37. target_object = torch.zeros_like(yolo_layer_prediction[..., 0])
  38. grid_points_num += target_object.numel()
  39. # COMPUTE LOSSES
  40. nb = len(image)
  41. if nb: # number of targets
  42. targets_num += nb
  43. predictions_for_targets = yolo_layer_prediction[image, anchor, grid_y, grid_x]
  44. target_object[image, anchor, grid_y, grid_x] = 1.0
  45. # GIoU LOSS CALCULATION
  46. pxy = torch.sigmoid(
  47. predictions_for_targets[:, 0:2]) # pxy = pxy * s - (s - 1) / 2, s = 1.5 (scale_xy)
  48. bbox_prediction = torch.cat(
  49. (pxy, torch.exp(predictions_for_targets[:, 2:4]).clamp(max=1E3) * anchor_vec[yolo_layer_index]), 1)
  50. giou = 1.0 - calculate_bbox_iou_elementwise(bbox_prediction.t(), target_box[yolo_layer_index],
  51. x1y1x2y2=False, GIoU=True)
  52. giou_loss += giou.sum() if reduction == 'sum' else giou.mean()
  53. # ONLY RELEVANT TO MULTIPLE CLASSES
  54. if self.classes_num > 1:
  55. class_targets = torch.zeros_like(predictions_for_targets[:, 5:])
  56. class_targets[range(nb), target_class[yolo_layer_index]] = 1.0
  57. class_loss += BCEcls(predictions_for_targets[:, 5:], class_targets)
  58. objectness_loss += BCEobj(yolo_layer_prediction[..., 4], target_object)
  59. if reduction == 'sum':
  60. giou_loss *= 3 / targets_num
  61. objectness_loss *= 3 / grid_points_num
  62. class_loss *= 3 / targets_num / self.classes_num
  63. loss = giou_loss * self.giou + objectness_loss * self.obj + class_loss * self.cls
  64. return loss, torch.cat((giou_loss, objectness_loss, class_loss, loss)).detach()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...