Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#875 Feature/sg 761 yolo nas

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-761-yolo-nas
41 changed files with 1714 additions and 97 deletions
  1. 5
    2
      src/super_gradients/common/object_names.py
  2. 2
    2
      src/super_gradients/examples/predict/detection_predict.py
  3. 2
    2
      src/super_gradients/examples/predict/detection_predict_image_folder.py
  4. 2
    2
      src/super_gradients/examples/predict/detection_predict_streaming.py
  5. 2
    2
      src/super_gradients/examples/predict/detection_predict_video.py
  6. 2
    2
      src/super_gradients/module_interfaces/__init__.py
  7. 28
    0
      src/super_gradients/module_interfaces/module_interfaces.py
  8. 27
    0
      src/super_gradients/modules/__init__.py
  9. 27
    0
      src/super_gradients/modules/base_modules.py
  10. 19
    26
      src/super_gradients/modules/detection_modules.py
  11. 47
    0
      src/super_gradients/modules/head_replacement_utils.py
  12. 1
    1
      src/super_gradients/modules/pose_estimation_modules.py
  13. 112
    0
      src/super_gradients/recipes/arch_params/yolo_nas_l_arch_params.yaml
  14. 112
    0
      src/super_gradients/recipes/arch_params/yolo_nas_m_arch_params.yaml
  15. 112
    0
      src/super_gradients/recipes/arch_params/yolo_nas_s_arch_params.yaml
  16. 43
    0
      src/super_gradients/recipes/coco2017_yolo_nas_s.yaml
  17. 6
    5
      src/super_gradients/recipes/dataset_params/coco_detection_yolo_nas_dataset_params.yaml
  18. 25
    8
      src/super_gradients/recipes/dataset_params/roboflow_detection_dataset_params.yaml
  19. 92
    0
      src/super_gradients/recipes/roboflow_yolo_nas_m.yaml
  20. 92
    0
      src/super_gradients/recipes/roboflow_yolo_nas_s.yaml
  21. 18
    0
      src/super_gradients/recipes/roboflow_yolo_nas_s_qat.yaml
  22. 56
    0
      src/super_gradients/recipes/training_hyperparams/coco2017_yolo_nas_train_params.yaml
  23. 4
    4
      src/super_gradients/training/dataloaders/__init__.py
  24. 6
    6
      src/super_gradients/training/dataloaders/dataloaders.py
  25. 2
    2
      src/super_gradients/training/datasets/detection_datasets/roboflow/metadata.py
  26. 27
    9
      src/super_gradients/training/models/__init__.py
  27. 19
    7
      src/super_gradients/training/models/detection_models/csp_darknet53.py
  28. 4
    0
      src/super_gradients/training/models/detection_models/customizable_detector.py
  29. 2
    2
      src/super_gradients/training/models/detection_models/pp_yolo_e/__init__.py
  30. 4
    4
      src/super_gradients/training/models/detection_models/pp_yolo_e/pan.py
  31. 3
    2
      src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_e.py
  32. 2
    1
      src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_head.py
  33. 26
    0
      src/super_gradients/training/models/detection_models/yolo_nas/__init__.py
  34. 270
    0
      src/super_gradients/training/models/detection_models/yolo_nas/dfl_heads.py
  35. 64
    0
      src/super_gradients/training/models/detection_models/yolo_nas/panneck.py
  36. 90
    0
      src/super_gradients/training/models/detection_models/yolo_nas/yolo_nas_variants.py
  37. 332
    0
      src/super_gradients/training/models/detection_models/yolo_nas/yolo_stages.py
  38. 7
    2
      src/super_gradients/training/pipelines/pipelines.py
  39. 4
    0
      src/super_gradients/training/pretrained_models.py
  40. 6
    6
      src/super_gradients/training/processing/processing.py
  41. 10
    0
      src/super_gradients/training/utils/checkpoint_utils.py
@@ -304,6 +304,9 @@ class Models:
     DEKR_W32_NO_DC = "dekr_w32_no_dc"
     DEKR_W32_NO_DC = "dekr_w32_no_dc"
     POSE_PP_YOLO_L = "pose_ppyolo_l"
     POSE_PP_YOLO_L = "pose_ppyolo_l"
     POSE_DDRNET_39 = "pose_ddrnet39"
     POSE_DDRNET_39 = "pose_ddrnet39"
+    YOLO_NAS_S = "yolo_nas_s"
+    YOLO_NAS_M = "yolo_nas_m"
+    YOLO_NAS_L = "yolo_nas_l"
 
 
 
 
 class ConcatenatedTensorFormats:
 class ConcatenatedTensorFormats:
@@ -326,8 +329,8 @@ class Dataloaders:
     COCO2017_VAL = "coco2017_val"
     COCO2017_VAL = "coco2017_val"
     COCO2017_TRAIN_YOLOX = "coco2017_train_yolox"
     COCO2017_TRAIN_YOLOX = "coco2017_train_yolox"
     COCO2017_VAL_YOLOX = "coco2017_val_yolox"
     COCO2017_VAL_YOLOX = "coco2017_val_yolox"
-    COCO2017_TRAIN_DECIYOLO = "coco2017_train_deci_yolo"
-    COCO2017_VAL_DECIYOLO = "coco2017_val_deci_yolo"
+    COCO2017_TRAIN_YOLO_NAS = "coco2017_train_yolo_nas"
+    COCO2017_VAL_YOLO_NAS = "coco2017_val_yolo_nas"
     COCO2017_TRAIN_PPYOLOE = "coco2017_train_ppyoloe"
     COCO2017_TRAIN_PPYOLOE = "coco2017_train_ppyoloe"
     COCO2017_VAL_PPYOLOE = "coco2017_val_ppyoloe"
     COCO2017_VAL_PPYOLOE = "coco2017_val_ppyoloe"
     COCO2017_TRAIN_SSD_LITE_MOBILENET_V2 = "coco2017_train_ssd_lite_mobilenet_v2"
     COCO2017_TRAIN_SSD_LITE_MOBILENET_V2 = "coco2017_train_ssd_lite_mobilenet_v2"
Discard
@@ -1,8 +1,8 @@
 from super_gradients.common.object_names import Models
 from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 
 
-# Note that currently only YoloX and PPYoloE are supported.
-model = models.get(Models.PP_YOLOE_S, pretrained_weights="coco")
+# Note that currently only YoloX, PPYoloE and YOLO-NAS are supported.
+model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco")
 
 
 IMAGES = [
 IMAGES = [
     "../../../../documentation/source/images/examples/countryside.jpg",
     "../../../../documentation/source/images/examples/countryside.jpg",
Discard
@@ -1,8 +1,8 @@
 from super_gradients.common.object_names import Models
 from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 
 
-# Note that currently only YoloX and PPYoloE are supported.
-model = models.get(Models.YOLOX_N, pretrained_weights="coco")
+# Note that currently only YoloX, PPYoloE and YOLO-NAS are supported.
+model = models.get(Models.YoloNAS_L, pretrained_weights="coco")
 
 
 image_folder_path = "../../../../documentation/source/images/examples"
 image_folder_path = "../../../../documentation/source/images/examples"
 
 
Discard
@@ -2,8 +2,8 @@ import torch
 from super_gradients.common.object_names import Models
 from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 
 
-# Note that currently only YoloX and PPYoloE are supported.
-model = models.get(Models.YOLOX_N, pretrained_weights="coco")
+# Note that currently only YoloX, PPYoloE and YOLO-NAS are supported.
+model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco")
 
 
 # We want to use cuda if available to speed up inference.
 # We want to use cuda if available to speed up inference.
 model = model.to("cuda" if torch.cuda.is_available() else "cpu")
 model = model.to("cuda" if torch.cuda.is_available() else "cpu")
Discard
@@ -3,8 +3,8 @@ import torch
 from super_gradients.common.object_names import Models
 from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 
 
-# Note that currently only YoloX and PPYoloE are supported.
-model = models.get(Models.YOLOX_N, pretrained_weights="coco")
+# Note that currently only YoloX, PPYoloE and YOLO-NAS are supported.
+model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco")
 
 
 # We want to use cuda if available to speed up inference.
 # We want to use cuda if available to speed up inference.
 model = model.to("cuda" if torch.cuda.is_available() else "cpu")
 model = model.to("cuda" if torch.cuda.is_available() else "cpu")
Discard
@@ -1,3 +1,3 @@
-from .module_interfaces import HasPredict, HasPreprocessingParams
+from .module_interfaces import HasPredict, HasPreprocessingParams, SupportsReplaceNumClasses
 
 
-__all__ = ["HasPredict", "HasPreprocessingParams"]
+__all__ = ["HasPredict", "HasPreprocessingParams", "SupportsReplaceNumClasses"]
Discard
@@ -1,3 +1,6 @@
+from typing import Callable
+
+from torch import nn
 from typing_extensions import Protocol, runtime_checkable
 from typing_extensions import Protocol, runtime_checkable
 
 
 
 
@@ -31,3 +34,28 @@ class HasPredict(Protocol):
 
 
     def predict_webcam(self, *args, **kwargs):
     def predict_webcam(self, *args, **kwargs):
         ...
         ...
+
+
+@runtime_checkable
+class SupportsReplaceNumClasses(Protocol):
+    """
+    Protocol interface for modules that support replacing the number of classes.
+    Derived classes should implement the `replace_num_classes` method.
+
+    This interface class serves a purpose of explicitly indicating whether a class supports optimized head replacement:
+
+    >>> class PredictionHead(nn.Module, SupportsReplaceNumClasses):
+    >>>    def replace_num_classes(self, num_classes: int, compute_new_weights_fn: Callable[[nn.Module, int], nn.Module] = None):
+    >>>       ...
+    """
+
+    def replace_num_classes(self, num_classes: int, compute_new_weights_fn: Callable[[nn.Module, int], nn.Module]):
+        """
+        Replace the number of classes in the module.
+
+        :param num_classes: New number of classes.
+        :param compute_new_weights_fn: (callable) An optional function that computes the new weights for the new classes.
+            It takes existing nn.Module and returns a new one.
+        :return: None
+        """
+        ...
Discard
@@ -16,7 +16,23 @@ from super_gradients.modules.skip_connections import (
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.registry.registry import ALL_DETECTION_MODULES
 from super_gradients.common.registry.registry import ALL_DETECTION_MODULES
 
 
+from super_gradients.modules.base_modules import BaseDetectionModule
+from super_gradients.modules.detection_modules import (
+    PANNeck,
+    NHeads,
+    MultiOutputBackbone,
+    NStageBackbone,
+    MobileNetV1Backbone,
+    MobileNetV2Backbone,
+    SSDNeck,
+    SSDInvertedResidualNeck,
+    SSDBottleneckNeck,
+    SSDHead,
+)
+from super_gradients.module_interfaces import SupportsReplaceNumClasses
+
 __all__ = [
 __all__ = [
+    "BaseDetectionModule",
     "ALL_DETECTION_MODULES",
     "ALL_DETECTION_MODULES",
     "PixelShuffle",
     "PixelShuffle",
     "AntiAliasDownsample",
     "AntiAliasDownsample",
@@ -33,6 +49,17 @@ __all__ = [
     "BackboneInternalSkipConnection",
     "BackboneInternalSkipConnection",
     "HeadInternalSkipConnection",
     "HeadInternalSkipConnection",
     "LightweightDEKRHead",
     "LightweightDEKRHead",
+    "PANNeck",
+    "NHeads",
+    "MultiOutputBackbone",
+    "NStageBackbone",
+    "MobileNetV1Backbone",
+    "MobileNetV2Backbone",
+    "SSDNeck",
+    "SSDInvertedResidualNeck",
+    "SSDBottleneckNeck",
+    "SSDHead",
+    "SupportsReplaceNumClasses",
 ]
 ]
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
  1. from abc import abstractmethod, ABC
  2. from typing import Union, List
  3. from torch import nn
  4. __all__ = ["BaseDetectionModule"]
  5. class BaseDetectionModule(nn.Module, ABC):
  6. """
  7. An interface for a module that is easy to integrate into a model with complex connections
  8. """
  9. def __init__(self, in_channels: Union[List[int], int], **kwargs):
  10. """
  11. :param in_channels: defines channels of tensor(s) that will be accepted by a module in forward
  12. """
  13. super().__init__()
  14. self.in_channels = in_channels
  15. @property
  16. @abstractmethod
  17. def out_channels(self) -> Union[List[int], int]:
  18. """
  19. :return: channels of tensor(s) that will be returned by a module in forward
  20. """
  21. raise NotImplementedError()
Discard
@@ -1,37 +1,30 @@
+from abc import ABC, abstractmethod
 from typing import Union, List
 from typing import Union, List
-from abc import abstractmethod, ABC
 
 
 import torch
 import torch
-from torch import nn
-from omegaconf.listconfig import ListConfig
 from omegaconf import DictConfig
 from omegaconf import DictConfig
-
+from omegaconf.listconfig import ListConfig
 from super_gradients.common.registry.registry import register_detection_module
 from super_gradients.common.registry.registry import register_detection_module
+from super_gradients.modules.base_modules import BaseDetectionModule
+from super_gradients.modules.multi_output_modules import MultiOutputModule
+from super_gradients.training.models import MobileNet, MobileNetV2
 from super_gradients.training.models.classification_models.mobilenetv2 import InvertedResidual
 from super_gradients.training.models.classification_models.mobilenetv2 import InvertedResidual
 from super_gradients.training.utils.utils import HpmStruct
 from super_gradients.training.utils.utils import HpmStruct
-from super_gradients.training.models import MobileNet, MobileNetV2
-from super_gradients.modules.multi_output_modules import MultiOutputModule
-
-
-class BaseDetectionModule(nn.Module, ABC):
-    """
-    An interface for a module that is easy to integrate into a model with complex connections
-    """
-
-    def __init__(self, in_channels: Union[List[int], int], **kwargs):
-        """
-        :param in_channels: defines channels of tensor(s) that will be accepted by a module in forward
-        """
-        super().__init__()
-        self.in_channels = in_channels
+from torch import nn
 
 
-    @property
-    @abstractmethod
-    def out_channels(self) -> Union[List[int], int]:
-        """
-        :return: channels of tensor(s) that will be returned by a module  in forward
-        """
-        raise NotImplementedError()
+__all__ = [
+    "PANNeck",
+    "NHeads",
+    "MultiOutputBackbone",
+    "NStageBackbone",
+    "MobileNetV1Backbone",
+    "MobileNetV2Backbone",
+    "SSDNeck",
+    "SSDInvertedResidualNeck",
+    "SSDBottleneckNeck",
+    "SSDHead",
+    "BaseDetectionModule",
+]
 
 
 
 
 @register_detection_module()
 @register_detection_module()
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
  1. from typing import Union
  2. import torch
  3. from torch import nn
  4. __all__ = ["replace_num_classes_with_random_weights"]
  5. def replace_num_classes_with_random_weights(module: Union[nn.Conv2d, nn.Linear, nn.Module], num_classes: int) -> nn.Module:
  6. """
  7. Replace the number of classes in the module with random weights.
  8. This is useful for replacing the output layer of a detection/classification head.
  9. This implementation support Conv2d and Linear layers.
  10. Returned module will have the same device and dtype as the original module.
  11. Random weights are initialized with the same mean and std as the original weights.
  12. :param module: (nn.Module) Module to replace the number of classes in.
  13. :param num_classes: New number of classes.
  14. :return: nn.Module
  15. """
  16. if isinstance(module, nn.Conv2d):
  17. new_module = nn.Conv2d(
  18. module.in_channels,
  19. num_classes,
  20. kernel_size=module.kernel_size,
  21. stride=module.stride,
  22. padding=module.padding,
  23. dilation=module.dilation,
  24. groups=module.groups,
  25. bias=module.bias is not None,
  26. device=module.weight.device,
  27. dtype=module.weight.dtype,
  28. )
  29. torch.nn.init.normal_(new_module.weight, mean=module.weight.mean().item(), std=module.weight.std(dim=(0, 1, 2, 3)).item())
  30. if module.bias is not None:
  31. torch.nn.init.normal_(new_module.bias, mean=module.bias.mean().item(), std=module.bias.std(dim=0).item())
  32. return new_module
  33. elif isinstance(module, nn.Linear):
  34. new_module = nn.Linear(module.in_features, num_classes, device=module.weight.device, dtype=module.weight.dtype, bias=module.bias is not None)
  35. torch.nn.init.normal_(new_module.weight, mean=module.weight.mean().item(), std=module.weight.std(dim=(0, 1, 2)).item())
  36. if module.bias is not None:
  37. torch.nn.init.normal_(new_module.bias, mean=module.bias.mean().item(), std=module.bias.std(dim=0).item())
  38. return new_module
  39. else:
  40. raise ValueError(f"Module {module} does not support replacing the number of classes")
Discard
@@ -5,7 +5,7 @@ from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
 from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
 from torch import nn, Tensor
 from torch import nn, Tensor
 
 
-from super_gradients.modules.detection_modules import BaseDetectionModule
+from super_gradients.modules.base_modules import BaseDetectionModule
 from super_gradients.common.registry.registry import register_detection_module
 from super_gradients.common.registry.registry import register_detection_module
 
 
 
 
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  1. backbone:
  2. NStageBackbone:
  3. stem:
  4. YoloNASStem:
  5. out_channels: 48
  6. stages:
  7. - YoloNASStage:
  8. out_channels: 96
  9. num_blocks: 2
  10. activation_type: relu
  11. hidden_channels: 96
  12. concat_intermediates: True
  13. - YoloNASStage:
  14. out_channels: 192
  15. num_blocks: 3
  16. activation_type: relu
  17. hidden_channels: 128
  18. concat_intermediates: True
  19. - YoloNASStage:
  20. out_channels: 384
  21. num_blocks: 5
  22. activation_type: relu
  23. hidden_channels: 256
  24. concat_intermediates: True
  25. - YoloNASStage:
  26. out_channels: 768
  27. num_blocks: 2
  28. activation_type: relu
  29. hidden_channels: 512
  30. concat_intermediates: True
  31. context_module:
  32. SPP:
  33. output_channels: 768
  34. activation_type: relu
  35. k: [5,9,13]
  36. out_layers: [stage1, stage2, stage3, context_module]
  37. neck:
  38. YoloNASPANNeckWithC2:
  39. neck1:
  40. YoloNASUpStage:
  41. out_channels: 192
  42. num_blocks: 4
  43. hidden_channels: 128
  44. width_mult: 1
  45. depth_mult: 1
  46. activation_type: relu
  47. reduce_channels: True
  48. neck2:
  49. YoloNASUpStage:
  50. out_channels: 96
  51. num_blocks: 4
  52. hidden_channels: 128
  53. width_mult: 1
  54. depth_mult: 1
  55. activation_type: relu
  56. reduce_channels: True
  57. neck3:
  58. YoloNASDownStage:
  59. out_channels: 192
  60. num_blocks: 4
  61. hidden_channels: 128
  62. activation_type: relu
  63. width_mult: 1
  64. depth_mult: 1
  65. neck4:
  66. YoloNASDownStage:
  67. out_channels: 384
  68. num_blocks: 4
  69. hidden_channels: 256
  70. activation_type: relu
  71. width_mult: 1
  72. depth_mult: 1
  73. heads:
  74. NDFLHeads:
  75. num_classes: 80
  76. reg_max: 16
  77. heads_list:
  78. - YoloNASDFLHead:
  79. inter_channels: 128
  80. width_mult: 1
  81. first_conv_group_size: 0
  82. stride: 8
  83. - YoloNASDFLHead:
  84. inter_channels: 256
  85. width_mult: 1
  86. first_conv_group_size: 0
  87. stride: 16
  88. - YoloNASDFLHead:
  89. inter_channels: 512
  90. width_mult: 1
  91. first_conv_group_size: 0
  92. stride: 32
  93. bn_eps: 1e-3
  94. bn_momentum: 0.03
  95. inplace_act: True
  96. _convert_: all
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  1. backbone:
  2. NStageBackbone:
  3. stem:
  4. YoloNASStem:
  5. out_channels: 48
  6. stages:
  7. - YoloNASStage:
  8. out_channels: 96
  9. num_blocks: 2
  10. activation_type: relu
  11. hidden_channels: 64
  12. concat_intermediates: True
  13. - YoloNASStage:
  14. out_channels: 192
  15. num_blocks: 3
  16. activation_type: relu
  17. hidden_channels: 128
  18. concat_intermediates: True
  19. - YoloNASStage:
  20. out_channels: 384
  21. num_blocks: 5
  22. activation_type: relu
  23. hidden_channels: 256
  24. concat_intermediates: True
  25. - YoloNASStage:
  26. out_channels: 768
  27. num_blocks: 2
  28. activation_type: relu
  29. hidden_channels: 384
  30. concat_intermediates: False
  31. context_module:
  32. SPP:
  33. output_channels: 768
  34. activation_type: relu
  35. k: [5,9,13]
  36. out_layers: [stage1, stage2, stage3, context_module]
  37. neck:
  38. YoloNASPANNeckWithC2:
  39. neck1:
  40. YoloNASUpStage:
  41. out_channels: 192
  42. num_blocks: 2
  43. hidden_channels: 192
  44. width_mult: 1
  45. depth_mult: 1
  46. activation_type: relu
  47. reduce_channels: True
  48. neck2:
  49. YoloNASUpStage:
  50. out_channels: 96
  51. num_blocks: 3
  52. hidden_channels: 64
  53. width_mult: 1
  54. depth_mult: 1
  55. activation_type: relu
  56. reduce_channels: True
  57. neck3:
  58. YoloNASDownStage:
  59. out_channels: 192
  60. num_blocks: 2
  61. hidden_channels: 192
  62. activation_type: relu
  63. width_mult: 1
  64. depth_mult: 1
  65. neck4:
  66. YoloNASDownStage:
  67. out_channels: 384
  68. num_blocks: 3
  69. hidden_channels: 256
  70. activation_type: relu
  71. width_mult: 1
  72. depth_mult: 1
  73. heads:
  74. NDFLHeads:
  75. num_classes: 80
  76. reg_max: 16
  77. heads_list:
  78. - YoloNASDFLHead:
  79. inter_channels: 128
  80. width_mult: 0.75
  81. first_conv_group_size: 0
  82. stride: 8
  83. - YoloNASDFLHead:
  84. inter_channels: 256
  85. width_mult: 0.75
  86. first_conv_group_size: 0
  87. stride: 16
  88. - YoloNASDFLHead:
  89. inter_channels: 512
  90. width_mult: 0.75
  91. first_conv_group_size: 0
  92. stride: 32
  93. bn_eps: 1e-3
  94. bn_momentum: 0.03
  95. inplace_act: True
  96. _convert_: all
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  1. backbone:
  2. NStageBackbone:
  3. stem:
  4. YoloNASStem:
  5. out_channels: 48
  6. stages:
  7. - YoloNASStage:
  8. out_channels: 96
  9. num_blocks: 2
  10. activation_type: relu
  11. hidden_channels: 32
  12. concat_intermediates: False
  13. - YoloNASStage:
  14. out_channels: 192
  15. num_blocks: 3
  16. activation_type: relu
  17. hidden_channels: 64
  18. concat_intermediates: False
  19. - YoloNASStage:
  20. out_channels: 384
  21. num_blocks: 5
  22. activation_type: relu
  23. hidden_channels: 96
  24. concat_intermediates: False
  25. - YoloNASStage:
  26. out_channels: 768
  27. num_blocks: 2
  28. activation_type: relu
  29. hidden_channels: 192
  30. concat_intermediates: False
  31. context_module:
  32. SPP:
  33. output_channels: 768
  34. activation_type: relu
  35. k: [5,9,13]
  36. out_layers: [stage1, stage2, stage3, context_module]
  37. neck:
  38. YoloNASPANNeckWithC2:
  39. neck1:
  40. YoloNASUpStage:
  41. out_channels: 192
  42. num_blocks: 2
  43. hidden_channels: 64
  44. width_mult: 1
  45. depth_mult: 1
  46. activation_type: relu
  47. reduce_channels: True
  48. neck2:
  49. YoloNASUpStage:
  50. out_channels: 96
  51. num_blocks: 2
  52. hidden_channels: 48
  53. width_mult: 1
  54. depth_mult: 1
  55. activation_type: relu
  56. reduce_channels: True
  57. neck3:
  58. YoloNASDownStage:
  59. out_channels: 192
  60. num_blocks: 2
  61. hidden_channels: 64
  62. activation_type: relu
  63. width_mult: 1
  64. depth_mult: 1
  65. neck4:
  66. YoloNASDownStage:
  67. out_channels: 384
  68. num_blocks: 2
  69. hidden_channels: 64
  70. activation_type: relu
  71. width_mult: 1
  72. depth_mult: 1
  73. heads:
  74. NDFLHeads:
  75. num_classes: 80
  76. reg_max: 16
  77. heads_list:
  78. - YoloNASDFLHead:
  79. inter_channels: 128
  80. width_mult: 0.5
  81. first_conv_group_size: 0
  82. stride: 8
  83. - YoloNASDFLHead:
  84. inter_channels: 256
  85. width_mult: 0.5
  86. first_conv_group_size: 0
  87. stride: 16
  88. - YoloNASDFLHead:
  89. inter_channels: 512
  90. width_mult: 0.5
  91. first_conv_group_size: 0
  92. stride: 32
  93. bn_eps: 1e-3
  94. bn_momentum: 0.03
  95. inplace_act: True
  96. _convert_: all
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
  1. # YoloNAS-S Detection training on COCO2017 Dataset:
  2. # This training recipe is for demonstration purposes only. Pretrained models were trained using a different recipe.
  3. # So it will not be possible to reproduce the results of the pretrained models using this recipe.
  4. # Instructions:
  5. # 0. Make sure that the data is stored in dataset_params.dataset_dir or add "dataset_params.data_dir=<PATH-TO-DATASET>" at the end of the command below (feel free to check ReadMe)
  6. # 1. Move to the project root (where you will find the ReadMe and src folder)
  7. # 2. Run the command you want:
  8. # yolo_nas_s: python src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=coco2017_yolo_nas_s
  9. #
  10. defaults:
  11. - training_hyperparams: coco2017_yolo_nas_train_params
  12. - dataset_params: coco_detection_yolo_nas_dataset_params
  13. - arch_params: yolo_nas_s_arch_params
  14. - checkpoint_params: default_checkpoint_params
  15. - _self_
  16. - variable_setup
  17. train_dataloader: coco2017_train_yolo_nas
  18. val_dataloader: coco2017_val_yolo_nas
  19. load_checkpoint: False
  20. resume: False
  21. dataset_params:
  22. train_dataloader_params:
  23. batch_size: 32
  24. arch_params:
  25. num_classes: 80
  26. training_hyperparams:
  27. resume: ${resume}
  28. mixed_precision: True
  29. architecture: yolo_nas_s
  30. multi_gpu: DDP
  31. num_gpus: 8
  32. experiment_suffix: ""
  33. experiment_name: coco2017_${architecture}${experiment_suffix}
Discard
@@ -30,15 +30,15 @@ train_dataset_params:
         mixup_scale: [ 0.5, 1.5 ]         # random rescale range for the additional sample in mixup
         mixup_scale: [ 0.5, 1.5 ]         # random rescale range for the additional sample in mixup
         prob: 0.5                       # probability to apply per-sample mixup
         prob: 0.5                       # probability to apply per-sample mixup
         flip_prob: 0.5                  # probability to apply horizontal flip
         flip_prob: 0.5                  # probability to apply horizontal flip
-    - DetectionStandardizeImage:
-        max_value: 255.
     - DetectionPaddedRescale:
     - DetectionPaddedRescale:
         input_dim: [640, 640]
         input_dim: [640, 640]
         max_targets: 120
         max_targets: 120
         pad_value: 114
         pad_value: 114
+    - DetectionStandardize:
+        max_value: 255.
     - DetectionTargetsFormatTransform:
     - DetectionTargetsFormatTransform:
         max_targets: 256
         max_targets: 256
-        output_format: LABEL_NORMALIZED_CXCYWH
+        output_format: LABEL_CXCYWH
 
 
   tight_box_rotation: False
   tight_box_rotation: False
   class_inclusion_list:
   class_inclusion_list:
@@ -67,13 +67,13 @@ val_dataset_params:
     - DetectionPadToSize:
     - DetectionPadToSize:
         output_size: [640, 640]
         output_size: [640, 640]
         pad_value: 114
         pad_value: 114
-    - DetectionStandardizeImage:
+    - DetectionStandardize:
         max_value: 255.
         max_value: 255.
     - DetectionImagePermute
     - DetectionImagePermute
     - DetectionTargetsFormatTransform:
     - DetectionTargetsFormatTransform:
         max_targets: 50
         max_targets: 50
         input_dim: [640, 640]
         input_dim: [640, 640]
-        output_format: LABEL_NORMALIZED_CXCYWH
+        output_format: LABEL_CXCYWH
   tight_box_rotation: False
   tight_box_rotation: False
   class_inclusion_list:
   class_inclusion_list:
   max_num_samples:
   max_num_samples:
@@ -83,6 +83,7 @@ val_dataloader_params:
   batch_size: 25
   batch_size: 25
   num_workers: 8
   num_workers: 8
   drop_last: False
   drop_last: False
+  shuffle: False
   pin_memory: True
   pin_memory: True
   collate_fn:
   collate_fn:
     _target_: super_gradients.training.utils.detection_utils.CrowdDetectionCollateFN
     _target_: super_gradients.training.utils.detection_utils.CrowdDetectionCollateFN
Discard
@@ -9,18 +9,27 @@ train_dataset_params:
   input_dim: [640, 640]
   input_dim: [640, 640]
   cache_dir:
   cache_dir:
   cache: False
   cache: False
+  ignore_empty_annotations: False
   transforms:
   transforms:
+    - DetectionMosaic:
+        input_dim: ${dataset_params.train_dataset_params.input_dim}
+        prob: 1.
     - DetectionRandomAffine:
     - DetectionRandomAffine:
         degrees: 0.                  # rotation degrees, randomly sampled from [-degrees, degrees]
         degrees: 0.                  # rotation degrees, randomly sampled from [-degrees, degrees]
         translate: 0.1                # image translation fraction
         translate: 0.1                # image translation fraction
         scales: [ 0.5, 1.5 ]              # random rescale range (keeps size by padding/cropping) after mosaic transform.
         scales: [ 0.5, 1.5 ]              # random rescale range (keeps size by padding/cropping) after mosaic transform.
         shear: 0.0                    # shear degrees, randomly sampled from [-degrees, degrees]
         shear: 0.0                    # shear degrees, randomly sampled from [-degrees, degrees]
         target_size: ${dataset_params.train_dataset_params.input_dim}
         target_size: ${dataset_params.train_dataset_params.input_dim}
-        filter_box_candidates: True   # whether to filter out transformed bboxes by edge size, area ratio, and aspect ratio.
+        filter_box_candidates: False  # whether to filter out transformed bboxes by edge size, area ratio, and aspect ratio.
         wh_thr: 2                     # edge size threshold when filter_box_candidates = True (pixels)
         wh_thr: 2                     # edge size threshold when filter_box_candidates = True (pixels)
         area_thr: 0.1                 # threshold for area ratio between original image and the transformed one, when filter_box_candidates = True
         area_thr: 0.1                 # threshold for area ratio between original image and the transformed one, when filter_box_candidates = True
         ar_thr: 20                    # aspect ratio threshold when filter_box_candidates = True
         ar_thr: 20                    # aspect ratio threshold when filter_box_candidates = True
         border_value: 128
         border_value: 128
+#    - DetectionMixup:
+#        input_dim: ${dataset_params.train_dataset_params.input_dim}
+#        mixup_scale: [ 0.5, 1.5 ]         # random rescale range for the additional sample in mixup
+#        prob: 1.0                       # probability to apply per-sample mixup
+#        flip_prob: 0.5                  # probability to apply horizontal flip
     - DetectionHSV:
     - DetectionHSV:
         prob: 1.0                       # probability to apply HSV transform
         prob: 1.0                       # probability to apply HSV transform
         hgain: 5                        # HSV transform hue gain (randomly sampled from [-hgain, hgain])
         hgain: 5                        # HSV transform hue gain (randomly sampled from [-hgain, hgain])
@@ -30,8 +39,11 @@ train_dataset_params:
         prob: 0.5                       # probability to apply horizontal flip
         prob: 0.5                       # probability to apply horizontal flip
     - DetectionPaddedRescale:
     - DetectionPaddedRescale:
         input_dim: ${dataset_params.train_dataset_params.input_dim}
         input_dim: ${dataset_params.train_dataset_params.input_dim}
-        max_targets: 120
+        max_targets: 300
+    - DetectionStandardize:
+        max_value: 255.
     - DetectionTargetsFormatTransform:
     - DetectionTargetsFormatTransform:
+        max_targets: 300
         input_dim: ${dataset_params.train_dataset_params.input_dim}
         input_dim: ${dataset_params.train_dataset_params.input_dim}
         output_format: LABEL_CXCYWH
         output_format: LABEL_CXCYWH
   tight_box_rotation: False
   tight_box_rotation: False
@@ -43,8 +55,8 @@ train_dataset_params:
 train_dataloader_params:
 train_dataloader_params:
   shuffle: True
   shuffle: True
   batch_size: 16
   batch_size: 16
-  num_workers: 0
-  sampler:
+  min_samples: 512
+  num_workers: 4
   drop_last: False
   drop_last: False
   pin_memory: True
   pin_memory: True
   worker_init_fn:
   worker_init_fn:
@@ -60,11 +72,16 @@ val_dataset_params:
   input_dim: [640, 640]
   input_dim: [640, 640]
   cache_dir:
   cache_dir:
   cache: False
   cache: False
+  ignore_empty_annotations: False
   transforms:
   transforms:
   - DetectionPaddedRescale:
   - DetectionPaddedRescale:
       input_dim: ${dataset_params.val_dataset_params.input_dim}
       input_dim: ${dataset_params.val_dataset_params.input_dim}
+      max_targets: 300
+      pad_value: 114
+  - DetectionStandardize:
+      max_value: 255.
   - DetectionTargetsFormatTransform:
   - DetectionTargetsFormatTransform:
-      max_targets: 50
+      max_targets: 300
       input_dim: ${dataset_params.val_dataset_params.input_dim}
       input_dim: ${dataset_params.val_dataset_params.input_dim}
       output_format: LABEL_CXCYWH
       output_format: LABEL_CXCYWH
   tight_box_rotation: False
   tight_box_rotation: False
@@ -74,10 +91,10 @@ val_dataset_params:
   verbose: 0
   verbose: 0
 
 
 val_dataloader_params:
 val_dataloader_params:
-  batch_size: 64
-  num_workers: 0
-  sampler:
+  batch_size: 32
+  num_workers: 4
   drop_last: False
   drop_last: False
+  shuffle: False
   pin_memory: True
   pin_memory: True
   collate_fn: # collate function for valset
   collate_fn: # collate function for valset
     _target_: super_gradients.training.utils.detection_utils.CrowdDetectionCollateFN
     _target_: super_gradients.training.utils.detection_utils.CrowdDetectionCollateFN
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
  1. # A recipe to fine-tune YoloNAS on Roboflow datasets.
  2. # Checkout the datasets at https://universe.roboflow.com/roboflow-100?ref=blog.roboflow.com
  3. #
  4. # `dataset_name` refers to the official name of the dataset.
  5. # You can find it in the url of the dataset: https://universe.roboflow.com/roboflow-100/digits-t2eg6 -> digits-t2eg6
  6. #
  7. # Example: python -m super_gradients.train_from_recipe --config-name=roboflow_yolo_nas_m dataset_name=digits-t2eg6
  8. defaults:
  9. - training_hyperparams: coco2017_yolo_nas_train_params
  10. - dataset_params: roboflow_detection_dataset_params
  11. - checkpoint_params: default_checkpoint_params
  12. - arch_params: yolo_nas_m_arch_params
  13. - _self_
  14. - variable_setup
  15. train_dataloader: roboflow_train_yolox
  16. val_dataloader: roboflow_val_yolox
  17. dataset_name: ??? # Placeholder for the name of the dataset you want to use (e.g. "digits-t2eg6")
  18. dataset_params:
  19. dataset_name: ${dataset_name}
  20. train_dataloader_params:
  21. batch_size: 12
  22. val_dataloader_params:
  23. batch_size: 16
  24. num_classes: ${roboflow_dataset_num_classes:${dataset_name}}
  25. architecture: yolo_nas_m
  26. arch_params:
  27. num_classes: ${num_classes}
  28. load_checkpoint: False
  29. checkpoint_params:
  30. pretrained_weights: coco
  31. result_path: # By defaults saves results in checkpoints directory
  32. resume: False
  33. training_hyperparams:
  34. resume: ${resume}
  35. zero_weight_decay_on_bias_and_bn: True
  36. lr_warmup_epochs: 3
  37. warmup_mode: linear_epoch_step
  38. initial_lr: 4e-4
  39. cosine_final_lr_ratio: 0.1
  40. optimizer_params:
  41. weight_decay: 0.0001
  42. ema: True
  43. ema_params:
  44. decay: 0.9
  45. max_epochs: 100
  46. mixed_precision: True
  47. criterion_params:
  48. num_classes: ${num_classes}
  49. phase_callbacks: []
  50. loss:
  51. ppyoloe_loss:
  52. num_classes: ${num_classes}
  53. reg_max: 16
  54. valid_metrics_list:
  55. - DetectionMetrics_050:
  56. score_thres: 0.1
  57. top_k_predictions: 300
  58. num_cls: ${num_classes}
  59. normalize_targets: True
  60. post_prediction_callback:
  61. _target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
  62. score_threshold: 0.01
  63. nms_top_k: 1000
  64. max_predictions: 300
  65. nms_threshold: 0.7
  66. metric_to_watch: 'mAP@0.50'
  67. multi_gpu: Off
  68. num_gpus: 1
  69. experiment_suffix: ""
  70. experiment_name: ${architecture}_roboflow_${dataset_name}${experiment_suffix}
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
  1. # A recipe to fine-tune YoloNAS on Roboflow datasets.
  2. # Checkout the datasets at https://universe.roboflow.com/roboflow-100?ref=blog.roboflow.com
  3. #
  4. # `dataset_name` refers to the official name of the dataset.
  5. # You can find it in the url of the dataset: https://universe.roboflow.com/roboflow-100/digits-t2eg6 -> digits-t2eg6
  6. #
  7. # Example: python -m super_gradients.train_from_recipe --config-name=roboflow_yolo_nas_s dataset_name=digits-t2eg6
  8. defaults:
  9. - training_hyperparams: coco2017_yolo_nas_train_params
  10. - dataset_params: roboflow_detection_dataset_params
  11. - checkpoint_params: default_checkpoint_params
  12. - arch_params: yolo_nas_s_arch_params
  13. - _self_
  14. - variable_setup
  15. train_dataloader: roboflow_train_yolox
  16. val_dataloader: roboflow_val_yolox
  17. dataset_name: ??? # Placeholder for the name of the dataset you want to use (e.g. "digits-t2eg6")
  18. dataset_params:
  19. dataset_name: ${dataset_name}
  20. train_dataloader_params:
  21. batch_size: 16
  22. val_dataloader_params:
  23. batch_size: 16
  24. num_classes: ${roboflow_dataset_num_classes:${dataset_name}}
  25. architecture: yolo_nas_s
  26. arch_params:
  27. num_classes: ${num_classes}
  28. load_checkpoint: False
  29. checkpoint_params:
  30. pretrained_weights: coco
  31. result_path: # By defaults saves results in checkpoints directory
  32. resume: False
  33. training_hyperparams:
  34. resume: ${resume}
  35. zero_weight_decay_on_bias_and_bn: True
  36. lr_warmup_epochs: 3
  37. warmup_mode: linear_epoch_step
  38. initial_lr: 5e-4
  39. cosine_final_lr_ratio: 0.1
  40. optimizer_params:
  41. weight_decay: 0.0001
  42. ema: True
  43. ema_params:
  44. decay: 0.9
  45. max_epochs: 100
  46. mixed_precision: True
  47. criterion_params:
  48. num_classes: ${num_classes}
  49. phase_callbacks: []
  50. loss:
  51. ppyoloe_loss:
  52. num_classes: ${num_classes}
  53. reg_max: 16
  54. valid_metrics_list:
  55. - DetectionMetrics_050:
  56. score_thres: 0.1
  57. top_k_predictions: 300
  58. num_cls: ${num_classes}
  59. normalize_targets: True
  60. post_prediction_callback:
  61. _target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
  62. score_threshold: 0.01
  63. nms_top_k: 1000
  64. max_predictions: 300
  65. nms_threshold: 0.7
  66. metric_to_watch: 'mAP@0.50'
  67. multi_gpu: Off
  68. num_gpus: 1
  69. experiment_suffix: ""
  70. experiment_name: ${architecture}_roboflow_${dataset_name}${experiment_suffix}
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
  1. defaults:
  2. - roboflow_yolo_nas_s
  3. - quantization_params: default_quantization_params
  4. - _self_
  5. checkpoint_params:
  6. checkpoint_path: ???
  7. strict_load: no_key_matching
  8. pre_launch_callbacks_list:
  9. - QATRecipeModificationCallback:
  10. batch_size_divisor: 2
  11. max_epochs_divisor: 10
  12. lr_decay_factor: 0.01
  13. warmup_epochs_divisor: 10
  14. cosine_final_lr_ratio: 0.01
  15. disable_phase_callbacks: True
  16. disable_augmentations: False
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
  1. defaults:
  2. - default_train_params
  3. max_epochs: 300
  4. warmup_mode: "linear_batch_step"
  5. warmup_initial_lr: 1e-6
  6. lr_warmup_steps: 1000
  7. lr_warmup_epochs: 0
  8. initial_lr: 2e-4
  9. lr_mode: cosine
  10. cosine_final_lr_ratio: 0.1
  11. zero_weight_decay_on_bias_and_bn: True
  12. batch_accumulate: 1
  13. save_ckpt_epoch_list: [100, 200, 250]
  14. loss:
  15. ppyoloe_loss:
  16. use_static_assigner: False
  17. num_classes: ${arch_params.num_classes}
  18. reg_max: 16
  19. optimizer: AdamW
  20. optimizer_params:
  21. weight_decay: 0.00001
  22. ema: True
  23. ema_params:
  24. decay: 0.9997
  25. decay_type: threshold
  26. mixed_precision: False
  27. sync_bn: True
  28. valid_metrics_list:
  29. - DetectionMetrics:
  30. score_thres: 0.1
  31. top_k_predictions: 300
  32. num_cls: ${arch_params.num_classes}
  33. normalize_targets: True
  34. post_prediction_callback:
  35. _target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
  36. score_threshold: 0.01
  37. nms_top_k: 1000
  38. max_predictions: 300
  39. nms_threshold: 0.7
  40. pre_prediction_callback:
  41. metric_to_watch: 'mAP@0.50:0.95'
  42. greater_metric_to_watch_is_better: True
  43. _convert_: all
Discard
@@ -9,8 +9,8 @@ from .dataloaders import (
     coco2017_val_ppyoloe,
     coco2017_val_ppyoloe,
     coco2017_pose_train,
     coco2017_pose_train,
     coco2017_pose_val,
     coco2017_pose_val,
-    coco2017_train_deci_yolo,
-    coco2017_val_deci_yolo,
+    coco2017_train_yolo_nas,
+    coco2017_val_yolo_nas,
     imagenet_train,
     imagenet_train,
     imagenet_val,
     imagenet_val,
     imagenet_efficientnet_train,
     imagenet_efficientnet_train,
@@ -68,8 +68,8 @@ __all__ = [
     "coco2017_val_ppyoloe",
     "coco2017_val_ppyoloe",
     "coco2017_pose_train",
     "coco2017_pose_train",
     "coco2017_pose_val",
     "coco2017_pose_val",
-    "coco2017_train_deci_yolo",
-    "coco2017_val_deci_yolo",
+    "coco2017_train_yolo_nas",
+    "coco2017_val_yolo_nas",
     "imagenet_train",
     "imagenet_train",
     "imagenet_val",
     "imagenet_val",
     "imagenet_efficientnet_train",
     "imagenet_efficientnet_train",
Discard
@@ -172,10 +172,10 @@ def coco2017_val(dataset_params: Dict = None, dataloader_params: Dict = None) ->
     )
     )
 
 
 
 
-@register_dataloader(Dataloaders.COCO2017_TRAIN_DECIYOLO)
-def coco2017_train_deci_yolo(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
+@register_dataloader(Dataloaders.COCO2017_TRAIN_YOLO_NAS)
+def coco2017_train_yolo_nas(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
     return get_data_loader(
     return get_data_loader(
-        config_name="coco_detection_deci_yolo_dataset_params",
+        config_name="coco_detection_yolo_nas_dataset_params",
         dataset_cls=COCODetectionDataset,
         dataset_cls=COCODetectionDataset,
         train=True,
         train=True,
         dataset_params=dataset_params,
         dataset_params=dataset_params,
@@ -183,10 +183,10 @@ def coco2017_train_deci_yolo(dataset_params: Dict = None, dataloader_params: Dic
     )
     )
 
 
 
 
-@register_dataloader(Dataloaders.COCO2017_VAL_DECIYOLO)
-def coco2017_val_deci_yolo(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
+@register_dataloader(Dataloaders.COCO2017_VAL_YOLO_NAS)
+def coco2017_val_yolo_nas(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
     return get_data_loader(
     return get_data_loader(
-        config_name="coco_detection_deci_yolo_dataset_params",
+        config_name="coco_detection_yolo_nas_dataset_params",
         dataset_cls=COCODetectionDataset,
         dataset_cls=COCODetectionDataset,
         train=False,
         train=False,
         dataset_params=dataset_params,
         dataset_params=dataset_params,
Discard
@@ -33,7 +33,7 @@ DATASETS_METADATA = {
     "underwater-objects-5v7p8": {"category": "underwater", "train": 5320, "test": 760, "valid": 1520, "size": 7600, "num_classes": 5, "num_classes_found": 5},
     "underwater-objects-5v7p8": {"category": "underwater", "train": 5320, "test": 760, "valid": 1520, "size": 7600, "num_classes": 5, "num_classes_found": 5},
     "coral-lwptl": {"category": "underwater", "train": 427, "test": 74, "valid": 93, "size": 594, "num_classes": 14, "num_classes_found": 14},
     "coral-lwptl": {"category": "underwater", "train": 427, "test": 74, "valid": 93, "size": 594, "num_classes": 14, "num_classes_found": 14},
     "tweeter-posts": {"category": "documents", "train": 87, "test": 9, "valid": 21, "size": 117, "num_classes": 2, "num_classes_found": 2},
     "tweeter-posts": {"category": "documents", "train": 87, "test": 9, "valid": 21, "size": 117, "num_classes": 2, "num_classes_found": 2},
-    "tweeter-profile": {"category": "documents", "train": 425, "test": 61, "valid": 121, "size": 607, "num_classes": 1, "num_classes_found": 0},
+    "tweeter-profile": {"category": "documents", "train": 425, "test": 61, "valid": 121, "size": 607, "num_classes": 1, "num_classes_found": 1},
     "document-parts": {"category": "documents", "train": 906, "test": 150, "valid": 318, "size": 1374, "num_classes": 2, "num_classes_found": 2},
     "document-parts": {"category": "documents", "train": 906, "test": 150, "valid": 318, "size": 1374, "num_classes": 2, "num_classes_found": 2},
     "activity-diagrams-qdobr": {"category": "documents", "train": 259, "test": 45, "valid": 74, "size": 378, "num_classes": 19, "num_classes_found": 19},
     "activity-diagrams-qdobr": {"category": "documents", "train": 259, "test": 45, "valid": 74, "size": 378, "num_classes": 19, "num_classes_found": 19},
     "signatures-xc8up": {"category": "documents", "train": 257, "test": 37, "valid": 74, "size": 368, "num_classes": 1, "num_classes_found": 1},
     "signatures-xc8up": {"category": "documents", "train": 257, "test": 37, "valid": 74, "size": 368, "num_classes": 1, "num_classes_found": 1},
@@ -148,7 +148,7 @@ _NUM_CLASSES_FOUND = {
     "underwater-objects-5v7p8": 5,
     "underwater-objects-5v7p8": 5,
     "coral-lwptl": 14,
     "coral-lwptl": 14,
     "tweeter-posts": 2,
     "tweeter-posts": 2,
-    "tweeter-profile": 0,
+    "tweeter-profile": 1,
     "document-parts": 2,
     "document-parts": 2,
     "activity-diagrams-qdobr": 19,
     "activity-diagrams-qdobr": 19,
     "signatures-xc8up": 1,
     "signatures-xc8up": 1,
Discard
@@ -62,13 +62,26 @@ from super_gradients.training.models.classification_models.vgg import VGG
 from super_gradients.training.models.classification_models.vit import ViT, ViTBase, ViTLarge, ViTHuge
 from super_gradients.training.models.classification_models.vit import ViT, ViTBase, ViTLarge, ViTHuge
 
 
 # Detection models
 # Detection models
-from super_gradients.training.models.detection_models.csp_darknet53 import CSPDarknet53
-from super_gradients.training.models.detection_models.pp_yolo_e.pp_yolo_e import PPYoloE, PPYoloE_S, PPYoloE_M, PPYoloE_L, PPYoloE_X
+from super_gradients.training.models.detection_models.csp_darknet53 import CSPDarknet53, SPP
+from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloE, PPYoloE_S, PPYoloE_M, PPYoloE_L, PPYoloE_X
 from super_gradients.training.models.detection_models.darknet53 import Darknet53, Darknet53Base
 from super_gradients.training.models.detection_models.darknet53 import Darknet53, Darknet53Base
 from super_gradients.training.models.detection_models.ssd import SSDMobileNetV1, SSDLiteMobileNetV2
 from super_gradients.training.models.detection_models.ssd import SSDMobileNetV1, SSDLiteMobileNetV2
 from super_gradients.training.models.detection_models.yolo_base import YoloBase, YoloPostPredictionCallback
 from super_gradients.training.models.detection_models.yolo_base import YoloBase, YoloPostPredictionCallback
 from super_gradients.training.models.detection_models.yolox import YoloX_N, YoloX_T, YoloX_S, YoloX_M, YoloX_L, YoloX_X, CustomYoloX
 from super_gradients.training.models.detection_models.yolox import YoloX_N, YoloX_T, YoloX_S, YoloX_M, YoloX_L, YoloX_X, CustomYoloX
 from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector
 from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector
+from super_gradients.training.models.detection_models.yolo_nas import (
+    YoloNASStage,
+    YoloNASStem,
+    YoloNASDownStage,
+    YoloNASUpStage,
+    YoloNASBottleneck,
+    YoloNASDFLHead,
+    NDFLHeads,
+    YoloNASPANNeckWithC2,
+    YoloNAS_S,
+    YoloNAS_M,
+    YoloNAS_L,
+)
 
 
 # Segmentation models
 # Segmentation models
 from super_gradients.training.models.segmentation_models.shelfnet import (
 from super_gradients.training.models.segmentation_models.shelfnet import (
@@ -96,7 +109,6 @@ from super_gradients.training.models.segmentation_models.stdc import (
     STDCSegmentationBase,
     STDCSegmentationBase,
     CustomSTDCSegmentation,
     CustomSTDCSegmentation,
 )
 )
-from super_gradients.training.models.segmentation_models.segformer import SegFormerB0, SegFormerB1, SegFormerB2, SegFormerB3, SegFormerB4, SegFormerB5
 
 
 # Pose estimation
 # Pose estimation
 from super_gradients.training.models.pose_estimation_models.pose_ppyolo import PosePPYoloL
 from super_gradients.training.models.pose_estimation_models.pose_ppyolo import PosePPYoloL
@@ -116,6 +128,18 @@ from super_gradients.common.object_names import Models
 from super_gradients.common.registry.registry import ARCHITECTURES
 from super_gradients.common.registry.registry import ARCHITECTURES
 
 
 __all__ = [
 __all__ = [
+    "SPP",
+    "YoloNAS_S",
+    "YoloNAS_M",
+    "YoloNAS_L",
+    "YoloNASStage",
+    "YoloNASUpStage",
+    "YoloNASStem",
+    "YoloNASDownStage",
+    "YoloNASDFLHead",
+    "YoloNASBottleneck",
+    "NDFLHeads",
+    "YoloNASPANNeckWithC2",
     "SgModule",
     "SgModule",
     "Beit",
     "Beit",
     "BeitLargePatch16_224",
     "BeitLargePatch16_224",
@@ -259,10 +283,4 @@ __all__ = [
     "ARCHITECTURES",
     "ARCHITECTURES",
     "Models",
     "Models",
     "user_models",
     "user_models",
-    "SegFormerB0",
-    "SegFormerB1",
-    "SegFormerB2",
-    "SegFormerB3",
-    "SegFormerB4",
-    "SegFormerB5",
 ]
 ]
Discard
@@ -7,9 +7,11 @@ from typing import Tuple, Type
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 
 
+from super_gradients.common.decorators.factory_decorator import resolve_param
+from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
 from super_gradients.common.object_names import Models
 from super_gradients.common.object_names import Models
-from super_gradients.common.registry.registry import register_model
-from super_gradients.modules import Residual, Conv
+from super_gradients.common.registry.registry import register_model, register_detection_module
+from super_gradients.modules import Residual, Conv, BaseDetectionModule
 from super_gradients.modules.utils import width_multiplier
 from super_gradients.modules.utils import width_multiplier
 from super_gradients.training.models.sg_module import SgModule
 from super_gradients.training.models.sg_module import SgModule
 from super_gradients.training.utils.utils import get_param, HpmStruct
 from super_gradients.training.utils.utils import get_param, HpmStruct
@@ -127,13 +129,16 @@ class BottleneckCSP(nn.Module):
         return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
         return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
 
 
 
 
-class SPP(nn.Module):
+@register_detection_module()
+class SPP(BaseDetectionModule):
     # SPATIAL PYRAMID POOLING LAYER
     # SPATIAL PYRAMID POOLING LAYER
-    def __init__(self, input_channels, output_channels, k: Tuple, activation_type: Type[nn.Module]):
-        super().__init__()
+    @resolve_param("activation_type", ActivationsTypeFactory())
+    def __init__(self, in_channels, output_channels, k: Tuple, activation_type: Type[nn.Module]):
+        super().__init__(in_channels)
+        self._output_channels = output_channels
 
 
-        hidden_channels = input_channels // 2
-        self.cv1 = Conv(input_channels, hidden_channels, 1, 1, activation_type)
+        hidden_channels = in_channels // 2
+        self.cv1 = Conv(in_channels, hidden_channels, 1, 1, activation_type)
         self.cv2 = Conv(hidden_channels * (len(k) + 1), output_channels, 1, 1, activation_type)
         self.cv2 = Conv(hidden_channels * (len(k) + 1), output_channels, 1, 1, activation_type)
         self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
         self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
 
 
@@ -141,6 +146,13 @@ class SPP(nn.Module):
         x = self.cv1(x)
         x = self.cv1(x)
         return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
         return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
 
 
+    @property
+    def out_channels(self):
+        """
+        :return: channels of tensor(s) that will be returned by a module  in forward
+        """
+        return self._output_channels
+
 
 
 class ViewModule(nn.Module):
 class ViewModule(nn.Module):
     """
     """
Discard
@@ -12,6 +12,8 @@ from omegaconf import DictConfig
 
 
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.factories.processing_factory import ProcessingFactory
 from super_gradients.common.factories.processing_factory import ProcessingFactory
+from super_gradients.module_interfaces import SupportsReplaceNumClasses
+from super_gradients.modules.head_replacement_utils import replace_num_classes_with_random_weights
 from super_gradients.training.utils.utils import HpmStruct
 from super_gradients.training.utils.utils import HpmStruct
 from super_gradients.training.models.sg_module import SgModule
 from super_gradients.training.models.sg_module import SgModule
 import super_gradients.common.factories.detection_modules_factory as det_factory
 import super_gradients.common.factories.detection_modules_factory as det_factory
@@ -102,6 +104,8 @@ class CustomizableDetector(SgModule):
             raise ValueError("At least one of new_num_classes, new_head must be given to replace output layer.")
             raise ValueError("At least one of new_num_classes, new_head must be given to replace output layer.")
         if new_head is not None:
         if new_head is not None:
             self.heads = new_head
             self.heads = new_head
+        elif isinstance(self.heads, SupportsReplaceNumClasses):
+            self.heads.replace_num_classes(new_num_classes, replace_num_classes_with_random_weights)
         else:
         else:
             factory = det_factory.DetectionModulesFactory()
             factory = det_factory.DetectionModulesFactory()
             self.heads_params = factory.insert_module_param(self.heads_params, "num_classes", new_num_classes)
             self.heads_params = factory.insert_module_param(self.heads_params, "num_classes", new_num_classes)
Discard
@@ -1,4 +1,4 @@
-from .pp_yolo_e import PPYoloE
+from .pp_yolo_e import PPYoloE, PPYoloE_S, PPYoloE_M, PPYoloE_L, PPYoloE_X
 from .post_prediction_callback import PPYoloEPostPredictionCallback
 from .post_prediction_callback import PPYoloEPostPredictionCallback
 
 
-__all__ = ["PPYoloE", "PPYoloEPostPredictionCallback"]
+__all__ = ["PPYoloE", "PPYoloEPostPredictionCallback", "PPYoloE_L", "PPYoloE_M", "PPYoloE_S", "PPYoloE_X"]
Discard
@@ -10,10 +10,10 @@ from super_gradients.common.factories.activations_type_factory import Activation
 from super_gradients.training.models.detection_models.csp_resnet import CSPResNetBasicBlock
 from super_gradients.training.models.detection_models.csp_resnet import CSPResNetBasicBlock
 from super_gradients.modules import ConvBNAct
 from super_gradients.modules import ConvBNAct
 
 
-__all__ = ["CustomCSPPAN"]
+__all__ = ["PPYoloECSPPAN"]
 
 
 
 
-class SPP(nn.Module):
+class PPYoloESPP(nn.Module):
     def __init__(
     def __init__(
         self,
         self,
         in_channels: int,
         in_channels: int,
@@ -52,7 +52,7 @@ class CSPStage(nn.Module):
         for i in range(n):
         for i in range(n):
             convs.append((str(i), CSPResNetBasicBlock(next_ch_in, ch_mid, activation_type=activation_type, use_residual_connection=False)))
             convs.append((str(i), CSPResNetBasicBlock(next_ch_in, ch_mid, activation_type=activation_type, use_residual_connection=False)))
             if i == (n - 1) // 2 and spp:
             if i == (n - 1) // 2 and spp:
-                convs.append(("spp", SPP(ch_mid, ch_mid, 1, (5, 9, 13), activation_type=activation_type)))
+                convs.append(("spp", PPYoloESPP(ch_mid, ch_mid, 1, (5, 9, 13), activation_type=activation_type)))
             next_ch_in = ch_mid
             next_ch_in = ch_mid
 
 
         self.convs = nn.Sequential(collections.OrderedDict(convs))
         self.convs = nn.Sequential(collections.OrderedDict(convs))
@@ -68,7 +68,7 @@ class CSPStage(nn.Module):
 
 
 
 
 @register_detection_module()
 @register_detection_module()
-class CustomCSPPAN(nn.Module):
+class PPYoloECSPPAN(nn.Module):
     @resolve_param("activation", ActivationsTypeFactory())
     @resolve_param("activation", ActivationsTypeFactory())
     def __init__(
     def __init__(
         self,
         self,
Discard
@@ -1,6 +1,7 @@
 from typing import Union, Optional, List
 from typing import Union, Optional, List
 
 
 from torch import Tensor
 from torch import Tensor
+
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.factories.processing_factory import ProcessingFactory
 from super_gradients.common.factories.processing_factory import ProcessingFactory
 from super_gradients.common.registry.registry import register_model
 from super_gradients.common.registry.registry import register_model
@@ -8,7 +9,7 @@ from super_gradients.common.object_names import Models
 from super_gradients.modules import RepVGGBlock
 from super_gradients.modules import RepVGGBlock
 from super_gradients.training.models.sg_module import SgModule
 from super_gradients.training.models.sg_module import SgModule
 from super_gradients.training.models.detection_models.csp_resnet import CSPResNetBackbone
 from super_gradients.training.models.detection_models.csp_resnet import CSPResNetBackbone
-from super_gradients.training.models.detection_models.pp_yolo_e.pan import CustomCSPPAN
+from super_gradients.training.models.detection_models.pp_yolo_e.pan import PPYoloECSPPAN
 from super_gradients.training.models.detection_models.pp_yolo_e.pp_yolo_head import PPYOLOEHead
 from super_gradients.training.models.detection_models.pp_yolo_e.pp_yolo_head import PPYOLOEHead
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.models.arch_params_factory import get_arch_params
 from super_gradients.training.models.arch_params_factory import get_arch_params
@@ -26,7 +27,7 @@ class PPYoloE(SgModule):
             arch_params = arch_params.to_dict()
             arch_params = arch_params.to_dict()
 
 
         self.backbone = CSPResNetBackbone(**arch_params["backbone"], depth_mult=arch_params["depth_mult"], width_mult=arch_params["width_mult"])
         self.backbone = CSPResNetBackbone(**arch_params["backbone"], depth_mult=arch_params["depth_mult"], width_mult=arch_params["width_mult"])
-        self.neck = CustomCSPPAN(**arch_params["neck"], depth_mult=arch_params["depth_mult"], width_mult=arch_params["width_mult"])
+        self.neck = PPYoloECSPPAN(**arch_params["neck"], depth_mult=arch_params["depth_mult"], width_mult=arch_params["width_mult"])
         self.head = PPYOLOEHead(**arch_params["head"], width_mult=arch_params["width_mult"], num_classes=arch_params["num_classes"])
         self.head = PPYOLOEHead(**arch_params["head"], width_mult=arch_params["width_mult"], num_classes=arch_params["num_classes"])
 
 
         self._class_names: Optional[List[str]] = None
         self._class_names: Optional[List[str]] = None
Discard
@@ -175,11 +175,12 @@ class PPYOLOEHead(nn.Module):
     @torch.jit.ignore
     @torch.jit.ignore
     def replace_num_classes(self, num_classes: int):
     def replace_num_classes(self, num_classes: int):
         bias_cls = bias_init_with_prob(0.01)
         bias_cls = bias_init_with_prob(0.01)
+        device = self.pred_cls[0].weight.device
         self.pred_cls = nn.ModuleList()
         self.pred_cls = nn.ModuleList()
         self.num_classes = num_classes
         self.num_classes = num_classes
 
 
         for in_c in self.in_channels:
         for in_c in self.in_channels:
-            predict_layer = nn.Conv2d(in_c, num_classes, 3, padding=1)
+            predict_layer = nn.Conv2d(in_c, num_classes, 3, padding=1, device=device)
             torch.nn.init.constant_(predict_layer.weight, 0.0)
             torch.nn.init.constant_(predict_layer.weight, 0.0)
             torch.nn.init.constant_(predict_layer.bias, bias_cls)
             torch.nn.init.constant_(predict_layer.bias, bias_cls)
             self.pred_cls.append(predict_layer)
             self.pred_cls.append(predict_layer)
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
  1. from super_gradients.training.models.detection_models.yolo_nas.dfl_heads import YoloNASDFLHead, NDFLHeads
  2. from super_gradients.training.models.detection_models.yolo_nas.panneck import YoloNASPANNeckWithC2
  3. from super_gradients.training.models.detection_models.yolo_nas.yolo_stages import (
  4. YoloNASStage,
  5. YoloNASStem,
  6. YoloNASDownStage,
  7. YoloNASUpStage,
  8. YoloNASBottleneck,
  9. )
  10. from super_gradients.training.models.detection_models.yolo_nas.yolo_nas_variants import YoloNAS_S, YoloNAS_M, YoloNAS_L
  11. __all__ = [
  12. "YoloNASBottleneck",
  13. "YoloNASUpStage",
  14. "YoloNASDownStage",
  15. "YoloNASStem",
  16. "YoloNASStage",
  17. "NDFLHeads",
  18. "YoloNASDFLHead",
  19. "YoloNASPANNeckWithC2",
  20. "YoloNAS_S",
  21. "YoloNAS_M",
  22. "YoloNAS_L",
  23. ]
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
  1. import math
  2. from typing import Tuple, Union, List, Callable, Optional
  3. import torch
  4. from omegaconf import DictConfig
  5. from torch import nn, Tensor
  6. import super_gradients.common.factories.detection_modules_factory as det_factory
  7. from super_gradients.common.registry import register_detection_module
  8. from super_gradients.modules import ConvBNReLU
  9. from super_gradients.modules.base_modules import BaseDetectionModule
  10. from super_gradients.module_interfaces import SupportsReplaceNumClasses
  11. from super_gradients.modules.utils import width_multiplier
  12. from super_gradients.training.models.detection_models.pp_yolo_e.pp_yolo_head import generate_anchors_for_grid_cell
  13. from super_gradients.training.utils import HpmStruct, torch_version_is_greater_or_equal
  14. from super_gradients.training.utils.bbox_utils import batch_distance2bbox
  15. @register_detection_module()
  16. class YoloNASDFLHead(BaseDetectionModule, SupportsReplaceNumClasses):
  17. def __init__(self, in_channels: int, inter_channels: int, width_mult: float, first_conv_group_size: int, num_classes: int, stride: int, reg_max: int):
  18. """
  19. Initialize the YoloNASDFLHead
  20. :param in_channels: Input channels
  21. :param inter_channels: Intermediate number of channels
  22. :param width_mult: Width multiplier
  23. :param first_conv_group_size: Group size
  24. :param num_classes: Number of detection classes
  25. :param stride: Output stride for this head
  26. :param reg_max: Number of bins in the regression head
  27. """
  28. super().__init__(in_channels)
  29. inter_channels = width_multiplier(inter_channels, width_mult, 8)
  30. if first_conv_group_size == 0:
  31. groups = 0
  32. elif first_conv_group_size == -1:
  33. groups = 1
  34. else:
  35. groups = inter_channels // first_conv_group_size
  36. self.num_classes = num_classes
  37. self.stem = ConvBNReLU(in_channels, inter_channels, kernel_size=1, stride=1, padding=0, bias=False)
  38. first_cls_conv = [ConvBNReLU(inter_channels, inter_channels, kernel_size=3, stride=1, padding=1, groups=groups, bias=False)] if groups else []
  39. self.cls_convs = nn.Sequential(*first_cls_conv, ConvBNReLU(inter_channels, inter_channels, kernel_size=3, stride=1, padding=1, bias=False))
  40. first_reg_conv = [ConvBNReLU(inter_channels, inter_channels, kernel_size=3, stride=1, padding=1, groups=groups, bias=False)] if groups else []
  41. self.reg_convs = nn.Sequential(*first_reg_conv, ConvBNReLU(inter_channels, inter_channels, kernel_size=3, stride=1, padding=1, bias=False))
  42. self.cls_pred = nn.Conv2d(inter_channels, self.num_classes, 1, 1, 0)
  43. self.reg_pred = nn.Conv2d(inter_channels, 4 * (reg_max + 1), 1, 1, 0)
  44. self.grid = torch.zeros(1)
  45. self.stride = stride
  46. self.prior_prob = 1e-2
  47. self._initialize_biases()
  48. def replace_num_classes(self, num_classes: int, compute_new_weights_fn: Callable[[nn.Module, int], nn.Module]):
  49. self.cls_pred = compute_new_weights_fn(self.cls_pred, num_classes)
  50. self.num_classes = num_classes
  51. @property
  52. def out_channels(self):
  53. return None
  54. def forward(self, x):
  55. x = self.stem(x)
  56. cls_feat = self.cls_convs(x)
  57. cls_output = self.cls_pred(cls_feat)
  58. reg_feat = self.reg_convs(x)
  59. reg_output = self.reg_pred(reg_feat)
  60. return reg_output, cls_output
  61. def _initialize_biases(self):
  62. prior_bias = -math.log((1 - self.prior_prob) / self.prior_prob)
  63. torch.nn.init.constant_(self.cls_pred.bias, prior_bias)
  64. @staticmethod
  65. def _make_grid(nx=20, ny=20):
  66. if torch_version_is_greater_or_equal(1, 10):
  67. # https://github.com/pytorch/pytorch/issues/50276
  68. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij")
  69. else:
  70. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  71. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  72. @register_detection_module()
  73. class NDFLHeads(BaseDetectionModule, SupportsReplaceNumClasses):
  74. def __init__(
  75. self,
  76. num_classes: int,
  77. in_channels: Tuple[int, int, int],
  78. heads_list: Union[str, HpmStruct, DictConfig],
  79. grid_cell_scale: float = 5.0,
  80. grid_cell_offset: float = 0.5,
  81. reg_max: int = 16,
  82. eval_size: Optional[Tuple[int, int]] = None,
  83. width_mult: float = 1.0,
  84. ):
  85. """
  86. Initializes the NDFLHeads module.
  87. :param num_classes: Number of detection classes
  88. :param in_channels: Number of channels for each feature map (See width_mult)
  89. :param grid_cell_scale:
  90. :param grid_cell_offset:
  91. :param reg_max: Number of bins in the regression head
  92. :param eval_size: (rows, cols) Size of the image for evaluation. Setting this value can be beneficial for inference speed,
  93. since anchors will not be regenerated for each forward call.
  94. :param width_mult: A scaling factor applied to in_channels.
  95. """
  96. super(NDFLHeads, self).__init__(in_channels)
  97. in_channels = [max(round(c * width_mult), 1) for c in in_channels]
  98. self.in_channels = tuple(in_channels)
  99. self.num_classes = num_classes
  100. self.grid_cell_scale = grid_cell_scale
  101. self.grid_cell_offset = grid_cell_offset
  102. self.reg_max = reg_max
  103. self.eval_size = eval_size
  104. # Do not apply quantization to this tensor
  105. proj = torch.linspace(0, self.reg_max, self.reg_max + 1).reshape([1, self.reg_max + 1, 1, 1])
  106. self.register_buffer("proj_conv", proj, persistent=False)
  107. self._init_weights()
  108. factory = det_factory.DetectionModulesFactory()
  109. heads_list = self._pass_args(heads_list, factory, num_classes, reg_max)
  110. self.num_heads = len(heads_list)
  111. fpn_strides: List[int] = []
  112. for i in range(self.num_heads):
  113. new_head = factory.get(factory.insert_module_param(heads_list[i], "in_channels", in_channels[i]))
  114. fpn_strides.append(new_head.stride)
  115. setattr(self, f"head{i + 1}", new_head)
  116. self.fpn_strides = tuple(fpn_strides)
  117. def replace_num_classes(self, num_classes: int, compute_new_weights_fn: Callable[[nn.Module, int], nn.Module]):
  118. for i in range(self.num_heads):
  119. head = getattr(self, f"head{i + 1}")
  120. head.replace_num_classes(num_classes, compute_new_weights_fn)
  121. self.num_classes = num_classes
  122. @staticmethod
  123. def _pass_args(heads_list, factory, num_classes, reg_max):
  124. for i in range(len(heads_list)):
  125. heads_list[i] = factory.insert_module_param(heads_list[i], "num_classes", num_classes)
  126. heads_list[i] = factory.insert_module_param(heads_list[i], "reg_max", reg_max)
  127. return heads_list
  128. @torch.jit.ignore
  129. def cache_anchors(self, input_size: Tuple[int, int]):
  130. self.eval_size = input_size
  131. anchor_points, stride_tensor = self._generate_anchors()
  132. self.anchor_points = anchor_points
  133. self.stride_tensor = stride_tensor
  134. @torch.jit.ignore
  135. def _init_weights(self):
  136. if self.eval_size:
  137. anchor_points, stride_tensor = self._generate_anchors()
  138. self.anchor_points = anchor_points
  139. self.stride_tensor = stride_tensor
  140. @torch.jit.ignore
  141. def forward_train(self, feats: Tuple[Tensor, ...]):
  142. anchors, anchor_points, num_anchors_list, stride_tensor = generate_anchors_for_grid_cell(
  143. feats, self.fpn_strides, self.grid_cell_scale, self.grid_cell_offset
  144. )
  145. cls_score_list, reg_distri_list = [], []
  146. for i, feat in enumerate(feats):
  147. reg_distri, cls_logit = getattr(self, f"head{i + 1}")(feat)
  148. # cls and reg
  149. # Note we don't apply sigmoid on class predictions to ensure good numerical stability at loss computation
  150. cls_score_list.append(torch.permute(cls_logit.flatten(2), [0, 2, 1]))
  151. reg_distri_list.append(torch.permute(reg_distri.flatten(2), [0, 2, 1]))
  152. cls_score_list = torch.cat(cls_score_list, dim=1)
  153. reg_distri_list = torch.cat(reg_distri_list, dim=1)
  154. return cls_score_list, reg_distri_list, anchors, anchor_points, num_anchors_list, stride_tensor
  155. def forward_eval(self, feats: Tuple[Tensor, ...]) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]:
  156. cls_score_list, reg_distri_list, reg_dist_reduced_list = [], [], []
  157. for i, feat in enumerate(feats):
  158. b, _, h, w = feat.shape
  159. height_mul_width = h * w
  160. reg_distri, cls_logit = getattr(self, f"head{i + 1}")(feat)
  161. reg_distri_list.append(torch.permute(reg_distri.flatten(2), [0, 2, 1]))
  162. reg_dist_reduced = torch.permute(reg_distri.reshape([-1, 4, self.reg_max + 1, height_mul_width]), [0, 2, 3, 1])
  163. reg_dist_reduced = torch.nn.functional.conv2d(torch.nn.functional.softmax(reg_dist_reduced, dim=1), weight=self.proj_conv).squeeze(1)
  164. # cls and reg
  165. cls_score_list.append(cls_logit.reshape([b, self.num_classes, height_mul_width]))
  166. reg_dist_reduced_list.append(reg_dist_reduced)
  167. cls_score_list = torch.cat(cls_score_list, dim=-1) # [B, C, Anchors]
  168. cls_score_list = torch.permute(cls_score_list, [0, 2, 1]) # # [B, Anchors, C]
  169. reg_distri_list = torch.cat(reg_distri_list, dim=1) # [B, Anchors, 4 * (self.reg_max + 1)]
  170. reg_dist_reduced_list = torch.cat(reg_dist_reduced_list, dim=1) # [B, Anchors, 4]
  171. # Decode bboxes
  172. # Note in eval mode, anchor_points_inference is different from anchor_points computed on train
  173. if self.eval_size:
  174. anchor_points_inference, stride_tensor = self.anchor_points, self.stride_tensor
  175. else:
  176. anchor_points_inference, stride_tensor = self._generate_anchors(feats)
  177. pred_scores = cls_score_list.sigmoid()
  178. pred_bboxes = batch_distance2bbox(anchor_points_inference, reg_dist_reduced_list) * stride_tensor # [B, Anchors, 4]
  179. decoded_predictions = pred_bboxes, pred_scores
  180. if torch.jit.is_tracing():
  181. return decoded_predictions
  182. anchors, anchor_points, num_anchors_list, _ = generate_anchors_for_grid_cell(feats, self.fpn_strides, self.grid_cell_scale, self.grid_cell_offset)
  183. raw_predictions = cls_score_list, reg_distri_list, anchors, anchor_points, num_anchors_list, stride_tensor
  184. return decoded_predictions, raw_predictions
  185. @property
  186. def out_channels(self):
  187. return None
  188. def forward(self, feats: Tuple[Tensor]):
  189. if self.training:
  190. return self.forward_train(feats)
  191. else:
  192. return self.forward_eval(feats)
  193. def _generate_anchors(self, feats=None, dtype=torch.float):
  194. # just use in eval time
  195. anchor_points = []
  196. stride_tensor = []
  197. for i, stride in enumerate(self.fpn_strides):
  198. if feats is not None:
  199. _, _, h, w = feats[i].shape
  200. else:
  201. h = int(self.eval_size[0] / stride)
  202. w = int(self.eval_size[1] / stride)
  203. shift_x = torch.arange(end=w) + self.grid_cell_offset
  204. shift_y = torch.arange(end=h) + self.grid_cell_offset
  205. if torch_version_is_greater_or_equal(1, 10):
  206. shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing="ij")
  207. else:
  208. shift_y, shift_x = torch.meshgrid(shift_y, shift_x)
  209. anchor_point = torch.stack([shift_x, shift_y], dim=-1).to(dtype=dtype)
  210. anchor_points.append(anchor_point.reshape([-1, 2]))
  211. stride_tensor.append(torch.full([h * w, 1], stride, dtype=dtype))
  212. anchor_points = torch.cat(anchor_points)
  213. stride_tensor = torch.cat(stride_tensor)
  214. if feats is not None:
  215. anchor_points = anchor_points.to(feats[0].device)
  216. stride_tensor = stride_tensor.to(feats[0].device)
  217. return anchor_points, stride_tensor
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  1. from typing import Union, List, Tuple
  2. from omegaconf import DictConfig
  3. from torch import Tensor
  4. from super_gradients.common.registry import register_detection_module
  5. from super_gradients.modules.detection_modules import BaseDetectionModule
  6. from super_gradients.training.utils.utils import HpmStruct
  7. import super_gradients.common.factories.detection_modules_factory as det_factory
  8. @register_detection_module("YoloNASPANNeckWithC2")
  9. class YoloNASPANNeckWithC2(BaseDetectionModule):
  10. """
  11. A PAN (path aggregation network) neck with 4 stages (2 up-sampling and 2 down-sampling stages)
  12. where the up-sampling stages include a higher resolution skip
  13. Returns outputs of neck stage 2, stage 3, stage 4
  14. """
  15. def __init__(
  16. self,
  17. in_channels: List[int],
  18. neck1: Union[str, HpmStruct, DictConfig],
  19. neck2: Union[str, HpmStruct, DictConfig],
  20. neck3: Union[str, HpmStruct, DictConfig],
  21. neck4: Union[str, HpmStruct, DictConfig],
  22. ):
  23. """
  24. Initialize the PAN neck
  25. :param in_channels: Input channels of the 4 feature maps from the backbone
  26. :param neck1: First neck stage config
  27. :param neck2: Second neck stage config
  28. :param neck3: Third neck stage config
  29. :param neck4: Fourth neck stage config
  30. """
  31. super().__init__(in_channels)
  32. c2_out_channels, c3_out_channels, c4_out_channels, c5_out_channels = in_channels
  33. factory = det_factory.DetectionModulesFactory()
  34. self.neck1 = factory.get(factory.insert_module_param(neck1, "in_channels", [c5_out_channels, c4_out_channels, c3_out_channels]))
  35. self.neck2 = factory.get(factory.insert_module_param(neck2, "in_channels", [self.neck1.out_channels[1], c3_out_channels, c2_out_channels]))
  36. self.neck3 = factory.get(factory.insert_module_param(neck3, "in_channels", [self.neck2.out_channels[1], self.neck2.out_channels[0]]))
  37. self.neck4 = factory.get(factory.insert_module_param(neck4, "in_channels", [self.neck3.out_channels, self.neck1.out_channels[0]]))
  38. self._out_channels = [
  39. self.neck2.out_channels[1],
  40. self.neck3.out_channels,
  41. self.neck4.out_channels,
  42. ]
  43. @property
  44. def out_channels(self):
  45. return self._out_channels
  46. def forward(self, inputs: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
  47. c2, c3, c4, c5 = inputs
  48. x_n1_inter, x = self.neck1([c5, c4, c3])
  49. x_n2_inter, p3 = self.neck2([x, c3, c2])
  50. p4 = self.neck3([p3, x_n2_inter])
  51. p5 = self.neck4([p4, x_n1_inter])
  52. return p3, p4, p5
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
  1. import copy
  2. from typing import Union
  3. from omegaconf import DictConfig
  4. from super_gradients.common.object_names import Models
  5. from super_gradients.common.registry import register_model
  6. from super_gradients.training.models.arch_params_factory import get_arch_params
  7. from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector
  8. from super_gradients.training.utils import HpmStruct, get_param
  9. from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback
  10. @register_model(Models.YOLO_NAS_S)
  11. class YoloNAS_S(CustomizableDetector):
  12. def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
  13. default_arch_params = get_arch_params("yolo_nas_s_arch_params")
  14. merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params))
  15. merged_arch_params.override(**arch_params.to_dict())
  16. super().__init__(
  17. backbone=merged_arch_params.backbone,
  18. neck=merged_arch_params.neck,
  19. heads=merged_arch_params.heads,
  20. num_classes=get_param(merged_arch_params, "num_classes", None),
  21. in_channels=in_channels,
  22. bn_momentum=get_param(merged_arch_params, "bn_momentum", None),
  23. bn_eps=get_param(merged_arch_params, "bn_eps", None),
  24. inplace_act=get_param(merged_arch_params, "inplace_act", None),
  25. )
  26. @staticmethod
  27. def get_post_prediction_callback(conf: float, iou: float) -> PPYoloEPostPredictionCallback:
  28. return PPYoloEPostPredictionCallback(score_threshold=conf, nms_threshold=iou, nms_top_k=1000, max_predictions=300)
  29. @property
  30. def num_classes(self):
  31. return self.heads.num_classes
  32. @register_model(Models.YOLO_NAS_M)
  33. class YoloNAS_M(CustomizableDetector):
  34. def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
  35. default_arch_params = get_arch_params("yolo_nas_m_arch_params")
  36. merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params))
  37. merged_arch_params.override(**arch_params.to_dict())
  38. super().__init__(
  39. backbone=merged_arch_params.backbone,
  40. neck=merged_arch_params.neck,
  41. heads=merged_arch_params.heads,
  42. num_classes=get_param(merged_arch_params, "num_classes", None),
  43. in_channels=in_channels,
  44. bn_momentum=get_param(merged_arch_params, "bn_momentum", None),
  45. bn_eps=get_param(merged_arch_params, "bn_eps", None),
  46. inplace_act=get_param(merged_arch_params, "inplace_act", None),
  47. )
  48. @staticmethod
  49. def get_post_prediction_callback(conf: float, iou: float) -> PPYoloEPostPredictionCallback:
  50. return PPYoloEPostPredictionCallback(score_threshold=conf, nms_threshold=iou, nms_top_k=1000, max_predictions=300)
  51. @property
  52. def num_classes(self):
  53. return self.heads.num_classes
  54. @register_model(Models.YOLO_NAS_L)
  55. class YoloNAS_L(CustomizableDetector):
  56. def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
  57. default_arch_params = get_arch_params("yolo_nas_l_arch_params")
  58. merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params))
  59. merged_arch_params.override(**arch_params.to_dict())
  60. super().__init__(
  61. backbone=merged_arch_params.backbone,
  62. neck=merged_arch_params.neck,
  63. heads=merged_arch_params.heads,
  64. num_classes=get_param(merged_arch_params, "num_classes", None),
  65. in_channels=in_channels,
  66. bn_momentum=get_param(merged_arch_params, "bn_momentum", None),
  67. bn_eps=get_param(merged_arch_params, "bn_eps", None),
  68. inplace_act=get_param(merged_arch_params, "inplace_act", None),
  69. )
  70. @staticmethod
  71. def get_post_prediction_callback(conf: float, iou: float) -> PPYoloEPostPredictionCallback:
  72. return PPYoloEPostPredictionCallback(score_threshold=conf, nms_threshold=iou, nms_top_k=1000, max_predictions=300)
  73. @property
  74. def num_classes(self):
  75. return self.heads.num_classes
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
  1. from functools import partial
  2. from typing import Type, List
  3. import torch
  4. from torch import nn, Tensor
  5. from super_gradients.common.registry import register_detection_module
  6. from super_gradients.modules import Residual, BaseDetectionModule
  7. from super_gradients.common.decorators.factory_decorator import resolve_param
  8. from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
  9. from super_gradients.modules import QARepVGGBlock, Conv
  10. from super_gradients.modules.utils import width_multiplier
  11. __all__ = ["YoloNASStage", "YoloNASUpStage", "YoloNASStem", "YoloNASDownStage", "YoloNASBottleneck"]
  12. class YoloNASBottleneck(nn.Module):
  13. """
  14. A bottleneck block for YoloNAS. Consists of two consecutive blocks and optional residual connection.
  15. """
  16. def __init__(
  17. self, input_channels: int, output_channels: int, block_type: Type[nn.Module], activation_type: Type[nn.Module], shortcut: bool, use_alpha: bool
  18. ):
  19. """
  20. Initialize the YoloNASBottleneck block
  21. :param input_channels: Number of input channels
  22. :param output_channels: Number of output channels
  23. :param block_type: Type of the convolutional block
  24. :param activation_type: Activation type for the convolutional block
  25. :param shortcut: If True, adds the residual connection from input to output.
  26. :param use_alpha: If True, adds the learnable alpha parameter (multiplier for the residual connection).
  27. """
  28. super().__init__()
  29. self.cv1 = block_type(input_channels, output_channels, activation_type=activation_type)
  30. self.cv2 = block_type(output_channels, output_channels, activation_type=activation_type)
  31. self.add = shortcut and input_channels == output_channels
  32. self.shortcut = Residual() if self.add else None
  33. if use_alpha:
  34. self.alpha = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True)
  35. else:
  36. self.alpha = 1.0
  37. def forward(self, x):
  38. return self.alpha * self.shortcut(x) + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  39. class SequentialWithIntermediates(nn.Sequential):
  40. """
  41. A Sequential module that can return all intermediate values as a list of Tensors
  42. """
  43. def __init__(self, output_intermediates: bool, *args):
  44. super(SequentialWithIntermediates, self).__init__(*args)
  45. self.output_intermediates = output_intermediates
  46. def forward(self, input: Tensor) -> List[Tensor]:
  47. if self.output_intermediates:
  48. output = [input]
  49. for module in self:
  50. output.append(module(output[-1]))
  51. return output
  52. # For uniformity, we return a list even if we don't output intermediates
  53. return [super(SequentialWithIntermediates, self).forward(input)]
  54. class YoloNASCSPLayer(nn.Module):
  55. """
  56. Cross-stage layer module for YoloNAS.
  57. """
  58. def __init__(
  59. self,
  60. in_channels: int,
  61. out_channels: int,
  62. num_bottlenecks: int,
  63. block_type: Type[nn.Module],
  64. activation_type: Type[nn.Module],
  65. shortcut: bool = True,
  66. use_alpha: bool = True,
  67. expansion: float = 0.5,
  68. hidden_channels: int = None,
  69. concat_intermediates: bool = False,
  70. ):
  71. """
  72. :param in_channels: Number of input channels.
  73. :param out_channels: Number of output channels.
  74. :param num_bottlenecks: Number of bottleneck blocks.
  75. :param block_type: Bottleneck block type.
  76. :param activation_type: Activation type for all blocks.
  77. :param shortcut: If True, adds the residual connection from input to output.
  78. :param use_alpha: If True, adds the learnable alpha parameter (multiplier for the residual connection).
  79. :param expansion: If hidden_channels is None, hidden_channels is set to in_channels * expansion.
  80. :param hidden_channels: If not None, sets the number of hidden channels used inside the bottleneck blocks.
  81. :param concat_intermediates:
  82. """
  83. super(YoloNASCSPLayer, self).__init__()
  84. if hidden_channels is None:
  85. hidden_channels = int(out_channels * expansion)
  86. self.conv1 = Conv(in_channels, hidden_channels, 1, stride=1, activation_type=activation_type)
  87. self.conv2 = Conv(in_channels, hidden_channels, 1, stride=1, activation_type=activation_type)
  88. self.conv3 = Conv(hidden_channels * (2 + concat_intermediates * num_bottlenecks), out_channels, 1, stride=1, activation_type=activation_type)
  89. module_list = [YoloNASBottleneck(hidden_channels, hidden_channels, block_type, activation_type, shortcut, use_alpha) for _ in range(num_bottlenecks)]
  90. self.bottlenecks = SequentialWithIntermediates(concat_intermediates, *module_list)
  91. def forward(self, x: Tensor) -> Tensor:
  92. x_1 = self.conv1(x)
  93. x_1 = self.bottlenecks(x_1)
  94. x_2 = self.conv2(x)
  95. x = torch.cat((*x_1, x_2), dim=1)
  96. return self.conv3(x)
  97. @register_detection_module()
  98. class YoloNASStem(BaseDetectionModule):
  99. """
  100. Stem module for YoloNAS. Consists of a single QARepVGGBlock with stride of two.
  101. """
  102. def __init__(self, in_channels: int, out_channels: int):
  103. """
  104. Initialize the YoloNASStem module
  105. :param in_channels: Number of input channels
  106. :param out_channels: Number of output channels
  107. """
  108. super().__init__(in_channels)
  109. self._out_channels = out_channels
  110. self.conv = QARepVGGBlock(in_channels, out_channels, stride=2, use_residual_connection=False)
  111. @property
  112. def out_channels(self):
  113. return self._out_channels
  114. def forward(self, x: Tensor) -> Tensor:
  115. return self.conv(x)
  116. @register_detection_module()
  117. class YoloNASStage(BaseDetectionModule):
  118. """
  119. A single stage module for YoloNAS. It consists of a downsample block (QARepVGGBlock) followed by YoloNASCSPLayer.
  120. """
  121. @resolve_param("activation_type", ActivationsTypeFactory())
  122. def __init__(
  123. self,
  124. in_channels: int,
  125. out_channels: int,
  126. num_blocks: int,
  127. activation_type: Type[nn.Module],
  128. hidden_channels: int = None,
  129. concat_intermediates: bool = False,
  130. ):
  131. """
  132. Initialize the YoloNASStage module
  133. :param in_channels: Number of input channels
  134. :param out_channels: Number of output channels
  135. :param num_blocks: Number of bottleneck blocks in the YoloNASCSPLayer
  136. :param activation_type: Activation type for all blocks
  137. :param hidden_channels: If not None, sets the number of hidden channels used inside the bottleneck blocks.
  138. :param concat_intermediates: If True, concatenates the intermediate values from the YoloNASCSPLayer.
  139. """
  140. super().__init__(in_channels)
  141. self._out_channels = out_channels
  142. self.downsample = QARepVGGBlock(in_channels, out_channels, stride=2, activation_type=activation_type, use_residual_connection=False)
  143. self.blocks = YoloNASCSPLayer(
  144. out_channels,
  145. out_channels,
  146. num_blocks,
  147. QARepVGGBlock,
  148. activation_type,
  149. True,
  150. hidden_channels=hidden_channels,
  151. concat_intermediates=concat_intermediates,
  152. )
  153. @property
  154. def out_channels(self):
  155. return self._out_channels
  156. def forward(self, x):
  157. return self.blocks(self.downsample(x))
  158. @register_detection_module()
  159. class YoloNASUpStage(BaseDetectionModule):
  160. """
  161. Upsampling stage for YoloNAS.
  162. """
  163. @resolve_param("activation_type", ActivationsTypeFactory())
  164. def __init__(
  165. self,
  166. in_channels: List[int],
  167. out_channels: int,
  168. width_mult: float,
  169. num_blocks: int,
  170. depth_mult: float,
  171. activation_type: Type[nn.Module],
  172. hidden_channels: int = None,
  173. concat_intermediates: bool = False,
  174. reduce_channels: bool = False,
  175. ):
  176. """
  177. Initialize the YoloNASUpStage module
  178. :param in_channels: Number of input channels
  179. :param out_channels: Number of output channels
  180. :param width_mult: Multiplier for the number of channels in the stage.
  181. :param num_blocks: Number of bottleneck blocks
  182. :param depth_mult: Multiplier for the number of blocks in the stage.
  183. :param activation_type: Activation type for all blocks
  184. :param hidden_channels: If not None, sets the number of hidden channels used inside the bottleneck blocks
  185. :param concat_intermediates:
  186. :param reduce_channels:
  187. """
  188. super().__init__(in_channels)
  189. num_inputs = len(in_channels)
  190. if num_inputs == 2:
  191. in_channels, skip_in_channels = in_channels
  192. else:
  193. in_channels, skip_in_channels1, skip_in_channels2 = in_channels
  194. skip_in_channels = skip_in_channels1 + out_channels # skip2 downsample results in out_channels channels
  195. out_channels = width_multiplier(out_channels, width_mult, 8)
  196. num_blocks = max(round(num_blocks * depth_mult), 1) if num_blocks > 1 else num_blocks
  197. if num_inputs == 2:
  198. self.reduce_skip = Conv(skip_in_channels, out_channels, 1, 1, activation_type) if reduce_channels else nn.Identity()
  199. else:
  200. self.reduce_skip1 = Conv(skip_in_channels1, out_channels, 1, 1, activation_type) if reduce_channels else nn.Identity()
  201. self.reduce_skip2 = Conv(skip_in_channels2, out_channels, 1, 1, activation_type) if reduce_channels else nn.Identity()
  202. self.conv = Conv(in_channels, out_channels, 1, 1, activation_type)
  203. self.upsample = nn.ConvTranspose2d(in_channels=out_channels, out_channels=out_channels, kernel_size=2, stride=2)
  204. if num_inputs == 3:
  205. self.downsample = Conv(out_channels if reduce_channels else skip_in_channels2, out_channels, kernel=3, stride=2, activation_type=activation_type)
  206. self.reduce_after_concat = Conv(num_inputs * out_channels, out_channels, 1, 1, activation_type) if reduce_channels else nn.Identity()
  207. after_concat_channels = out_channels if reduce_channels else out_channels + skip_in_channels
  208. self.blocks = YoloNASCSPLayer(
  209. after_concat_channels,
  210. out_channels,
  211. num_blocks,
  212. QARepVGGBlock,
  213. activation_type,
  214. hidden_channels=hidden_channels,
  215. concat_intermediates=concat_intermediates,
  216. )
  217. self._out_channels = [out_channels, out_channels]
  218. @property
  219. def out_channels(self):
  220. return self._out_channels
  221. def forward(self, inputs):
  222. if len(inputs) == 2:
  223. x, skip_x = inputs
  224. skip_x = [self.reduce_skip(skip_x)]
  225. else:
  226. x, skip_x1, skip_x2 = inputs
  227. skip_x1, skip_x2 = self.reduce_skip1(skip_x1), self.reduce_skip2(skip_x2)
  228. skip_x = [skip_x1, self.downsample(skip_x2)]
  229. x_inter = self.conv(x)
  230. x = self.upsample(x_inter)
  231. x = torch.cat([x, *skip_x], 1)
  232. x = self.reduce_after_concat(x)
  233. x = self.blocks(x)
  234. return x_inter, x
  235. @register_detection_module()
  236. class YoloNASDownStage(BaseDetectionModule):
  237. @resolve_param("activation_type", ActivationsTypeFactory())
  238. def __init__(
  239. self,
  240. in_channels: List[int],
  241. out_channels: int,
  242. width_mult: float,
  243. num_blocks: int,
  244. depth_mult: float,
  245. activation_type: Type[nn.Module],
  246. hidden_channels: int = None,
  247. concat_intermediates: bool = False,
  248. ):
  249. """
  250. Initializes a YoloNASDownStage.
  251. :param in_channels: Number of input channels.
  252. :param out_channels: Number of output channels.
  253. :param width_mult: Multiplier for the number of channels in the stage.
  254. :param num_blocks: Number of blocks in the stage.
  255. :param depth_mult: Multiplier for the number of blocks in the stage.
  256. :param activation_type: Type of activation to use inside the blocks.
  257. :param hidden_channels: If not None, sets the number of hidden channels used inside the bottleneck blocks.
  258. :param concat_intermediates:
  259. """
  260. super().__init__(in_channels)
  261. in_channels, skip_in_channels = in_channels
  262. out_channels = width_multiplier(out_channels, width_mult, 8)
  263. num_blocks = max(round(num_blocks * depth_mult), 1) if num_blocks > 1 else num_blocks
  264. self.conv = Conv(in_channels, out_channels // 2, 3, 2, activation_type)
  265. after_concat_channels = out_channels // 2 + skip_in_channels
  266. self.blocks = YoloNASCSPLayer(
  267. in_channels=after_concat_channels,
  268. out_channels=out_channels,
  269. num_bottlenecks=num_blocks,
  270. block_type=partial(Conv, kernel=3, stride=1),
  271. activation_type=activation_type,
  272. hidden_channels=hidden_channels,
  273. concat_intermediates=concat_intermediates,
  274. )
  275. self._out_channels = out_channels
  276. @property
  277. def out_channels(self):
  278. return self._out_channels
  279. def forward(self, inputs):
  280. x, skip_x = inputs
  281. x = self.conv(x)
  282. x = torch.cat([x, skip_x], 1)
  283. x = self.blocks(x)
  284. return x
Discard
@@ -52,7 +52,7 @@ class Pipeline(ABC):
     def __init__(self, model: SgModule, image_processor: Union[Processing, List[Processing]], class_names: List[str], device: Optional[str] = None):
     def __init__(self, model: SgModule, image_processor: Union[Processing, List[Processing]], class_names: List[str], device: Optional[str] = None):
         super().__init__()
         super().__init__()
         self.device = device or next(model.parameters()).device
         self.device = device or next(model.parameters()).device
-        self.model = model.to(device)
+        self.model = model.to(self.device)
         self.class_names = class_names
         self.class_names = class_names
 
 
         if isinstance(image_processor, list):
         if isinstance(image_processor, list):
@@ -265,7 +265,12 @@ class DetectionPipeline(Pipeline):
     def _combine_image_prediction_to_images(
     def _combine_image_prediction_to_images(
         self, images_predictions: Iterable[ImageDetectionPrediction], n_images: Optional[int] = None
         self, images_predictions: Iterable[ImageDetectionPrediction], n_images: Optional[int] = None
     ) -> ImagesDetectionPrediction:
     ) -> ImagesDetectionPrediction:
-        images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Images")]
+        if n_images is not None and n_images == 1:
+            # Do not show tqdm progress bar if there is only one image
+            pass
+        else:
+            images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Images")]
+
         return ImagesDetectionPrediction(_images_prediction_lst=images_predictions)
         return ImagesDetectionPrediction(_images_prediction_lst=images_predictions)
 
 
     def _combine_image_prediction_to_video(
     def _combine_image_prediction_to_video(
Discard
@@ -59,6 +59,10 @@ MODEL_URLS = {
     "ppyoloe_m_coco": "https://deci-pretrained-models.s3.amazonaws.com/ppyolo_e/coco2017_ppyoloe_m.pth",
     "ppyoloe_m_coco": "https://deci-pretrained-models.s3.amazonaws.com/ppyolo_e/coco2017_ppyoloe_m.pth",
     "ppyoloe_l_coco": "https://deci-pretrained-models.s3.amazonaws.com/ppyolo_e/coco2017_pp_yoloe_l_best_model_21uffbb8.pth",  # 0.4948
     "ppyoloe_l_coco": "https://deci-pretrained-models.s3.amazonaws.com/ppyolo_e/coco2017_pp_yoloe_l_best_model_21uffbb8.pth",  # 0.4948
     "ppyoloe_x_coco": "https://deci-pretrained-models.s3.amazonaws.com/ppyolo_e/coco2017_pp_yoloe_x_best_model_z03if91o.pth",  # 0.5115
     "ppyoloe_x_coco": "https://deci-pretrained-models.s3.amazonaws.com/ppyolo_e/coco2017_pp_yoloe_x_best_model_z03if91o.pth",  # 0.5115
+    #
+    "yolo_nas_s_coco": "https://deci-pretrained-models.s3.amazonaws.com/yolo_nas/yolo_nas_s_coco2017.pth",
+    "yolo_nas_m_coco": "https://deci-pretrained-models.s3.amazonaws.com/yolo_nas/yolo_nas_m_coco2017.pth",
+    "yolo_nas_l_coco": "https://deci-pretrained-models.s3.amazonaws.com/yolo_nas/yolo_nas_l_coco2017.pth",
 }
 }
 
 
 PRETRAINED_NUM_CLASSES = {
 PRETRAINED_NUM_CLASSES = {
Discard
@@ -305,8 +305,8 @@ def default_ppyoloe_coco_processing_params() -> dict:
     return params
     return params
 
 
 
 
-def default_deciyolo_coco_processing_params() -> dict:
-    """Processing parameters commonly used for training DeciYolo on COCO dataset.
+def default_yolo_nas_coco_processing_params() -> dict:
+    """Processing parameters commonly used for training YoloNAS on COCO dataset.
     TODO: remove once we load it from the checkpoint
     TODO: remove once we load it from the checkpoint
     """
     """
 
 
@@ -322,8 +322,8 @@ def default_deciyolo_coco_processing_params() -> dict:
     params = dict(
     params = dict(
         class_names=COCO_DETECTION_CLASSES_LIST,
         class_names=COCO_DETECTION_CLASSES_LIST,
         image_processor=image_processor,
         image_processor=image_processor,
-        iou=0.65,
-        conf=0.5,
+        iou=0.7,
+        conf=0.25,
     )
     )
     return params
     return params
 
 
@@ -337,6 +337,6 @@ def get_pretrained_processing_params(model_name: str, pretrained_weights: str) -
             return default_yolox_coco_processing_params()
             return default_yolox_coco_processing_params()
         elif "ppyoloe" in model_name:
         elif "ppyoloe" in model_name:
             return default_ppyoloe_coco_processing_params()
             return default_ppyoloe_coco_processing_params()
-        elif "deciyolo" in model_name:
-            return default_deciyolo_coco_processing_params()
+        elif "yolo_nas" in model_name:
+            return default_yolo_nas_coco_processing_params()
     return dict()
     return dict()
Discard
@@ -291,11 +291,21 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine
     :param pretrained_weights: name for the pretrianed weights (i.e imagenet)
     :param pretrained_weights: name for the pretrianed weights (i.e imagenet)
     :return: None
     :return: None
     """
     """
+    from super_gradients.common.object_names import Models
+
     model_url_key = architecture + "_" + str(pretrained_weights)
     model_url_key = architecture + "_" + str(pretrained_weights)
     if model_url_key not in MODEL_URLS.keys():
     if model_url_key not in MODEL_URLS.keys():
         raise MissingPretrainedWeightsException(model_url_key)
         raise MissingPretrainedWeightsException(model_url_key)
 
 
     url = MODEL_URLS[model_url_key]
     url = MODEL_URLS[model_url_key]
+
+    if architecture in {Models.YOLO_NAS_S, Models.YOLO_NAS_M, Models.YOLO_NAS_L}:
+        logger.info(
+            "License Notification: YOLO-NAS pre-trained weights are subjected to the specific license terms and conditions detailed in \n"
+            "https://github.com/Deci-AI/super-gradients/LICENSE.YOLONAS.md. \n"
+            "By downloading the pre-trained weight files you agree to comply with these terms."
+        )
+
     unique_filename = url.split("https://deci-pretrained-models.s3.amazonaws.com/")[1].replace("/", "_").replace(" ", "_")
     unique_filename = url.split("https://deci-pretrained-models.s3.amazonaws.com/")[1].replace("/", "_").replace(" ", "_")
     map_location = torch.device("cpu")
     map_location = torch.device("cpu")
     pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
     pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
Discard