Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

ssd.py 7.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
  1. import torch
  2. import torch.nn as nn
  3. from super_gradients.training.models import MobileNet, SgModule, MobileNetV2, InvertedResidual
  4. from super_gradients.training.utils import HpmStruct, utils
  5. from super_gradients.training.utils.module_utils import MultiOutputModule
  6. DEFAULT_SSD_ARCH_PARAMS = {
  7. "num_defaults": [4, 6, 6, 6, 4, 4],
  8. "additional_blocks_bottleneck_channels": [256, 256, 128, 128, 128]
  9. }
  10. DEFAULT_SSD_MOBILENET_V1_ARCH_PARAMS = {
  11. "out_channels": [512, 1024, 512, 256, 256, 256],
  12. "kernel_sizes": [3, 3, 3, 3, 2]
  13. }
  14. DEFAULT_SSD_LITE_MOBILENET_V2_ARCH_PARAMS = {
  15. "out_channels": [576, 1280, 512, 256, 256, 64],
  16. "expand_ratios": [0.2, 0.25, 0.5, 0.25],
  17. "num_defaults": [6, 6, 6, 6, 6, 6],
  18. "lite": True,
  19. "width_mult": 1.0,
  20. "output_paths": [[14, 'conv', 2], 18]
  21. }
  22. def SeperableConv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True):
  23. """Replace Conv2d with a depthwise Conv2d and Pointwise Conv2d.
  24. """
  25. return nn.Sequential(
  26. nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
  27. groups=in_channels, stride=stride, padding=padding, bias=bias),
  28. nn.BatchNorm2d(in_channels),
  29. nn.ReLU(),
  30. nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
  31. )
  32. class SSD(SgModule):
  33. """
  34. paper: https://arxiv.org/pdf/1512.02325.pdf
  35. based on code: https://github.com/NVIDIA/DeepLearningExamples
  36. """
  37. def __init__(self, backbone, arch_params):
  38. super().__init__()
  39. self.arch_params = HpmStruct(**DEFAULT_SSD_ARCH_PARAMS)
  40. self.arch_params.override(**arch_params.to_dict())
  41. paths = utils.get_param(self.arch_params, 'output_paths')
  42. if paths is not None:
  43. self.backbone = MultiOutputModule(backbone, paths)
  44. else:
  45. self.backbone = backbone
  46. lite = utils.get_param(arch_params, 'lite', False)
  47. # NUMBER OF CLASSES + 1 NO_CLASS
  48. self.num_classes = self.arch_params.num_classes
  49. self._build_additional_blocks()
  50. self._build_location_and_conf_branches(self.arch_params.out_channels, lite)
  51. self._init_weights()
  52. def _build_location_and_conf_branches(self, out_channels, lite: bool):
  53. """Add the sdd blocks after the backbone"""
  54. self.num_defaults = self.arch_params.num_defaults
  55. self.loc = []
  56. self.conf = []
  57. conv_to_use = SeperableConv2d if lite else nn.Conv2d
  58. for i, (nd, oc) in enumerate(zip(self.num_defaults, out_channels)):
  59. if i < len(self.num_defaults) - 1:
  60. self.loc.append(conv_to_use(oc, nd * 4, kernel_size=3, padding=1))
  61. self.conf.append(conv_to_use(oc, nd * self.num_classes, kernel_size=3, padding=1))
  62. else:
  63. self.loc.append(nn.Conv2d(oc, nd * 4, kernel_size=3, padding=1))
  64. self.conf.append(nn.Conv2d(oc, nd * self.num_classes, kernel_size=3, padding=1))
  65. self.loc = nn.ModuleList(self.loc)
  66. self.conf = nn.ModuleList(self.conf)
  67. def _build_additional_blocks(self):
  68. input_size = self.arch_params.out_channels
  69. kernel_sizes = self.arch_params.kernel_sizes
  70. bottleneck_channels = self.arch_params.additional_blocks_bottleneck_channels
  71. self.additional_blocks = []
  72. for i, (input_size, output_size, channels, kernel_size) in enumerate(
  73. zip(input_size[:-1], input_size[1:], bottleneck_channels, kernel_sizes)):
  74. if i < 3:
  75. middle_layer = nn.Conv2d(channels, output_size, kernel_size=kernel_size, padding=1, stride=2,
  76. bias=False)
  77. else:
  78. middle_layer = nn.Conv2d(channels, output_size, kernel_size=kernel_size, bias=False)
  79. layer = nn.Sequential(
  80. nn.Conv2d(input_size, channels, kernel_size=1, bias=False),
  81. nn.BatchNorm2d(channels),
  82. nn.ReLU(inplace=True),
  83. middle_layer,
  84. nn.BatchNorm2d(output_size),
  85. nn.ReLU(inplace=True),
  86. )
  87. self.additional_blocks.append(layer)
  88. self.additional_blocks = nn.ModuleList(self.additional_blocks)
  89. def _init_weights(self):
  90. layers = [*self.additional_blocks, *self.loc, *self.conf]
  91. for layer in layers:
  92. for param in layer.parameters():
  93. if param.dim() > 1:
  94. nn.init.xavier_uniform_(param)
  95. def bbox_view(self, src, loc, conf):
  96. """ Shape the classifier to the view of bboxes """
  97. ret = []
  98. for s, l, c in zip(src, loc, conf):
  99. ret.append((l(s).view(s.size(0), 4, -1), c(s).view(s.size(0), self.num_classes, -1)))
  100. locs, confs = list(zip(*ret))
  101. locs, confs = torch.cat(locs, 2).contiguous(), torch.cat(confs, 2).contiguous()
  102. return locs, confs
  103. def forward(self, x):
  104. x = self.backbone(x)
  105. # IF THE BACKBONE IS A MultiOutputModule WE GET A LIST, OTHERWISE WE WRAP IT IN A LIST
  106. detection_feed = x if isinstance(x, list) else [x]
  107. x = detection_feed[-1]
  108. for block in self.additional_blocks:
  109. x = block(x)
  110. detection_feed.append(x)
  111. # FEATURE MAPS: i.e. FOR 300X300 INPUT - 38X38X4, 19X19X6, 10X10X6, 5X5X6, 3X3X4, 1X1X4
  112. locs, confs = self.bbox_view(detection_feed, self.loc, self.conf)
  113. # FOR 300X300 INPUT - RETURN N_BATCH X 8732 X {N_LABELS, N_LOCS} RESULTS
  114. return locs, confs
  115. class SSDMobileNetV1(SSD):
  116. """
  117. paper: http://ceur-ws.org/Vol-2500/paper_5.pdf
  118. """
  119. def __init__(self, arch_params: HpmStruct):
  120. self.arch_params = HpmStruct(**DEFAULT_SSD_MOBILENET_V1_ARCH_PARAMS)
  121. self.arch_params.override(**arch_params.to_dict())
  122. mobilenet_backbone = MobileNet(num_classes=None, backbone_mode=True, up_to_layer=10)
  123. super().__init__(backbone=mobilenet_backbone, arch_params=self.arch_params)
  124. class SSDLiteMobileNetV2(SSD):
  125. def __init__(self, arch_params: HpmStruct):
  126. self.arch_params = HpmStruct(**DEFAULT_SSD_LITE_MOBILENET_V2_ARCH_PARAMS)
  127. self.arch_params.override(**arch_params.to_dict())
  128. self.arch_params.out_channels[0] = int(round(self.arch_params.out_channels[0] * self.arch_params.width_mult))
  129. mobilenetv2 = MobileNetV2(num_classes=None, backbone_mode=True, width_mult=self.arch_params.width_mult)
  130. super().__init__(backbone=mobilenetv2.features, arch_params=self.arch_params)
  131. # OVERRIDE THE DEFAULT FUNCTION FROM SSD. ADD THE SDD BLOCKS AFTER THE BACKBONE.
  132. def _build_additional_blocks(self):
  133. channels = self.arch_params.out_channels
  134. expand_ratios = self.arch_params.expand_ratios
  135. self.additional_blocks = []
  136. for in_channels, out_channels, expand_ratio in zip(channels[1:-1], channels[2:], expand_ratios):
  137. self.additional_blocks.append(
  138. InvertedResidual(in_channels, out_channels, stride=2, expand_ratio=expand_ratio))
  139. self.additional_blocks = nn.ModuleList(self.additional_blocks)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...