Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

yolo_v5_loss.py 12 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
  1. from typing import List, Tuple, Union
  2. import torch
  3. from torch import nn
  4. from torch.nn.modules.loss import _Loss
  5. from super_gradients.training.losses.focal_loss import FocalLoss
  6. from super_gradients.training.utils.detection_utils import calculate_bbox_iou_elementwise, Anchors
  7. class YoLoV5DetectionLoss(_Loss):
  8. """
  9. Calculate YOLO V5 loss:
  10. L = L_objectivness + L_boxes + L_classification
  11. """
  12. def __init__(self, anchors: Anchors,
  13. cls_pos_weight: Union[float, List[float]] = 1.0, obj_pos_weight: float = 1.0,
  14. obj_loss_gain: float = 1.0, box_loss_gain: float = 0.05, cls_loss_gain: float = 0.5,
  15. focal_loss_gamma: float = 0.0,
  16. cls_objectness_weights: Union[List[float], torch.Tensor] = None):
  17. """
  18. :param anchors: the anchors of the model (same anchors used for training)
  19. :param cls_pos_weight: pos_weight for BCE in L_classification,
  20. can be one value for all positives or a list of weights for each class
  21. :param obj_pos_weight: pos_weight for BCE in L_objectivness
  22. :param obj_loss_gain: coef for L_objectivness
  23. :param box_loss_gain: coef for L_boxes
  24. :param cls_loss_gain: coef for L_classification
  25. :param focal_loss_gamma: gamma for a focal loss, 0 to train with a usual BCE
  26. :param cls_objectness_weights: class-based weight for L_objectivness that will be applied in each cell that
  27. has a GT assigned to it.
  28. Note: default weight for objectness loss in each cell is 1.
  29. """
  30. super(YoLoV5DetectionLoss, self).__init__()
  31. self.cls_pos_weight = cls_pos_weight
  32. self.obj_pos_weight = obj_pos_weight
  33. self.obj_loss_gain = obj_loss_gain
  34. self.box_loss_gain = box_loss_gain
  35. self.cls_loss_gain = cls_loss_gain
  36. self.focal_loss_gamma = focal_loss_gamma
  37. self.anchors = anchors
  38. self.cls_obj_weights = cls_objectness_weights
  39. if isinstance(cls_objectness_weights, list):
  40. self.cls_obj_weights = torch.nn.Parameter(torch.tensor(cls_objectness_weights))
  41. def forward(self, model_output, targets):
  42. if isinstance(model_output, tuple) and len(model_output) == 2:
  43. # in test/eval mode the Yolo v5 model output a tuple where the second item is the raw predictions
  44. _, predictions = model_output
  45. else:
  46. predictions = model_output
  47. return self.compute_loss(predictions, targets)
  48. def build_targets(self, predictions: List[torch.Tensor], targets: torch.Tensor, anchor_threshold=4.0) \
  49. -> Tuple[List[torch.Tensor], List[torch.Tensor], List[Tuple[torch.Tensor]], List[torch.Tensor]]:
  50. """
  51. Assign targets to anchors to use in L_boxes & L_classification calculation:
  52. * each target can be assigned to a few anchors,
  53. all anchors that are within [1/anchor_threshold, anchor_threshold] times target size range
  54. * each anchor can be assigned to a few targets
  55. :param predictions: Yolo predictions
  56. :param targets: ground truth targets
  57. :param anchor_threshold: ratio defining a size range of an appropriate anchor
  58. :return: each of 4 outputs contains one element for each Yolo output,
  59. correspondences are raveled over the whole batch and all anchors:
  60. * classes of the targets;
  61. * boxes of the targets;
  62. * image id in a batch, anchor id, grid y, grid x coordinates;
  63. * anchor sizes.
  64. All the above can be indexed in parallel to get the selected correspondences
  65. """
  66. num_anchors, num_targets = self.anchors.num_anchors, targets.shape[0]
  67. target_classes, target_boxes, indices, anchors = [], [], [], []
  68. gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
  69. anchor_indices = torch.arange(num_anchors, device=targets.device)
  70. anchor_indices = anchor_indices.float().view(num_anchors, 1).repeat(1, num_targets)
  71. # repeat all targets for each anchor and append a corresponding anchor index
  72. targets = torch.cat((targets.repeat(num_anchors, 1, 1), anchor_indices[:, :, None]), 2)
  73. bias = 0.5
  74. off = torch.tensor([[0, 0],
  75. [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m
  76. ], device=targets.device).float() * bias # offsets
  77. for i in range(self.anchors.detection_layers_num):
  78. anch = self.anchors.anchors[i]
  79. gain[2:6] = torch.tensor(predictions[i].shape)[[3, 2, 3, 2]] # xyxy gain
  80. # Convert target coordinates from [0, 1] range to coordinates in [0, GridY], [0, GridX] ranges
  81. t = targets * gain
  82. if num_targets:
  83. # Match: filter targets by anchor size ratio
  84. r = t[:, :, 4:6] / anch[:, None] # wh ratio
  85. filtered_targets_ids = torch.max(r, 1. / r).max(2)[0] < anchor_threshold # compare
  86. t = t[filtered_targets_ids]
  87. # Find coordinates of targets on a grid
  88. gxy = t[:, 2:4] # grid xy
  89. gxi = gain[[2, 3]] - gxy # inverse
  90. j, k = ((gxy % 1. < bias) & (gxy > 1.)).T
  91. l, m = ((gxi % 1. < bias) & (gxi > 1.)).T
  92. j = torch.stack((torch.ones_like(j), j, k, l, m))
  93. t = t.repeat((5, 1, 1))[j]
  94. offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
  95. else:
  96. t = targets[0]
  97. offsets = 0
  98. # Define
  99. b, c = t[:, :2].long().T # image, class
  100. gxy = t[:, 2:4] # grid xy
  101. gwh = t[:, 4:6] # grid wh
  102. gij = (gxy - offsets).long()
  103. gi, gj = gij.T # grid xy indices
  104. # prevent coordinates from going out of bounds
  105. gi, gj = gi.clamp_(0, gain[2] - 1), gj.clamp_(0, gain[3] - 1)
  106. # Append
  107. a = t[:, 6].long() # anchor indices
  108. indices.append((b, a, gj, gi)) # image, anchor, grid indices
  109. target_boxes.append(torch.cat((gxy - gij, gwh), 1)) # box
  110. anchors.append(anch[a]) # anchors
  111. target_classes.append(c) # class
  112. return target_classes, target_boxes, indices, anchors
  113. def compute_loss(self, predictions: List[torch.Tensor], targets: torch.Tensor, giou_loss_ratio: float = 1.0) \
  114. -> Tuple[torch.Tensor, torch.Tensor]:
  115. """
  116. L = L_objectivness + L_boxes + L_classification
  117. where:
  118. * L_boxes and L_classification are calculated only between anchors and targets that suit them;
  119. * L_objectivness is calculated on all anchors.
  120. L_classification:
  121. for anchors that have suitable ground truths in their grid locations add BCEs
  122. to force max probability for each GT class in a multi-label way
  123. Coef: self.cls_loss_gain
  124. L_boxes:
  125. for anchors that have suitable ground truths in their grid locations
  126. add (1 - IoU), IoU between a predicted box and each GT box, force maximum IoU
  127. Coef: self.box_loss_gain
  128. L_objectness:
  129. for each anchor add BCE to force a prediction of (1 - giou_loss_ratio) + giou_loss_ratio * IoU,
  130. IoU between a predicted box and random GT in it
  131. Coef: self.obj_loss_gain, loss from each YOLO grid is additionally multiplied by balance = [4.0, 1.0, 0.4]
  132. to balance different contributions coming from different numbers of grid cells
  133. :param predictions: output from all Yolo levels, each of shape
  134. [Batch x Num_Anchors x GridSizeY x GridSizeX x (4 + 1 + Num_classes)]
  135. :param targets: [Num_targets x (4 + 2)], values on dim 1 are: image id in a batch, class, box x y w h
  136. :param giou_loss_ratio: a coef in L_objectness defining what should be predicted as objecness
  137. in a call with a target: can be a value in [IoU, 1] range
  138. :return: loss, all losses separately in a detached tensor
  139. """
  140. device = targets.device
  141. loss_classification, loss_boxes, loss_objectivness = \
  142. torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
  143. target_classes, target_boxes, indices, anchors = self.build_targets(predictions, targets) # targets
  144. # Define criteria
  145. BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([self.cls_pos_weight])).to(device)
  146. BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([self.obj_pos_weight]), reduction='none').to(device)
  147. # Focal loss
  148. if self.focal_loss_gamma > 0:
  149. BCEcls, BCEobj = FocalLoss(BCEcls, self.focal_loss_gamma), FocalLoss(BCEobj, self.focal_loss_gamma)
  150. # Losses
  151. num_targets = 0
  152. num_predictions = len(predictions)
  153. balance = [4.0, 1.0, 0.4] if num_predictions == 3 else [4.0, 1.0, 0.4, 0.1] # P3-5 or P3-6
  154. for i, prediction in enumerate(predictions): # layer index, layer predictions
  155. image, anchor, grid_y, grid_x = indices[i]
  156. target_obj = torch.zeros_like(prediction[..., 0], device=device)
  157. weight_obj = torch.ones_like(prediction[..., 0], device=device)
  158. n = image.shape[0] # number of targets
  159. if n:
  160. num_targets += n # cumulative targets
  161. ps = prediction[image, anchor, grid_y, grid_x] # prediction subset corresponding to targets
  162. # Boxes loss
  163. pxy = ps[:, :2].sigmoid() * 2. - 0.5
  164. pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
  165. pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box
  166. iou = calculate_bbox_iou_elementwise(pbox.T, target_boxes[i], x1y1x2y2=False, CIoU=True)
  167. loss_boxes += (1.0 - iou).mean() # iou loss
  168. # Objectness loss target
  169. target_obj[image, anchor, grid_y, grid_x] = \
  170. (1.0 - giou_loss_ratio) + giou_loss_ratio * iou.detach().clamp(0).type(target_obj.dtype)
  171. # Weights for weighted objectness
  172. if self.cls_obj_weights is not None:
  173. # NOTE: for grid cells that have a few ground truths with different classes assigned to them
  174. # objectness weight will be picked randomly from one of these classes
  175. weight_obj[image, anchor, grid_y, grid_x] = self.cls_obj_weights[target_classes[i]]
  176. # Classification loss
  177. if ps.shape[1] > 6: # cls loss (only if multiple classes)
  178. t = torch.full_like(ps[:, 5:], 0, device=device) # targets
  179. t[range(n), target_classes[i]] = 1
  180. loss_classification += BCEcls(ps[:, 5:], t) # BCE
  181. # Objectness loss
  182. loss_obj_cur_head = BCEobj(prediction[..., 4], target_obj)
  183. loss_obj_cur_head = torch.sum(loss_obj_cur_head * weight_obj / torch.sum(weight_obj))
  184. loss_objectivness += loss_obj_cur_head * balance[i] # obj loss
  185. batch_size = prediction.shape[0] # batch size
  186. loss = loss_boxes * self.box_loss_gain + loss_objectivness * self.obj_loss_gain + loss_classification * self.cls_loss_gain
  187. # IMPORTANT: box, obj and cls loss are logged scaled by gain in ultralytics
  188. # and are logged unscaled in our codebase
  189. return loss * batch_size, torch.cat((loss_boxes, loss_objectivness, loss_classification, loss)).detach()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...