Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#481 Feature/alg 287 refactor ssd

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/ALG-287_refactor-ssd
@@ -3,9 +3,12 @@ from abc import abstractmethod, ABC
 
 
 import torch
 import torch
 from torch import nn
 from torch import nn
+from omegaconf.listconfig import ListConfig
 from omegaconf import DictConfig
 from omegaconf import DictConfig
 
 
 from super_gradients.training.utils.utils import HpmStruct
 from super_gradients.training.utils.utils import HpmStruct
+from super_gradients.training.models import MobileNet, MobileNetV2, InvertedResidual
+from super_gradients.training.utils.module_utils import MultiOutputModule
 import super_gradients.common.factories.detection_modules_factory as det_factory
 import super_gradients.common.factories.detection_modules_factory as det_factory
 
 
 
 
@@ -24,6 +27,9 @@ class BaseDetectionModule(nn.Module, ABC):
     @property
     @property
     @abstractmethod
     @abstractmethod
     def out_channels(self) -> Union[List[int], int]:
     def out_channels(self) -> Union[List[int], int]:
+        """
+        :return: channels of tensor(s) that will be returned by a module  in forward
+        """
         raise NotImplementedError()
         raise NotImplementedError()
 
 
 
 
@@ -167,7 +173,217 @@ class NHeads(BaseDetectionModule):
         return outputs if self.training else (torch.cat(outputs, 1), outputs_logits)
         return outputs if self.training else (torch.cat(outputs, 1), outputs_logits)
 
 
 
 
+class MultiOutputBackbone(BaseDetectionModule):
+    """
+    Defines a backbone using MultiOutputModule with the interface of BaseDetectionModule
+    """
+
+    def __init__(self, in_channels: int, backbone: nn.Module, out_layers: List):
+        super().__init__(in_channels)
+        self.multi_output_backbone = MultiOutputModule(backbone, out_layers)
+        self._out_channels = [x.shape[1] for x in self.forward(torch.empty((1, in_channels, 64, 64)))]
+
+    @property
+    def out_channels(self) -> Union[List[int], int]:
+        return self._out_channels
+
+    def forward(self, x):
+        return self.multi_output_backbone(x)
+
+
+class MobileNetV1Backbone(MultiOutputBackbone):
+    """MobileNetV1 backbone with an option to return output of any layer"""
+
+    def __init__(self, in_channels: int, out_layers: List):
+        backbone = MobileNet(backbone_mode=True, num_classes=None, in_channels=in_channels)
+        super().__init__(in_channels, backbone, out_layers)
+
+
+class MobileNetV2Backbone(MultiOutputBackbone):
+    """MobileNetV2 backbone with an option to return output of any layer"""
+
+    def __init__(self, in_channels: int, out_layers: List, width_mult: float = 1.0, structure: List[List] = None, grouped_conv_size: int = 1):
+        backbone = MobileNetV2(
+            backbone_mode=True,
+            num_classes=None,
+            dropout=0.0,
+            width_mult=width_mult,
+            structure=structure,
+            grouped_conv_size=grouped_conv_size,
+            in_channels=in_channels,
+        )
+        super().__init__(in_channels, backbone, out_layers)
+
+
+class SSDNeck(BaseDetectionModule, ABC):
+    """
+    SSD neck which returns:
+     * outputs of the backbone, unchanged
+     * outputs of a custom number of additional blocks added after the last backbone stage (returns output of each block)
+    Has no skips to the backbone
+    """
+
+    def __init__(self, in_channels: Union[int, List[int]], blocks_out_channels: List[int], **kwargs):
+        in_channels = in_channels if isinstance(in_channels, (list, ListConfig)) else [in_channels]
+        super().__init__(in_channels)
+        self.neck_blocks = nn.ModuleList(self.create_blocks(in_channels[-1], blocks_out_channels, **kwargs))
+        self._out_channels = in_channels + list(blocks_out_channels)
+
+    @property
+    def out_channels(self) -> Union[List[int], int]:
+        return self._out_channels
+
+    @abstractmethod
+    def create_blocks(self, in_channels: int, blocks_out_channels, **kwargs):
+        raise NotImplementedError()
+
+    def forward(self, inputs):
+        outputs = inputs if isinstance(inputs, list) else [inputs]
+
+        x = outputs[-1]
+        for block in self.neck_blocks:
+            x = block(x)
+            outputs.append(x)
+
+        return outputs
+
+
+class SSDInvertedResidualNeck(SSDNeck):
+    """
+    Consecutive InvertedResidual blocks each starting with stride 2
+    """
+
+    def create_blocks(self, prev_channels: int, blocks_out_channels: List[int], expand_ratios: List[float], grouped_conv_size: int):
+        neck_blocks = []
+        for i in range(len(blocks_out_channels)):
+            out_channels = blocks_out_channels[i]
+            neck_blocks.append(InvertedResidual(prev_channels, out_channels, stride=2, expand_ratio=expand_ratios[i], grouped_conv_size=grouped_conv_size))
+            prev_channels = out_channels
+        return neck_blocks
+
+
+class SSDBottleneckNeck(SSDNeck):
+    """
+    Consecutive bottleneck blocks
+    """
+
+    def create_blocks(self, prev_channels: int, blocks_out_channels: List[int], bottleneck_channels: List[int], kernel_sizes: List[int], strides: List[int]):
+        neck_blocks = []
+        for i in range(len(blocks_out_channels)):
+            mid_channels = bottleneck_channels[i]
+            out_channels = blocks_out_channels[i]
+            kernel_size = kernel_sizes[i]
+            stride = strides[i]
+            padding = 1 if stride == 2 else 0
+            neck_blocks.append(
+                nn.Sequential(
+                    nn.Conv2d(prev_channels, mid_channels, kernel_size=1, bias=False),
+                    nn.BatchNorm2d(mid_channels),
+                    nn.ReLU(inplace=True),
+                    nn.Conv2d(mid_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride, bias=False),
+                    nn.BatchNorm2d(out_channels),
+                    nn.ReLU(inplace=True),
+                )
+            )
+            prev_channels = out_channels
+        return neck_blocks
+
+
+def SeperableConv2d(in_channels: int, out_channels: int, kernel_size: int = 1, stride: int = 1, padding: int = 0, bias: bool = True):
+    """Depthwise Conv2d and Pointwise Conv2d."""
+    return nn.Sequential(
+        nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels, stride=stride, padding=padding, bias=bias),
+        nn.BatchNorm2d(in_channels),
+        nn.ReLU(),
+        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1),
+    )
+
+
+class SSDHead(BaseDetectionModule):
+    """
+    A one-layer conv head attached to each input feature map.
+    A conv is implemented as two branches: localization and classification
+    """
+
+    def __init__(self, in_channels: Union[int, List[int]], num_classes, anchors, lite):
+        in_channels = in_channels if isinstance(in_channels, (list, ListConfig)) else [in_channels]
+        super().__init__(in_channels)
+
+        self.num_classes = num_classes
+        self.dboxes_xy = nn.Parameter(anchors("xywh")[:, :2], requires_grad=False)
+        self.dboxes_wh = nn.Parameter(anchors("xywh")[:, 2:], requires_grad=False)
+        scale_xy = anchors.scale_xy
+        scale_wh = anchors.scale_wh
+        scales = torch.tensor([scale_xy, scale_xy, scale_wh, scale_wh])
+        self.scales = nn.Parameter(scales, requires_grad=False)
+        self.img_size = nn.Parameter(torch.tensor([anchors.fig_size]), requires_grad=False)
+        self.num_anchors = anchors.num_anchors
+
+        loc_blocks = []
+        conf_blocks = []
+
+        for i, (num_anch, in_c) in enumerate(zip(self.num_anchors, in_channels)):
+            conv = SeperableConv2d if lite and i < len(self.num_anchors) - 1 else nn.Conv2d
+            loc_blocks.append(conv(in_c, num_anch * 4, kernel_size=3, padding=1))
+            conf_blocks.append(conv(in_c, num_anch * (self.num_classes + 1), kernel_size=3, padding=1))
+
+        self.loc = nn.ModuleList(loc_blocks)
+        self.conf = nn.ModuleList(conf_blocks)
+
+    @property
+    def out_channels(self) -> Union[List[int], int]:
+        return None
+
+    def forward(self, inputs):
+        inputs = inputs if isinstance(inputs, list) else [inputs]
+
+        preds = []
+        for i in range(len(inputs)):
+            boxes_preds = self.loc[i](inputs[i])
+            class_preds = self.conf[i](inputs[i])
+            preds.append([boxes_preds, class_preds])
+
+        return self.combine_preds(preds)
+
+    def combine_preds(self, preds):
+        batch_size = preds[0][0].shape[0]
+
+        for i in range(len(preds)):
+            box_pred_map, conf_pred_map = preds[i]
+            preds[i][0] = box_pred_map.view(batch_size, 4, -1)
+            preds[i][1] = conf_pred_map.view(batch_size, self.num_classes + 1, -1)
+
+        locs, confs = list(zip(*preds))
+        locs, confs = torch.cat(locs, 2).contiguous(), torch.cat(confs, 2).contiguous()
+
+        if self.training:
+            # FOR 300X300 INPUT - RETURN N_BATCH X 8732 X {N_LABELS, N_LOCS} RESULTS
+            return locs, confs
+        else:
+            bboxes_in = locs.permute(0, 2, 1)
+            scores_in = confs.permute(0, 2, 1)
+
+            bboxes_in *= self.scales
+
+            # CONVERT RELATIVE LOCATIONS INTO ABSOLUTE LOCATION (OUTPUT LOCATIONS ARE RELATIVE TO THE DBOXES)
+            xy = (bboxes_in[:, :, :2] * self.dboxes_wh + self.dboxes_xy) * self.img_size
+            wh = (bboxes_in[:, :, 2:].exp() * self.dboxes_wh) * self.img_size
+
+            # REPLACE THE CONFIDENCE OF CLASS NONE WITH OBJECT CONFIDENCE
+            # SSD DOES NOT OUTPUT OBJECT CONFIDENCE, REQUIRED FOR THE NMS
+            scores_in = torch.softmax(scores_in, dim=-1)
+            classes_conf = scores_in[:, :, 1:]
+            obj_conf = torch.max(classes_conf, dim=2)[0].unsqueeze(dim=-1)
+
+            return torch.cat((xy, wh, obj_conf, classes_conf), dim=2), (locs, confs)
+
+
 ALL_DETECTION_MODULES = {
 ALL_DETECTION_MODULES = {
+    "MobileNetV1Backbone": MobileNetV1Backbone,
+    "MobileNetV2Backbone": MobileNetV2Backbone,
+    "SSDInvertedResidualNeck": SSDInvertedResidualNeck,
+    "SSDBottleneckNeck": SSDBottleneckNeck,
+    "SSDHead": SSDHead,
     "NStageBackbone": NStageBackbone,
     "NStageBackbone": NStageBackbone,
     "PANNeck": PANNeck,
     "PANNeck": PANNeck,
     "NHeads": NHeads,
     "NHeads": NHeads,
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
  1. backbone:
  2. MobileNetV2Backbone:
  3. width_mult: 1.
  4. structure:
  5. grouped_conv_size: 1
  6. out_layers: [['features', 14, 'conv', 2], ['features', 18]]
  7. neck:
  8. SSDInvertedResidualNeck:
  9. blocks_out_channels: [512, 256, 256, 64]
  10. expand_ratios: [0.2, 0.25, 0.5, 0.25]
  11. grouped_conv_size: 1
  12. heads:
  13. SSDHead:
  14. num_classes: 80
  15. lite: True
  16. anchors:
  17. _target_: super_gradients.training.utils.ssd_utils.DefaultBoxes
  18. fig_size: 320
  19. feat_size: [20, 10, 5, 3, 2, 1]
  20. scales: [32, 82, 133, 184, 235, 285, 336]
  21. aspect_ratios: [[2, 3], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]]
  22. scale_xy: 0.1
  23. scale_wh: 0.2
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
  1. backbone:
  2. MobileNetV1Backbone:
  3. out_layers: [['layers', 9]]
  4. neck:
  5. SSDBottleneckNeck:
  6. blocks_out_channels: [1024, 512, 256, 256, 256]
  7. bottleneck_channels: [256, 256, 128, 128, 128]
  8. strides: [2, 2, 2, 1, 1]
  9. kernel_sizes: [3, 3, 3, 3, 2]
  10. heads:
  11. SSDHead:
  12. num_classes: 80
  13. lite: False
  14. anchors:
  15. _target_: super_gradients.training.utils.ssd_utils.DefaultBoxes
  16. fig_size: 320
  17. feat_size: [40, 20, 10, 5, 3, 2]
  18. scales: [22, 48, 106, 163, 221, 278, 336]
  19. aspect_ratios: [[2], [2, 3], [2, 3], [2, 3], [2], [2]]
  20. scale_xy: 0.1
  21. scale_wh: 0.2
Discard
@@ -1,219 +1,26 @@
-import torch
-import torch.nn as nn
+import copy
+from typing import Union
 
 
-from super_gradients.training.models import MobileNet, SgModule, MobileNetV2, InvertedResidual
+from omegaconf import DictConfig
 
 
-from super_gradients.training.utils import HpmStruct, utils
-from super_gradients.training.utils.module_utils import MultiOutputModule
-from super_gradients.training.utils.ssd_utils import DefaultBoxes
+from super_gradients.training.utils.hydra_utils import load_arch_params
+from super_gradients.training.utils.utils import HpmStruct
+from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector
 
 
-DEFAULT_SSD_ARCH_PARAMS = {
-    "additional_blocks_bottleneck_channels": [256, 256, 128, 128, 128]
-}
 
 
-DEFAULT_SSD_MOBILENET_V1_ARCH_PARAMS = {
-    "out_channels": [512, 1024, 512, 256, 256, 256],
-    "kernel_sizes": [3, 3, 3, 3, 2],
-    "anchors": DefaultBoxes(fig_size=320, feat_size=[40, 20, 10, 5, 3, 2], scales=[22, 48, 106, 163, 221, 278, 336],
-                            aspect_ratios=[[2], [2, 3], [2, 3], [2, 3], [2], [2]], scale_xy=0.1, scale_wh=0.2)
-}
+DEFAULT_SSD_MOBILENET_V1_ARCH_PARAMS = load_arch_params("ssd_mobilenetv1_arch_params")
+DEFAULT_SSD_LITE_MOBILENET_V2_ARCH_PARAMS = load_arch_params("ssd_lite_mobilenetv2_arch_params")
 
 
-DEFAULT_SSD_LITE_MOBILENET_V2_ARCH_PARAMS = {
-    "out_channels": [576, 1280, 512, 256, 256, 64],
-    "expand_ratios": [0.2, 0.25, 0.5, 0.25],
-    "lite": True,
-    "width_mult": 1.0,
-    # "output_paths": [[7,'conv',2], [14, 'conv', 2]], output paths for a model with output levels of stride 8 plus
-    "output_paths": [[14, 'conv', 2], 18],
-    "anchors": DefaultBoxes(fig_size=320, feat_size=[20, 10, 5, 3, 2, 1], scales=[32, 82, 133, 184, 235, 285, 336],
-                            aspect_ratios=[[2, 3], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]], scale_xy=0.1, scale_wh=0.2)
-}
 
 
+class SSDMobileNetV1(CustomizableDetector):
+    def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
+        merged_arch_params = HpmStruct(**copy.deepcopy(DEFAULT_SSD_MOBILENET_V1_ARCH_PARAMS))
+        merged_arch_params.override(**arch_params.to_dict())
+        super().__init__(merged_arch_params, in_channels=in_channels)
 
 
-def SeperableConv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True):
-    """Replace Conv2d with a depthwise Conv2d and Pointwise Conv2d.
-    """
-    return nn.Sequential(
-        nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
-                  groups=in_channels, stride=stride, padding=padding, bias=bias),
-        nn.BatchNorm2d(in_channels),
-        nn.ReLU(),
-        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
-    )
 
 
-
-class SSD(SgModule):
-    """
-    paper: https://arxiv.org/pdf/1512.02325.pdf
-    based on code: https://github.com/NVIDIA/DeepLearningExamples
-    """
-
-    def __init__(self, backbone, arch_params):
-        super().__init__()
-
-        self.arch_params = HpmStruct(**DEFAULT_SSD_ARCH_PARAMS)
-        self.arch_params.override(**arch_params.to_dict())
-
-        paths = utils.get_param(self.arch_params, 'output_paths')
-        if paths is not None:
-            self.backbone = MultiOutputModule(backbone, paths)
-        else:
-            self.backbone = backbone
-
-        # num classes in a dataset
-        # the model will predict self.num_classes + 1 values to also include background
-        self.num_classes = self.arch_params.num_classes
-        self.dboxes_xy = nn.Parameter(self.arch_params.anchors('xywh')[:, :2], requires_grad=False)
-        self.dboxes_wh = nn.Parameter(self.arch_params.anchors('xywh')[:, 2:], requires_grad=False)
-        scale_xy = self.arch_params.anchors.scale_xy
-        scale_wh = self.arch_params.anchors.scale_wh
-        scales = torch.tensor([scale_xy, scale_xy, scale_wh, scale_wh])
-        self.scales = nn.Parameter(scales, requires_grad=False)
-        self.img_size = nn.Parameter(torch.tensor([self.arch_params.anchors.fig_size]), requires_grad=False)
-        self.num_anchors = self.arch_params.anchors.num_anchors
-
-        self._build_additional_blocks()
-        self._build_detecting_branches()
-        self._init_weights()
-
-    def _build_detecting_branches(self, build_loc=True):
-        """Add localization and classification branches
-
-        :param build_loc: whether to build localization branch;
-                          called with False in replace_head(...), in such case only classification branch is rebuilt
-        """
-        if build_loc:
-            self.loc = []
-        self.conf = []
-
-        out_channels = self.arch_params.out_channels
-        lite = utils.get_param(self.arch_params, 'lite', False)
-        for i, (nd, oc) in enumerate(zip(self.num_anchors, out_channels)):
-            conv = SeperableConv2d if lite and i < len(self.num_anchors) - 1 else nn.Conv2d
-            if build_loc:
-                self.loc.append(conv(oc, nd * 4, kernel_size=3, padding=1))
-            self.conf.append(conv(oc, nd * (self.num_classes + 1), kernel_size=3, padding=1))
-
-        if build_loc:
-            self.loc = nn.ModuleList(self.loc)
-        self.conf = nn.ModuleList(self.conf)
-
-    def _build_additional_blocks(self):
-        input_size = self.arch_params.out_channels
-        kernel_sizes = self.arch_params.kernel_sizes
-        bottleneck_channels = self.arch_params.additional_blocks_bottleneck_channels
-
-        self.additional_blocks = []
-        for i, (input_size, output_size, channels, kernel_size) in enumerate(
-                zip(input_size[:-1], input_size[1:], bottleneck_channels, kernel_sizes)):
-            if i < 3:
-                middle_layer = nn.Conv2d(channels, output_size, kernel_size=kernel_size, padding=1, stride=2,
-                                         bias=False)
-            else:
-                middle_layer = nn.Conv2d(channels, output_size, kernel_size=kernel_size, bias=False)
-
-            layer = nn.Sequential(
-                nn.Conv2d(input_size, channels, kernel_size=1, bias=False),
-                nn.BatchNorm2d(channels),
-                nn.ReLU(inplace=True),
-                middle_layer,
-                nn.BatchNorm2d(output_size),
-                nn.ReLU(inplace=True),
-            )
-
-            self.additional_blocks.append(layer)
-
-        self.additional_blocks = nn.ModuleList(self.additional_blocks)
-
-    def _init_weights(self):
-        layers = [*self.additional_blocks, *self.loc, *self.conf]
-        for layer in layers:
-            for param in layer.parameters():
-                if param.dim() > 1:
-                    nn.init.xavier_uniform_(param)
-
-    def bbox_view(self, feature_maps):
-        """ Shape the classifier to the view of bboxes """
-        ret = []
-        for features, loc, conf in zip(feature_maps, self.loc, self.conf):
-            boxes_preds = loc(features).view(features.size(0), 4, -1)
-            class_preds = conf(features).view(features.size(0), self.num_classes + 1, -1)
-            ret.append((boxes_preds, class_preds))
-
-        locs, confs = list(zip(*ret))
-        locs, confs = torch.cat(locs, 2).contiguous(), torch.cat(confs, 2).contiguous()
-        return locs, confs
-
-    def forward(self, x):
-        x = self.backbone(x)
-
-        # IF THE BACKBONE IS A MultiOutputModule WE GET A LIST, OTHERWISE WE WRAP IT IN A LIST
-        detection_feed = x if isinstance(x, list) else [x]
-        x = detection_feed[-1]
-
-        for block in self.additional_blocks:
-            x = block(x)
-            detection_feed.append(x)
-
-        # detection_feed are FEATURE MAPS: i.e. FOR 300X300 INPUT - 38X38X4, 19X19X6, 10X10X6, 5X5X6, 3X3X4, 1X1X4
-        locs, confs = self.bbox_view(detection_feed)
-
-        if self.training:
-            # FOR 300X300 INPUT - RETURN N_BATCH X 8732 X {N_LABELS, N_LOCS} RESULTS
-            return locs, confs
-        else:
-            bboxes_in = locs.permute(0, 2, 1)
-            scores_in = confs.permute(0, 2, 1)
-
-            bboxes_in *= self.scales
-
-            # CONVERT RELATIVE LOCATIONS INTO ABSOLUTE LOCATION (OUTPUT LOCATIONS ARE RELATIVE TO THE DBOXES)
-            xy = (bboxes_in[:, :, :2] * self.dboxes_wh + self.dboxes_xy) * self.img_size
-            wh = (bboxes_in[:, :, 2:].exp() * self.dboxes_wh) * self.img_size
-
-            # REPLACE THE CONFIDENCE OF CLASS NONE WITH OBJECT CONFIDENCE
-            # SSD DOES NOT OUTPUT OBJECT CONFIDENCE, REQUIRED FOR THE NMS
-            scores_in = torch.softmax(scores_in, dim=-1)
-            classes_conf = scores_in[:, :, 1:]
-            obj_conf = torch.max(classes_conf, dim=2)[0].unsqueeze(dim=-1)
-
-            return torch.cat((xy, wh, obj_conf, classes_conf), dim=2), (locs, confs)
-
-    def replace_head(self, new_num_classes):
-        del self.conf
-        self.arch_params.num_classes = new_num_classes
-        self.num_classes = new_num_classes
-        self._build_detecting_branches(build_loc=False)
-
-
-class SSDMobileNetV1(SSD):
-    """
-    paper: http://ceur-ws.org/Vol-2500/paper_5.pdf
-    """
-
-    def __init__(self, arch_params: HpmStruct):
-        self.arch_params = HpmStruct(**DEFAULT_SSD_MOBILENET_V1_ARCH_PARAMS)
-        self.arch_params.override(**arch_params.to_dict())
-        mobilenet_backbone = MobileNet(num_classes=None, backbone_mode=True, up_to_layer=10)
-        super().__init__(backbone=mobilenet_backbone, arch_params=self.arch_params)
-
-
-class SSDLiteMobileNetV2(SSD):
-    def __init__(self, arch_params: HpmStruct):
-        self.arch_params = HpmStruct(**DEFAULT_SSD_LITE_MOBILENET_V2_ARCH_PARAMS)
-        self.arch_params.override(**arch_params.to_dict())
-        self.arch_params.out_channels[0] = int(round(self.arch_params.out_channels[0] * self.arch_params.width_mult))
-        mobilenetv2 = MobileNetV2(num_classes=None, dropout=0.,
-                                  backbone_mode=True, width_mult=self.arch_params.width_mult)
-        super().__init__(backbone=mobilenetv2.features, arch_params=self.arch_params)
-
-    # OVERRIDE THE DEFAULT FUNCTION FROM SSD. ADD THE SDD BLOCKS AFTER THE BACKBONE.
-    def _build_additional_blocks(self):
-        channels = self.arch_params.out_channels
-        expand_ratios = self.arch_params.expand_ratios
-        self.additional_blocks = []
-        for in_channels, out_channels, expand_ratio in zip(channels[1:-1], channels[2:], expand_ratios):
-            self.additional_blocks.append(
-                InvertedResidual(in_channels, out_channels, stride=2, expand_ratio=expand_ratio))
-
-        self.additional_blocks = nn.ModuleList(self.additional_blocks)
+class SSDLiteMobileNetV2(CustomizableDetector):
+    def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
+        merged_arch_params = HpmStruct(**copy.deepcopy(DEFAULT_SSD_LITE_MOBILENET_V2_ARCH_PARAMS))
+        merged_arch_params.override(**arch_params.to_dict())
+        super().__init__(merged_arch_params, in_channels=in_channels)
Discard
@@ -1,6 +1,9 @@
+import os
 from pathlib import Path
 from pathlib import Path
 from typing import List
 from typing import List
+import pkg_resources
 
 
+import hydra
 from hydra import initialize_config_dir, compose
 from hydra import initialize_config_dir, compose
 from hydra.core.global_hydra import GlobalHydra
 from hydra.core.global_hydra import GlobalHydra
 from omegaconf import OmegaConf, open_dict, DictConfig
 from omegaconf import OmegaConf, open_dict, DictConfig
@@ -62,3 +65,15 @@ def normalize_path(path: str) -> str:
     :return: Output path string with all \\ symbols replaces with /.
     :return: Output path string with all \\ symbols replaces with /.
     """
     """
     return path.replace("\\", "/")
     return path.replace("\\", "/")
+
+
+def load_arch_params(config_name: str) -> DictConfig:
+    """
+    :param config_name: name of a yaml with arch parameters
+    """
+    GlobalHydra.instance().clear()
+    sg_recipes_dir = pkg_resources.resource_filename("super_gradients.recipes", "")
+    dataset_config = os.path.join("arch_params", config_name)
+    with initialize_config_dir(config_dir=normalize_path(sg_recipes_dir)):
+        # config is relative to a module
+        return hydra.utils.instantiate(compose(config_name=normalize_path(dataset_config)).arch_params)
Discard