Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

ssd_loss.py 4.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
  1. import torch
  2. from torch import nn
  3. from torch.nn.modules.loss import _Loss
  4. from super_gradients.training.utils.detection_utils import calculate_bbox_iou_matrix
  5. from super_gradients.training.utils.ssd_utils import DefaultBoxes
  6. class SSDLoss(_Loss):
  7. """
  8. Implements the loss as the sum of the followings:
  9. 1. Confidence Loss: All labels, with hard negative mining
  10. 2. Localization Loss: Only on positive labels
  11. """
  12. def __init__(self, dboxes: DefaultBoxes, alpha: float = 1.0):
  13. super(SSDLoss, self).__init__()
  14. self.scale_xy = 1.0 / dboxes.scale_xy
  15. self.scale_wh = 1.0 / dboxes.scale_wh
  16. self.alpha = alpha
  17. self.sl1_loss = nn.SmoothL1Loss(reduce=False)
  18. self.dboxes = nn.Parameter(dboxes(order="xywh").transpose(0, 1).unsqueeze(dim=0), requires_grad=False)
  19. self.con_loss = nn.CrossEntropyLoss(reduce=False)
  20. def _norm_relative_bbox(self, loc):
  21. """
  22. convert bbox locations into relative locations (relative to the dboxes) and normalized by w,h
  23. :param loc a tensor of shape [batch, 4, num_boxes]
  24. """
  25. gxy = self.scale_xy * (loc[:, :2, :] - self.dboxes[:, :2, :]) / self.dboxes[:, 2:, ]
  26. gwh = self.scale_wh * (loc[:, 2:, :] / self.dboxes[:, 2:, :]).log()
  27. return torch.cat((gxy, gwh), dim=1).contiguous()
  28. def match_dboxes(self, targets):
  29. """
  30. convert ground truth boxes into a tensor with the same size as dboxes. each gt bbox is matched to every
  31. destination box which overlaps it over 0.5 (IoU). so some gt bboxes can be duplicated to a few destination boxes
  32. :param targets: a tensor containing the boxes for a single image. shape [num_boxes, 5] (x,y,w,h,label)
  33. :return: two tensors
  34. boxes - shape of dboxes [4, num_dboxes] (x,y,w,h)
  35. labels - sahpe [num_dboxes]
  36. """
  37. target_locations = self.dboxes.data.clone().squeeze()
  38. target_labels = torch.zeros((self.dboxes.data.shape[2])).to(self.dboxes.device)
  39. if len(targets) > 0:
  40. boxes = targets[:, 2:]
  41. ious = calculate_bbox_iou_matrix(boxes, self.dboxes.data.squeeze().T, x1y1x2y2=False)
  42. values, indices = torch.max(ious, dim=0)
  43. mask = values > 0.5
  44. target_locations[:, mask] = targets[indices[mask], 2:].T
  45. target_labels[mask] = targets[indices[mask], 1]
  46. return target_locations, target_labels
  47. def forward(self, predictions, targets):
  48. """
  49. Compute the loss
  50. :param predictions - predictions tensor coming from the network. shape [N, num_classes+4, num_dboxes]
  51. were the first four items are (x,y,w,h) and the rest are class confidence
  52. :param targets - targets for the batch. [num targets, 6] (index in batch, label, x,y,w,h)
  53. """
  54. batch_target_locations = []
  55. batch_target_labels = []
  56. (ploc, plabel) = predictions
  57. targets = targets.to(self.dboxes.device)
  58. for i in range(ploc.shape[0]):
  59. target_locations, target_labels = self.match_dboxes(targets[targets[:, 0] == i])
  60. batch_target_locations.append(target_locations)
  61. batch_target_labels.append(target_labels)
  62. batch_target_locations = torch.stack(batch_target_locations)
  63. batch_target_labels = torch.stack(batch_target_labels).type(torch.long)
  64. mask = batch_target_labels > 0
  65. pos_num = mask.sum(dim=1)
  66. vec_gd = self._norm_relative_bbox(batch_target_locations)
  67. # SUM ON FOUR COORDINATES, AND MASK
  68. sl1 = self.sl1_loss(ploc, vec_gd).sum(dim=1)
  69. sl1 = (mask.float() * sl1).sum(dim=1)
  70. # HARD NEGATIVE MINING
  71. con = self.con_loss(plabel, batch_target_labels)
  72. # POSITIVE MASK WILL NEVER SELECTED
  73. con_neg = con.clone()
  74. con_neg[mask] = 0
  75. _, con_idx = con_neg.sort(dim=1, descending=True)
  76. _, con_rank = con_idx.sort(dim=1)
  77. # NUMBER OF NEGATIVE THREE TIMES POSITIVE
  78. neg_num = torch.clamp(3 * pos_num, max=mask.size(1)).unsqueeze(-1)
  79. neg_mask = con_rank < neg_num
  80. closs = (con * (mask.float() + neg_mask.float())).sum(dim=1)
  81. # AVOID NO OBJECT DETECTED
  82. total_loss = (2 - self.alpha) * sl1 + self.alpha * closs
  83. num_mask = (pos_num > 0).float()
  84. pos_num = pos_num.float().clamp(min=1e-6)
  85. ret = (total_loss * num_mask / pos_num).mean(dim=0)
  86. return ret, torch.cat((sl1.mean().unsqueeze(0), closs.mean().unsqueeze(0), ret.unsqueeze(0))).detach()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...