1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
|
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from super_gradients.training.models import BasicBlock, Bottleneck, SgModule, HpmStruct
- """
- paper: Deep Dual-resolution Networks for Real-time and
- Accurate Semantic Segmentation of Road Scenes ( https://arxiv.org/pdf/2101.06085.pdf )
- code from git repo: https://github.com/ydhongHIT/DDRNet
- """
- def ConvBN(in_channels: int, out_channels: int, kernel_size: int, bias=True, stride=1, padding=0, add_relu=False):
- seq = [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias, stride=stride, padding=padding),
- nn.BatchNorm2d(out_channels)]
- if add_relu:
- seq.append(nn.ReLU(inplace=True))
- return nn.Sequential(*seq)
- def _make_layer(block, in_planes, planes, num_blocks, stride=1, expansion=1):
- layers = []
- layers.append(block(in_planes, planes, stride, final_relu=num_blocks > 1, expansion=expansion))
- in_planes = planes * expansion
- if num_blocks > 1:
- for i in range(1, num_blocks):
- if i == (num_blocks - 1):
- layers.append(block(in_planes, planes, stride=1, final_relu=False, expansion=expansion))
- else:
- layers.append(block(in_planes, planes, stride=1, final_relu=True, expansion=expansion))
- return nn.Sequential(*layers)
- class DAPPMBranch(nn.Module):
- def __init__(self, kernel_size: int, stride: int, in_planes: int, branch_planes: int, inter_mode: str = 'bilinear'):
- """
- A DAPPM branch
- :param kernel_size: the kernel size for the average pooling
- when stride=0 this parameter is omitted and AdaptiveAvgPool2d over all the input is performed
- :param stride: stride for the average pooling
- when stride=0: an AdaptiveAvgPool2d over all the input is performed (output is 1x1)
- when stride=1: no average pooling is performed
- when stride>1: average polling is performed (scaling the input down and up again)
- :param in_planes:
- :param branch_planes: width after the the first convolution
- :param inter_mode: interpolation mode for upscaling
- """
- super().__init__()
- down_list = []
- if stride == 0:
- # when stride is 0 average pool all the input to 1x1
- down_list.append(nn.AdaptiveAvgPool2d((1, 1)))
- elif stride == 1:
- # when stride id 1 no average pooling is used
- pass
- else:
- down_list.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=stride))
- down_list.append(nn.BatchNorm2d(in_planes))
- down_list.append(nn.ReLU(inplace=True))
- down_list.append(nn.Conv2d(in_planes, branch_planes, kernel_size=1, bias=False))
- self.down_scale = nn.Sequential(*down_list)
- self.up_scale = UpscaleOnline(inter_mode)
- if stride != 1:
- self.process = nn.Sequential(
- nn.BatchNorm2d(branch_planes),
- nn.ReLU(inplace=True),
- nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
- )
- def forward(self, x):
- """
- All branches of the DAPPM but the first one receive the output of the previous branch as a second input
- :param x: in branch 0 - the original input of the DAPPM. in other branches - a list containing the original
- input and the output of the previous branch.
- """
- if isinstance(x, list):
- output_of_prev_branch = x[1]
- x = x[0]
- else:
- output_of_prev_branch = None
- in_width = x.shape[-1]
- in_height = x.shape[-2]
- out = self.down_scale(x)
- out = self.up_scale(out, output_height=in_height, output_width=in_width)
- if output_of_prev_branch is not None:
- out = self.process(out + output_of_prev_branch)
- return out
- class DAPPM(nn.Module):
- def __init__(self, in_planes: int, branch_planes: int, out_planes: int,
- kernel_sizes: list, strides: list, inter_mode: str = 'bilinear'):
- super().__init__()
- assert len(kernel_sizes) == len(strides), 'len of kernel_sizes and strides must be the same'
- self.branches = nn.ModuleList()
- for kernel_size, stride in zip(kernel_sizes, strides):
- self.branches.append(DAPPMBranch(kernel_size=kernel_size, stride=stride,
- in_planes=in_planes, branch_planes=branch_planes, inter_mode=inter_mode))
- self.compression = nn.Sequential(
- nn.BatchNorm2d(branch_planes * len(self.branches)),
- nn.ReLU(inplace=True),
- nn.Conv2d(branch_planes * len(self.branches), out_planes, kernel_size=1, bias=False),
- )
- self.shortcut = nn.Sequential(
- nn.BatchNorm2d(in_planes),
- nn.ReLU(inplace=True),
- nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False),
- )
- def forward(self, x):
- x_list = []
- for i, branch in enumerate(self.branches):
- if i == 0:
- x_list.append(branch(x))
- else:
- x_list.append(branch([x, x_list[i - 1]]))
- out = self.compression(torch.cat(x_list, 1)) + self.shortcut(x)
- return out
- class SegmentHead(nn.Module):
- def __init__(self, in_planes: int, inter_planes: int, out_planes: int, scale_factor: int,
- inter_mode: str = 'bilinear'):
- """
- Last stage of the segmentation network.
- Reduces the number of output planes (usually to num_classes) while increasing the size by scale_factor
- :param in_planes: width of input
- :param inter_planes: width of internal conv. must be a multiple of scale_factor^2 when inter_mode=pixel_shuffle
- :param out_planes: output width
- :param scale_factor: scaling factor
- :param inter_mode: one of nearest, linear, bilinear, bicubic, trilinear, area or pixel_shuffle.
- when set to pixel_shuffle, an nn.PixelShuffle will be used for scaling
- """
- super().__init__()
- if inter_mode == 'pixel_shuffle':
- assert inter_planes % (scale_factor ^ 2) == 0, 'when using pixel_shuffle, inter_planes must be a multiple of scale_factor^2'
- self.bn1 = nn.BatchNorm2d(in_planes)
- self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=3, padding=1, bias=False)
- self.bn2 = nn.BatchNorm2d(inter_planes)
- self.relu = nn.ReLU(inplace=True)
- if inter_mode == 'pixel_shuffle':
- self.conv2 = nn.Conv2d(inter_planes, inter_planes, kernel_size=1, padding=0, bias=True)
- self.upscale = nn.PixelShuffle(scale_factor)
- else:
- self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=1, padding=0, bias=True)
- self.upscale = nn.Upsample(scale_factor=scale_factor, mode=inter_mode)
- self.scale_factor = scale_factor
- def forward(self, x):
- x = self.conv1(self.relu(self.bn1(x)))
- out = self.conv2(self.relu(self.bn2(x)))
- out = self.upscale(out)
- return out
- class UpscaleOnline(nn.Module):
- """
- In some cases the required scale/size for the scaling is known only when the input is received.
- This class support such cases. only the interpolation mode is set in advance.
- """
- def __init__(self, mode='bilinear'):
- super().__init__()
- self.mode = mode
- def forward(self, x, output_height: int, output_width: int):
- return F.interpolate(x, size=[output_height, output_width], mode=self.mode)
- class DDRBackBoneBase(nn.Module):
- """A base class defining functions that must be supported by DDRBackBones """
- def validate_backbone_attributes(self):
- expected_attributes = ['stem', 'layer1', 'layer2', 'layer3', 'layer4', 'input_channels']
- for attribute in expected_attributes:
- assert hasattr(self, attribute), f'Invalid backbone - attribute \'{attribute}\' is missing'
- def get_backbone_output_number_of_channels(self):
- """Return a dictionary of the shapes of each output of the backbone to determine the in_channels of the
- skip and compress layers"""
- output_shapes = {}
- x = torch.randn(1, self.input_channels, 320, 320)
- x = self.stem(x)
- x = self.layer1(x)
- x = self.layer2(x)
- output_shapes['layer2'] = x.shape[1]
- x = self.layer3(x)
- output_shapes['layer3'] = x.shape[1]
- x = self.layer4(x)
- output_shapes['layer4'] = x.shape[1]
- return output_shapes
- class BasicDDRBackBone(DDRBackBoneBase):
- def __init__(self, block: nn.Module.__class__, width: int, layers: list, input_channels: int):
- super().__init__()
- self.input_channels = input_channels
- self.stem = nn.Sequential(
- ConvBN(in_channels=input_channels, out_channels=width, kernel_size=3, stride=2, padding=1, add_relu=True),
- ConvBN(in_channels=width, out_channels=width, kernel_size=3, stride=2, padding=1, add_relu=True),
- )
- self.layer1 = _make_layer(block=block, in_planes=width, planes=width, num_blocks=layers[0])
- self.layer2 = _make_layer(block=block, in_planes=width, planes=width * 2, num_blocks=layers[1], stride=2)
- self.layer3 = _make_layer(block=block, in_planes=width * 2, planes=width * 4, num_blocks=layers[2], stride=2)
- self.layer4 = _make_layer(block=block, in_planes=width * 4, planes=width * 8, num_blocks=layers[3], stride=2)
- class RegnetDDRBackBone(DDRBackBoneBase):
- """
- Translation of Regnet to fit DDR model
- """
- def __init__(self, regnet_module: nn.Module.__class__):
- super().__init__()
- self.input_channels = regnet_module.net.stem.conv.in_channels
- self.stem = regnet_module.net.stem
- self.layer1 = regnet_module.net.stage_0
- self.layer2 = regnet_module.net.stage_1
- self.layer3 = regnet_module.net.stage_2
- self.layer4 = regnet_module.net.stage_3
- class DDRNet(SgModule):
- def __init__(self, backbone: DDRBackBoneBase.__class__, additional_layers: list, upscale_module: nn.Module,
- num_classes: int,
- highres_planes: int, spp_width: int, head_width: int, aux_head: bool = False,
- ssp_inter_mode: str = 'bilinear',
- segmentation_inter_mode: str = 'bilinear', skip_block: nn.Module.__class__ = None,
- layer5_block: nn.Module.__class__ = Bottleneck, layer5_bottleneck_expansion: int = 2,
- classification_mode=False, spp_kernel_sizes: list = [1, 5, 9, 17, 0],
- spp_strides: list = [1, 2, 4, 8, 0]):
- """
- :param backbone: the low resolution branch of DDR, expected to have specific attributes in the class
- :param additional_layers: list of num blocks for the highres stage and layer5
- :param upscale_module: upscale to use in the backbone (DAPPM and Segmentation head are using bilinear interpolation)
- :param num_classes: number of classes
- :param highres_planes: number of channels in the high resolution net
- :param aux_head: add a second segmentation head (fed from after compress3 + upscale). this head can be used
- during training (see paper https://arxiv.org/pdf/2101.06085.pdf for details)
- :param ssp_inter_mode: the interpolation used in the SPP block
- :param segmentation_inter_mode: the interpolation used in the segmentation head
- :param skip_block: allows specifying a different block (from 'block') for the skip layer
- :param layer5_block: type of block to use in layer5 and layer5_skip
- :param layer5_bottleneck_expansion: determines the expansion rate for Bottleneck block
- :param spp_kernel_sizes: list of kernel sizes for the spp module pooling
- :param spp_strides: list of strides for the spp module pooling
- """
- super().__init__()
- self.aux_head = aux_head
- self.upscale = upscale_module
- self.ssp_inter_mode = ssp_inter_mode
- self.segmentation_inter_mode = segmentation_inter_mode
- self.relu = nn.ReLU(inplace=False)
- self.classification_mode = classification_mode
- assert not (aux_head and classification_mode), "auxiliary head cannot be used in classification mode"
- assert isinstance(backbone, DDRBackBoneBase), 'The backbone must inherit from AbstractDDRBackBone'
- self.backbone = backbone
- self.backbone.validate_backbone_attributes()
- out_chan_backbone = self.backbone.get_backbone_output_number_of_channels()
- self.compression3 = ConvBN(in_channels=out_chan_backbone['layer3'], out_channels=highres_planes, kernel_size=1,
- bias=False)
- self.compression4 = ConvBN(in_channels=out_chan_backbone['layer4'], out_channels=highres_planes, kernel_size=1,
- bias=False)
- self.down3 = ConvBN(in_channels=highres_planes, out_channels=out_chan_backbone['layer3'], kernel_size=3,
- stride=2, padding=1,
- bias=False)
- self.down4 = nn.Sequential(
- ConvBN(in_channels=highres_planes, out_channels=highres_planes * 2, kernel_size=3, stride=2, padding=1,
- bias=False, add_relu=True),
- ConvBN(in_channels=highres_planes * 2, out_channels=out_chan_backbone['layer4'], kernel_size=3, stride=2,
- padding=1, bias=False))
- self.layer3_skip = _make_layer(block=skip_block, in_planes=out_chan_backbone['layer2'], planes=highres_planes,
- num_blocks=additional_layers[1])
- self.layer4_skip = _make_layer(block=skip_block, in_planes=highres_planes, planes=highres_planes,
- num_blocks=additional_layers[2])
- self.layer5_skip = _make_layer(block=layer5_block, in_planes=highres_planes, planes=highres_planes,
- num_blocks=additional_layers[3], expansion=layer5_bottleneck_expansion)
- # when training the backbones on Imagenet:
- # - layer 5 has stride 1
- # - a new high_to_low_fusion is added with to 3x3 convs with stride 2 (and double the width)
- # - a classification head is placed instead of the segmentation head
- if self.classification_mode:
- self.layer5 = _make_layer(block=layer5_block, in_planes=out_chan_backbone['layer4'],
- planes=out_chan_backbone['layer4'], num_blocks=additional_layers[0],
- expansion=layer5_bottleneck_expansion)
- highres_planes_out = highres_planes * layer5_bottleneck_expansion
- self.high_to_low_fusion = nn.Sequential(ConvBN(in_channels=highres_planes_out,
- out_channels=highres_planes_out * 2,
- kernel_size=3, stride=2,
- padding=1, add_relu=True),
- ConvBN(in_channels=highres_planes_out * 2,
- out_channels=out_chan_backbone['layer4'] * layer5_bottleneck_expansion,
- kernel_size=3, stride=2,
- padding=1, add_relu=True))
- self.average_pool = nn.AdaptiveAvgPool2d(1)
- self.fc = nn.Linear(in_features=out_chan_backbone['layer4'] * layer5_bottleneck_expansion,
- out_features=num_classes)
- else:
- self.layer5 = _make_layer(block=layer5_block, in_planes=out_chan_backbone['layer4'],
- planes=out_chan_backbone['layer4'], num_blocks=additional_layers[0],
- stride=2, expansion=layer5_bottleneck_expansion)
- self.spp = DAPPM(in_planes=out_chan_backbone['layer4'] * layer5_bottleneck_expansion,
- branch_planes=spp_width, out_planes=highres_planes * layer5_bottleneck_expansion,
- inter_mode=self.ssp_inter_mode, kernel_sizes=spp_kernel_sizes, strides=spp_strides)
- if self.aux_head:
- self.seghead_extra = SegmentHead(highres_planes, head_width, num_classes, 8,
- inter_mode=self.segmentation_inter_mode)
- self.final_layer = SegmentHead(highres_planes * layer5_bottleneck_expansion,
- head_width, num_classes, 8, inter_mode=self.segmentation_inter_mode)
- self._initialize_weights()
- def _initialize_weights(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- def forward(self, x):
- width_output = x.shape[-1] // 8
- height_output = x.shape[-2] // 8
- x = self.backbone.stem(x)
- x = self.backbone.layer1(x)
- out_layer2 = self.backbone.layer2(self.relu(x))
- out_layer3 = self.backbone.layer3(self.relu(out_layer2))
- out_layer3_skip = self.layer3_skip(self.relu(out_layer2))
- x = out_layer3 + self.down3(self.relu(out_layer3_skip))
- x_skip = out_layer3_skip + self.upscale(self.compression3(self.relu(out_layer3)), height_output, width_output)
- # save for auxiliary head
- if self.aux_head:
- temp = x_skip
- out_layer4 = self.backbone.layer4(self.relu(x))
- out_layer4_skip = self.layer4_skip(self.relu(x_skip))
- x = out_layer4 + self.down4(self.relu(out_layer4_skip))
- x_skip = out_layer4_skip + self.upscale(self.compression4(self.relu(out_layer4)), height_output, width_output)
- out_layer5_skip = self.layer5_skip(self.relu(x_skip))
- if self.classification_mode:
- x_skip = self.high_to_low_fusion(self.relu(out_layer5_skip))
- x = self.layer5(self.relu(x))
- x = self.average_pool(x + x_skip)
- x = self.fc(x.squeeze())
- return x
- else:
- x = self.upscale(self.spp(self.layer5(self.relu(x))), height_output, width_output)
- x = self.final_layer(x + out_layer5_skip)
- if self.aux_head:
- x_extra = self.seghead_extra(temp)
- return x, x_extra
- else:
- return x
- class DDRNetCustom(DDRNet):
- def __init__(self, arch_params: HpmStruct):
- """ Parse arch_params and translate the parameters to build the original DDRNet architecture """
- super().__init__(backbone=arch_params.backbone,
- additional_layers=arch_params.additional_layers,
- upscale_module=arch_params.upscale_module,
- num_classes=arch_params.num_classes,
- highres_planes=arch_params.highres_planes,
- spp_width=arch_params.spp_planes,
- head_width=arch_params.head_planes,
- aux_head=arch_params.aux_head,
- ssp_inter_mode=arch_params.ssp_inter_mode,
- segmentation_inter_mode=arch_params.segmentation_inter_mode,
- skip_block=arch_params.skip_block,
- layer5_block=arch_params.layer5_block,
- layer5_bottleneck_expansion=arch_params.layer5_bottleneck_expansion,
- classification_mode=arch_params.classification_mode,
- spp_kernel_sizes=arch_params.spp_kernel_sizes,
- spp_strides=arch_params.spp_strides)
- DEFAULT_DDRNET_23_PARAMS = {
- "input_channels": 3,
- "block": BasicBlock,
- "skip_block": BasicBlock,
- "layer5_block": Bottleneck,
- "layer5_bottleneck_expansion": 2,
- "layers": [2, 2, 2, 2, 1, 2, 2, 1],
- "upscale_module": UpscaleOnline(),
- "planes": 64,
- "highres_planes": 128,
- "head_planes": 128,
- "aux_head": False,
- "segmentation_inter_mode": 'bilinear',
- "classification_mode": False,
- "spp_planes": 128,
- "ssp_inter_mode": 'bilinear',
- "spp_kernel_sizes": [1, 5, 9, 17, 0],
- "spp_strides": [1, 2, 4, 8, 0],
- }
- DEFAULT_DDRNET_23_SLIM_PARAMS = {
- **DEFAULT_DDRNET_23_PARAMS,
- "planes": 32,
- "highres_planes": 64,
- "head_planes": 64,
- }
- class DDRNet23(DDRNetCustom):
- def __init__(self, arch_params: HpmStruct):
- _arch_params = HpmStruct(**DEFAULT_DDRNET_23_PARAMS)
- _arch_params.override(**arch_params.to_dict())
- # BUILD THE BACKBONE AND INSERT TO THE _arch_params
- backbone_layers, _arch_params.additional_layers = _arch_params.layers[:4], _arch_params.layers[4:]
- _arch_params.backbone = BasicDDRBackBone(block=_arch_params.block, width=_arch_params.planes,
- layers=backbone_layers,
- input_channels=_arch_params.input_channels)
- super().__init__(_arch_params)
- class DDRNet23Slim(DDRNetCustom):
- def __init__(self, arch_params: HpmStruct):
- _arch_params = HpmStruct(**DEFAULT_DDRNET_23_SLIM_PARAMS)
- _arch_params.override(**arch_params.to_dict())
- # BUILD THE BACKBONE AND INSERT TO THE _arch_params
- backbone_layers, _arch_params.additional_layers = _arch_params.layers[:4], _arch_params.layers[4:]
- _arch_params.backbone = BasicDDRBackBone(block=_arch_params.block, width=_arch_params.planes,
- layers=backbone_layers,
- input_channels=_arch_params.input_channels)
- super().__init__(_arch_params)
- class AnyBackBoneDDRNet23(DDRNetCustom):
- def __init__(self, arch_params: HpmStruct):
- _arch_params = HpmStruct(**DEFAULT_DDRNET_23_PARAMS)
- _arch_params.override(**arch_params.to_dict())
- assert len(_arch_params.layers) == 4 or len(_arch_params.layers) == 8, \
- 'The length of \'arch_params.layers\' must be 4 or 8'
- # TAKE THE LAST 4 NUMBERS AS THE ADDITIONAL LAYERS SPECIFICATION
- _arch_params.additional_layers = _arch_params.layers[-4:]
- assert hasattr(_arch_params, 'backbone'), 'AnyBackBoneDDRNet_23 requires having a backbone in arch_params'
- if hasattr(_arch_params, 'input_channels'):
- assert _arch_params.backbone.input_channels == _arch_params.input_channels, \
- '\'input_channels\' was given in arch_params with a different value than existing in the backbone'
- super().__init__(_arch_params)
|