1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
|
- from typing import List, Tuple, Union
- import torch
- from torch import nn
- from torch.nn.modules.loss import _Loss
- from super_gradients.training.losses.focal_loss import FocalLoss
- from super_gradients.training.utils.detection_utils import calculate_bbox_iou_elementwise, Anchors
- class YoLoV5DetectionLoss(_Loss):
- """
- Calculate YOLO V5 loss:
- L = L_objectivness + L_boxes + L_classification
- """
- def __init__(self, anchors: Anchors,
- cls_pos_weight: Union[float, List[float]] = 1.0, obj_pos_weight: float = 1.0,
- obj_loss_gain: float = 1.0, box_loss_gain: float = 0.05, cls_loss_gain: float = 0.5,
- focal_loss_gamma: float = 0.0,
- cls_objectness_weights: Union[List[float], torch.Tensor] = None):
- """
- :param anchors: the anchors of the model (same anchors used for training)
- :param cls_pos_weight: pos_weight for BCE in L_classification,
- can be one value for all positives or a list of weights for each class
- :param obj_pos_weight: pos_weight for BCE in L_objectivness
- :param obj_loss_gain: coef for L_objectivness
- :param box_loss_gain: coef for L_boxes
- :param cls_loss_gain: coef for L_classification
- :param focal_loss_gamma: gamma for a focal loss, 0 to train with a usual BCE
- :param cls_objectness_weights: class-based weight for L_objectivness that will be applied in each cell that
- has a GT assigned to it.
- Note: default weight for objectness loss in each cell is 1.
- """
- super(YoLoV5DetectionLoss, self).__init__()
- self.cls_pos_weight = cls_pos_weight
- self.obj_pos_weight = obj_pos_weight
- self.obj_loss_gain = obj_loss_gain
- self.box_loss_gain = box_loss_gain
- self.cls_loss_gain = cls_loss_gain
- self.focal_loss_gamma = focal_loss_gamma
- self.anchors = anchors
- self.cls_obj_weights = cls_objectness_weights
- if isinstance(cls_objectness_weights, list):
- self.cls_obj_weights = torch.nn.Parameter(torch.tensor(cls_objectness_weights))
- def forward(self, model_output, targets):
- if isinstance(model_output, tuple) and len(model_output) == 2:
- # in test/eval mode the Yolo v5 model output a tuple where the second item is the raw predictions
- _, predictions = model_output
- else:
- predictions = model_output
- return self.compute_loss(predictions, targets)
- def build_targets(self, predictions: List[torch.Tensor], targets: torch.Tensor, anchor_threshold=4.0) \
- -> Tuple[List[torch.Tensor], List[torch.Tensor], List[Tuple[torch.Tensor]], List[torch.Tensor]]:
- """
- Assign targets to anchors to use in L_boxes & L_classification calculation:
- * each target can be assigned to a few anchors,
- all anchors that are within [1/anchor_threshold, anchor_threshold] times target size range
- * each anchor can be assigned to a few targets
- :param predictions: Yolo predictions
- :param targets: ground truth targets
- :param anchor_threshold: ratio defining a size range of an appropriate anchor
- :return: each of 4 outputs contains one element for each Yolo output,
- correspondences are raveled over the whole batch and all anchors:
- * classes of the targets;
- * boxes of the targets;
- * image id in a batch, anchor id, grid y, grid x coordinates;
- * anchor sizes.
- All the above can be indexed in parallel to get the selected correspondences
- """
- num_anchors, num_targets = self.anchors.num_anchors, targets.shape[0]
- target_classes, target_boxes, indices, anchors = [], [], [], []
- gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
- anchor_indices = torch.arange(num_anchors, device=targets.device)
- anchor_indices = anchor_indices.float().view(num_anchors, 1).repeat(1, num_targets)
- # repeat all targets for each anchor and append a corresponding anchor index
- targets = torch.cat((targets.repeat(num_anchors, 1, 1), anchor_indices[:, :, None]), 2)
- bias = 0.5
- off = torch.tensor([[0, 0],
- [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m
- ], device=targets.device).float() * bias # offsets
- for i in range(self.anchors.detection_layers_num):
- anch = self.anchors.anchors[i]
- gain[2:6] = torch.tensor(predictions[i].shape)[[3, 2, 3, 2]] # xyxy gain
- # Convert target coordinates from [0, 1] range to coordinates in [0, GridY], [0, GridX] ranges
- t = targets * gain
- if num_targets:
- # Match: filter targets by anchor size ratio
- r = t[:, :, 4:6] / anch[:, None] # wh ratio
- filtered_targets_ids = torch.max(r, 1. / r).max(2)[0] < anchor_threshold # compare
- t = t[filtered_targets_ids]
- # Find coordinates of targets on a grid
- gxy = t[:, 2:4] # grid xy
- gxi = gain[[2, 3]] - gxy # inverse
- j, k = ((gxy % 1. < bias) & (gxy > 1.)).T
- l, m = ((gxi % 1. < bias) & (gxi > 1.)).T
- j = torch.stack((torch.ones_like(j), j, k, l, m))
- t = t.repeat((5, 1, 1))[j]
- offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
- else:
- t = targets[0]
- offsets = 0
- # Define
- b, c = t[:, :2].long().T # image, class
- gxy = t[:, 2:4] # grid xy
- gwh = t[:, 4:6] # grid wh
- gij = (gxy - offsets).long()
- gi, gj = gij.T # grid xy indices
- # prevent coordinates from going out of bounds
- gi, gj = gi.clamp_(0, gain[2] - 1), gj.clamp_(0, gain[3] - 1)
- # Append
- a = t[:, 6].long() # anchor indices
- indices.append((b, a, gj, gi)) # image, anchor, grid indices
- target_boxes.append(torch.cat((gxy - gij, gwh), 1)) # box
- anchors.append(anch[a]) # anchors
- target_classes.append(c) # class
- return target_classes, target_boxes, indices, anchors
- def compute_loss(self, predictions: List[torch.Tensor], targets: torch.Tensor, giou_loss_ratio: float = 1.0) \
- -> Tuple[torch.Tensor, torch.Tensor]:
- """
- L = L_objectivness + L_boxes + L_classification
- where:
- * L_boxes and L_classification are calculated only between anchors and targets that suit them;
- * L_objectivness is calculated on all anchors.
- L_classification:
- for anchors that have suitable ground truths in their grid locations add BCEs
- to force max probability for each GT class in a multi-label way
- Coef: self.cls_loss_gain
- L_boxes:
- for anchors that have suitable ground truths in their grid locations
- add (1 - IoU), IoU between a predicted box and each GT box, force maximum IoU
- Coef: self.box_loss_gain
- L_objectness:
- for each anchor add BCE to force a prediction of (1 - giou_loss_ratio) + giou_loss_ratio * IoU,
- IoU between a predicted box and random GT in it
- Coef: self.obj_loss_gain, loss from each YOLO grid is additionally multiplied by balance = [4.0, 1.0, 0.4]
- to balance different contributions coming from different numbers of grid cells
- :param predictions: output from all Yolo levels, each of shape
- [Batch x Num_Anchors x GridSizeY x GridSizeX x (4 + 1 + Num_classes)]
- :param targets: [Num_targets x (4 + 2)], values on dim 1 are: image id in a batch, class, box x y w h
- :param giou_loss_ratio: a coef in L_objectness defining what should be predicted as objecness
- in a call with a target: can be a value in [IoU, 1] range
- :return: loss, all losses separately in a detached tensor
- """
- device = targets.device
- loss_classification, loss_boxes, loss_objectivness = \
- torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
- target_classes, target_boxes, indices, anchors = self.build_targets(predictions, targets) # targets
- # Define criteria
- BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([self.cls_pos_weight])).to(device)
- BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([self.obj_pos_weight]), reduction='none').to(device)
- # Focal loss
- if self.focal_loss_gamma > 0:
- BCEcls, BCEobj = FocalLoss(BCEcls, self.focal_loss_gamma), FocalLoss(BCEobj, self.focal_loss_gamma)
- # Losses
- num_targets = 0
- num_predictions = len(predictions)
- balance = [4.0, 1.0, 0.4] if num_predictions == 3 else [4.0, 1.0, 0.4, 0.1] # P3-5 or P3-6
- for i, prediction in enumerate(predictions): # layer index, layer predictions
- image, anchor, grid_y, grid_x = indices[i]
- target_obj = torch.zeros_like(prediction[..., 0], device=device)
- weight_obj = torch.ones_like(prediction[..., 0], device=device)
- n = image.shape[0] # number of targets
- if n:
- num_targets += n # cumulative targets
- ps = prediction[image, anchor, grid_y, grid_x] # prediction subset corresponding to targets
- # Boxes loss
- pxy = ps[:, :2].sigmoid() * 2. - 0.5
- pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
- pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box
- iou = calculate_bbox_iou_elementwise(pbox.T, target_boxes[i], x1y1x2y2=False, CIoU=True)
- loss_boxes += (1.0 - iou).mean() # iou loss
- # Objectness loss target
- target_obj[image, anchor, grid_y, grid_x] = \
- (1.0 - giou_loss_ratio) + giou_loss_ratio * iou.detach().clamp(0).type(target_obj.dtype)
- # Weights for weighted objectness
- if self.cls_obj_weights is not None:
- # NOTE: for grid cells that have a few ground truths with different classes assigned to them
- # objectness weight will be picked randomly from one of these classes
- weight_obj[image, anchor, grid_y, grid_x] = self.cls_obj_weights[target_classes[i]]
- # Classification loss
- if ps.shape[1] > 6: # cls loss (only if multiple classes)
- t = torch.full_like(ps[:, 5:], 0, device=device) # targets
- t[range(n), target_classes[i]] = 1
- loss_classification += BCEcls(ps[:, 5:], t) # BCE
- # Objectness loss
- loss_obj_cur_head = BCEobj(prediction[..., 4], target_obj)
- loss_obj_cur_head = torch.sum(loss_obj_cur_head * weight_obj / torch.sum(weight_obj))
- loss_objectivness += loss_obj_cur_head * balance[i] # obj loss
- batch_size = prediction.shape[0] # batch size
- loss = loss_boxes * self.box_loss_gain + loss_objectivness * self.obj_loss_gain + loss_classification * self.cls_loss_gain
- # IMPORTANT: box, obj and cls loss are logged scaled by gain in ultralytics
- # and are logged unscaled in our codebase
- return loss * batch_size, torch.cat((loss_boxes, loss_objectivness, loss_classification, loss)).detach()
|