Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#427 Feature/alg 743 new detection base

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/ALG-743_new-detection-base
@@ -8,6 +8,8 @@ from super_gradients.common.factories.optimizers_type_factory import OptimizersT
 from super_gradients.common.factories.samplers_factory import SamplersFactory
 from super_gradients.common.factories.samplers_factory import SamplersFactory
 from super_gradients.common.factories.transforms_factory import TransformsFactory
 from super_gradients.common.factories.transforms_factory import TransformsFactory
 from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
 from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
+from super_gradients.common.factories.detection_modules_factory import DetectionModulesFactory
+
 
 
 __all__ = [
 __all__ = [
     "CallbacksFactory",
     "CallbacksFactory",
@@ -20,4 +22,5 @@ __all__ = [
     "ActivationsTypeFactory",
     "ActivationsTypeFactory",
     "TypeFactory",
     "TypeFactory",
     "BaseFactory",
     "BaseFactory",
+    "DetectionModulesFactory",
 ]
 ]
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
  1. from typing import Union, Any
  2. from omegaconf import DictConfig
  3. from super_gradients.training.utils import HpmStruct
  4. from super_gradients.common.factories.base_factory import BaseFactory
  5. from super_gradients.modules.detection_modules import ALL_DETECTION_MODULES
  6. class DetectionModulesFactory(BaseFactory):
  7. def __init__(self):
  8. super().__init__(ALL_DETECTION_MODULES)
  9. @staticmethod
  10. def insert_module_param(conf: Union[str, dict, HpmStruct, DictConfig], name: str, value: Any):
  11. """
  12. Assign a new parameter for the module
  13. :param conf: a module config, either {type_name(str): {parameters...}} or just type_name(str)
  14. :param name: parameter name
  15. :param value: parameter value
  16. :return: an update config {type_name(str): {name: value, parameters...}}
  17. """
  18. if isinstance(conf, str):
  19. return {conf: {name: value}}
  20. cls_type = list(conf.keys())[0]
  21. conf[cls_type][name] = value
  22. return conf
Discard
@@ -1,4 +1,4 @@
-from super_gradients.common.registry.registry import register_model, register_metric, register_loss
+from super_gradients.common.registry.registry import register_model, register_metric, register_loss, register_detection_module
 
 
 
 
-__all__ = ['register_model', 'register_metric', 'register_loss']
+__all__ = ["register_model", "register_detection_module", "register_metric", "register_loss"]
Discard
@@ -5,6 +5,7 @@ from super_gradients.training.dataloaders.dataloaders import ALL_DATALOADERS
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.metrics.all_metrics import METRICS
 from super_gradients.training.metrics.all_metrics import METRICS
 from super_gradients.training.losses.all_losses import LOSSES
 from super_gradients.training.losses.all_losses import LOSSES
+from super_gradients.modules.detection_modules import ALL_DETECTION_MODULES
 
 
 
 
 def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
 def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
@@ -14,6 +15,7 @@ def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
     :param registry: The registry (maps name to object that you register)
     :param registry: The registry (maps name to object that you register)
     :return:         Register function
     :return:         Register function
     """
     """
+
     def register(name: Optional[str] = None) -> Callable:
     def register(name: Optional[str] = None) -> Callable:
         """
         """
         Set up a register decorator.
         Set up a register decorator.
@@ -21,6 +23,7 @@ def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
         :param name: If specified, the decorated object will be registered with this name.
         :param name: If specified, the decorated object will be registered with this name.
         :return:     Decorator that registers the callable.
         :return:     Decorator that registers the callable.
         """
         """
+
         def decorator(cls: Callable) -> Callable:
         def decorator(cls: Callable) -> Callable:
             """Register the decorated callable"""
             """Register the decorated callable"""
             cls_name = name if name is not None else cls.__name__
             cls_name = name if name is not None else cls.__name__
@@ -31,11 +34,14 @@ def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
 
 
             registry[cls_name] = cls
             registry[cls_name] = cls
             return cls
             return cls
+
         return decorator
         return decorator
+
     return register
     return register
 
 
 
 
 register_model = create_register_decorator(registry=ARCHITECTURES)
 register_model = create_register_decorator(registry=ARCHITECTURES)
+register_detection_module = create_register_decorator(registry=ALL_DETECTION_MODULES)
 register_metric = create_register_decorator(registry=METRICS)
 register_metric = create_register_decorator(registry=METRICS)
 register_loss = create_register_decorator(registry=LOSSES)
 register_loss = create_register_decorator(registry=LOSSES)
 register_dataloader = create_register_decorator(registry=ALL_DATALOADERS)
 register_dataloader = create_register_decorator(registry=ALL_DATALOADERS)
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
  1. from abc import abstractmethod, ABC
  2. from typing import Union, List
  3. from torch import nn
  4. class BaseDetectionModule(nn.Module, ABC):
  5. """
  6. An interface for a module that is easy to integrate into a model with complex connections
  7. """
  8. def __init__(self, in_channels: Union[List[int], int], **kwargs):
  9. """
  10. :param in_channels: defines channels of tensor(s) that will be accepted by a module in forward
  11. """
  12. super().__init__()
  13. self.in_channels = in_channels
  14. @property
  15. @abstractmethod
  16. def out_channels(self) -> Union[List[int], int]:
  17. """
  18. :return: channels of tensor(s) that will be returned by a module in forward
  19. """
  20. raise NotImplementedError()
  21. ALL_DETECTION_MODULES = {}
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
  1. """
  2. A base for a detection network built according to the following scheme:
  3. * constructed from nested arch_params;
  4. * inside arch_params each nested level (module) has an explicit type and its required parameters
  5. * each module accepts in_channels and other parameters
  6. * each module defines out_channels property on construction
  7. """
  8. from typing import Union, List
  9. import torch
  10. from torch import nn
  11. from omegaconf import DictConfig
  12. from super_gradients.training.utils.utils import HpmStruct, get_param
  13. from super_gradients.training.models.sg_module import SgModule
  14. from super_gradients.common.factories import DetectionModulesFactory
  15. from super_gradients.modules.detection_modules import BaseDetectionModule
  16. from super_gradients.common.registry import register_detection_module
  17. @register_detection_module("NStageBackbone")
  18. class NStageBackbone(BaseDetectionModule):
  19. """
  20. A backbone with a stem -> N stages -> context module
  21. Returns outputs of the layers listed in out_layers
  22. """
  23. def __init__(
  24. self,
  25. in_channels: int,
  26. out_layers: List[str],
  27. stem: Union[str, HpmStruct, DictConfig],
  28. stages: Union[str, HpmStruct, DictConfig],
  29. context_module: Union[str, HpmStruct, DictConfig],
  30. ):
  31. """
  32. :param out_layers: names of layers to output from the following options: 'stem', 'stageN', 'context_module'
  33. """
  34. super().__init__(in_channels)
  35. factory = DetectionModulesFactory()
  36. self.num_stages = len(stages)
  37. self.stem = factory.get(factory.insert_module_param(stem, "in_channels", in_channels))
  38. prev_channels = self.stem.out_channels
  39. for i in range(self.num_stages):
  40. new_stage = factory.get(factory.insert_module_param(stages[i], "in_channels", prev_channels))
  41. setattr(self, f"stage{i + 1}", new_stage)
  42. prev_channels = new_stage.out_channels
  43. self.context_module = factory.get(factory.get(factory.insert_module_param(context_module, "in_channels", prev_channels)))
  44. self.out_layers = out_layers
  45. self._out_channels = self._define_out_channels()
  46. def _define_out_channels(self):
  47. out_channels = []
  48. for layer in self.out_layers:
  49. out_channels.append(getattr(self, layer).out_channels)
  50. return out_channels
  51. @property
  52. def out_channels(self):
  53. return self._out_channels
  54. def forward(self, x):
  55. outputs = []
  56. all_layers = ["stem"] + [f"stage{i}" for i in range(1, self.num_stages + 1)] + ["context_module"]
  57. for layer in all_layers:
  58. x = getattr(self, layer)(x)
  59. if layer in self.out_layers:
  60. outputs.append(x)
  61. return outputs
  62. @register_detection_module("PANNeck")
  63. class PANNeck(BaseDetectionModule):
  64. """
  65. A PAN (path aggregation network) neck with 4 stages (2 up-sampling and 2 down-sampling stages)
  66. Returns outputs of neck stage 2, stage 3, stage 4
  67. """
  68. def __init__(
  69. self,
  70. in_channels: List[int],
  71. neck1: Union[str, HpmStruct, DictConfig],
  72. neck2: Union[str, HpmStruct, DictConfig],
  73. neck3: Union[str, HpmStruct, DictConfig],
  74. neck4: Union[str, HpmStruct, DictConfig],
  75. ):
  76. super().__init__(in_channels)
  77. c3_out_channels, c4_out_channels, c5_out_channels = in_channels
  78. factory = DetectionModulesFactory()
  79. self.neck1 = factory.get(factory.insert_module_param(neck1, "in_channels", [c5_out_channels, c4_out_channels]))
  80. self.neck2 = factory.get(factory.insert_module_param(neck2, "in_channels", [self.neck1.out_channels[1], c3_out_channels]))
  81. self.neck3 = factory.get(factory.insert_module_param(neck3, "in_channels", [self.neck2.out_channels[1], self.neck2.out_channels[0]]))
  82. self.neck4 = factory.get(factory.insert_module_param(neck4, "in_channels", [self.neck3.out_channels, self.neck1.out_channels[0]]))
  83. self._out_channels = [
  84. self.neck2.out_channels[1],
  85. self.neck3.out_channels,
  86. self.neck4.out_channels,
  87. ]
  88. @property
  89. def out_channels(self):
  90. return self._out_channels
  91. def forward(self, inputs):
  92. c3, c4, c5 = inputs
  93. x_n1_inter, x = self.neck1([c5, c4])
  94. x_n2_inter, p3 = self.neck2([x, c3])
  95. p4 = self.neck3([p3, x_n2_inter])
  96. p5 = self.neck4([p4, x_n1_inter])
  97. return p3, p4, p5
  98. @register_detection_module("NHeads")
  99. class NHeads(BaseDetectionModule):
  100. """
  101. Apply N heads in parallel and combine predictions into the shape expected by SG detection losses
  102. """
  103. def __init__(self, in_channels: List[int], num_classes: int, heads_list: Union[str, HpmStruct, DictConfig]):
  104. super().__init__(in_channels)
  105. factory = DetectionModulesFactory()
  106. heads_list = self._pass_num_classes(heads_list, factory, num_classes)
  107. self.num_heads = len(heads_list)
  108. for i in range(self.num_heads):
  109. new_head = factory.get(factory.insert_module_param(heads_list[i], "in_channels", in_channels[i]))
  110. setattr(self, f"head{i + 1}", new_head)
  111. @staticmethod
  112. def _pass_num_classes(heads_list, factory, num_classes):
  113. for i in range(len(heads_list)):
  114. heads_list[i] = factory.insert_module_param(heads_list[i], "num_classes", num_classes)
  115. return heads_list
  116. @property
  117. def out_channels(self):
  118. return None
  119. def forward(self, inputs):
  120. outputs = []
  121. for i in range(self.num_heads):
  122. outputs.append(getattr(self, f"head{i + 1}")(inputs[i]))
  123. return self.combine_preds(outputs)
  124. def combine_preds(self, preds):
  125. outputs = []
  126. outputs_logits = []
  127. for output, output_logits in preds:
  128. outputs.append(output)
  129. outputs_logits.append(output_logits)
  130. return outputs if self.training else (torch.cat(outputs, 1), outputs_logits)
  131. class CustomizableDetector(SgModule):
  132. """
  133. A customizable detector with backbone -> neck -> heads
  134. Each submodule with its parameters must be defined explicitly.
  135. Modules should follow the interface of BaseDetectionModule
  136. """
  137. def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
  138. """
  139. :param type_mapping: can be passed to resolve string type names in arch_params to actual types
  140. """
  141. super().__init__()
  142. factory = DetectionModulesFactory()
  143. # move num_classes into heads params
  144. if get_param(arch_params, "num_classes"):
  145. arch_params.heads = factory.insert_module_param(arch_params.heads, "num_classes", arch_params.num_classes)
  146. self.arch_params = arch_params
  147. self.backbone = factory.get(factory.insert_module_param(arch_params.backbone, "in_channels", in_channels))
  148. self.neck = factory.get(factory.insert_module_param(arch_params.neck, "in_channels", self.backbone.out_channels))
  149. self.heads = factory.get(factory.insert_module_param(arch_params.heads, "in_channels", self.neck.out_channels))
  150. self._initialize_weights(arch_params)
  151. def forward(self, x):
  152. x = self.backbone(x)
  153. x = self.neck(x)
  154. return self.heads(x)
  155. def _initialize_weights(self, arch_params: Union[HpmStruct, DictConfig]):
  156. bn_eps = get_param(arch_params, "bn_eps", None)
  157. bn_momentum = get_param(arch_params, "bn_momentum", None)
  158. inplace_act = get_param(arch_params, "inplace_act", True)
  159. for m in self.modules():
  160. t = type(m)
  161. if t is nn.BatchNorm2d:
  162. m.eps = bn_eps if bn_eps else m.eps
  163. m.momentum = bn_momentum if bn_momentum else m.momentum
  164. elif inplace_act and t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, nn.Mish]:
  165. m.inplace = True
  166. def prep_model_for_conversion(self, input_size: Union[tuple, list] = None, **kwargs):
  167. for module in self.modules():
  168. if module != self and hasattr(module, "prep_model_for_conversion"):
  169. module.prep_model_for_conversion(input_size, **kwargs)
  170. def replace_head(self, new_num_classes: int = None, new_head: nn.Module = None):
  171. if new_num_classes is None and new_head is None:
  172. raise ValueError("At least one of new_num_classes, new_head must be given to replace output layer.")
  173. if new_head is not None:
  174. self.heads = new_head
  175. else:
  176. self.arch_params.heads.num_classes = new_num_classes
  177. self.heads = self.factory.get(self.arch_params.heads, self.neck.out_channels)
  178. self._initialize_weights(self.arch_params)
Discard