Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#923 Fix the bug with YOLONAS of not supporting overriding in_channels

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-848-fix-in-channels
@@ -68,3 +68,13 @@ python -m super_gradients.train_from_recipe --config-name=roboflow_yolo_nas_s  d
 ```
 ```
 
 
 Replace <DATASET_NAME> with any of the [RF100 datasets](https://github.com/roboflow/roboflow-100-benchmark/blob/8587f81ef282d529fe5707c0eede74fe91d472d0/metadata/datasets_stats.csv) that you wish to train on.
 Replace <DATASET_NAME> with any of the [RF100 datasets](https://github.com/roboflow/roboflow-100-benchmark/blob/8587f81ef282d529fe5707c0eede74fe91d472d0/metadata/datasets_stats.csv) that you wish to train on.
+
+
+## Creating a model for a non-RGB image
+
+You can create a model taking arbitrary number of channels by passing the number of channels to the arch_params argument.
+Important thing to keep in mind that in this case you cannot use the available pretrained weights and have to provde `num_classes` parameter explicitly.
+
+```python
+model = models.get(Models.YOLO_NAS_S, arch_params=dict(in_channels=2), num_classes=15)
+```
Discard
@@ -1,3 +1,5 @@
+in_channels: 3
+
 backbone:
 backbone:
   NStageBackbone:
   NStageBackbone:
 
 
Discard
@@ -1,3 +1,5 @@
+in_channels: 3
+
 backbone:
 backbone:
   NStageBackbone:
   NStageBackbone:
 
 
Discard
@@ -1,3 +1,5 @@
+in_channels: 3
+
 backbone:
 backbone:
   NStageBackbone:
   NStageBackbone:
 
 
Discard
@@ -14,7 +14,7 @@ from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPo
 
 
 @register_model(Models.YOLO_NAS_S)
 @register_model(Models.YOLO_NAS_S)
 class YoloNAS_S(CustomizableDetector):
 class YoloNAS_S(CustomizableDetector):
-    def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
+    def __init__(self, arch_params: Union[HpmStruct, DictConfig]):
         default_arch_params = get_arch_params("yolo_nas_s_arch_params")
         default_arch_params = get_arch_params("yolo_nas_s_arch_params")
         merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params))
         merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params))
         merged_arch_params.override(**arch_params.to_dict())
         merged_arch_params.override(**arch_params.to_dict())
@@ -23,7 +23,7 @@ class YoloNAS_S(CustomizableDetector):
             neck=merged_arch_params.neck,
             neck=merged_arch_params.neck,
             heads=merged_arch_params.heads,
             heads=merged_arch_params.heads,
             num_classes=get_param(merged_arch_params, "num_classes", None),
             num_classes=get_param(merged_arch_params, "num_classes", None),
-            in_channels=in_channels,
+            in_channels=get_param(merged_arch_params, "in_channels", 3),
             bn_momentum=get_param(merged_arch_params, "bn_momentum", None),
             bn_momentum=get_param(merged_arch_params, "bn_momentum", None),
             bn_eps=get_param(merged_arch_params, "bn_eps", None),
             bn_eps=get_param(merged_arch_params, "bn_eps", None),
             inplace_act=get_param(merged_arch_params, "inplace_act", None),
             inplace_act=get_param(merged_arch_params, "inplace_act", None),
@@ -40,7 +40,7 @@ class YoloNAS_S(CustomizableDetector):
 
 
 @register_model(Models.YOLO_NAS_M)
 @register_model(Models.YOLO_NAS_M)
 class YoloNAS_M(CustomizableDetector):
 class YoloNAS_M(CustomizableDetector):
-    def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
+    def __init__(self, arch_params: Union[HpmStruct, DictConfig]):
         default_arch_params = get_arch_params("yolo_nas_m_arch_params")
         default_arch_params = get_arch_params("yolo_nas_m_arch_params")
         merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params))
         merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params))
         merged_arch_params.override(**arch_params.to_dict())
         merged_arch_params.override(**arch_params.to_dict())
@@ -49,7 +49,7 @@ class YoloNAS_M(CustomizableDetector):
             neck=merged_arch_params.neck,
             neck=merged_arch_params.neck,
             heads=merged_arch_params.heads,
             heads=merged_arch_params.heads,
             num_classes=get_param(merged_arch_params, "num_classes", None),
             num_classes=get_param(merged_arch_params, "num_classes", None),
-            in_channels=in_channels,
+            in_channels=get_param(merged_arch_params, "in_channels", 3),
             bn_momentum=get_param(merged_arch_params, "bn_momentum", None),
             bn_momentum=get_param(merged_arch_params, "bn_momentum", None),
             bn_eps=get_param(merged_arch_params, "bn_eps", None),
             bn_eps=get_param(merged_arch_params, "bn_eps", None),
             inplace_act=get_param(merged_arch_params, "inplace_act", None),
             inplace_act=get_param(merged_arch_params, "inplace_act", None),
@@ -66,7 +66,7 @@ class YoloNAS_M(CustomizableDetector):
 
 
 @register_model(Models.YOLO_NAS_L)
 @register_model(Models.YOLO_NAS_L)
 class YoloNAS_L(CustomizableDetector):
 class YoloNAS_L(CustomizableDetector):
-    def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
+    def __init__(self, arch_params: Union[HpmStruct, DictConfig]):
         default_arch_params = get_arch_params("yolo_nas_l_arch_params")
         default_arch_params = get_arch_params("yolo_nas_l_arch_params")
         merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params))
         merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params))
         merged_arch_params.override(**arch_params.to_dict())
         merged_arch_params.override(**arch_params.to_dict())
@@ -75,7 +75,7 @@ class YoloNAS_L(CustomizableDetector):
             neck=merged_arch_params.neck,
             neck=merged_arch_params.neck,
             heads=merged_arch_params.heads,
             heads=merged_arch_params.heads,
             num_classes=get_param(merged_arch_params, "num_classes", None),
             num_classes=get_param(merged_arch_params, "num_classes", None),
-            in_channels=in_channels,
+            in_channels=get_param(merged_arch_params, "in_channels", 3),
             bn_momentum=get_param(merged_arch_params, "bn_momentum", None),
             bn_momentum=get_param(merged_arch_params, "bn_momentum", None),
             bn_eps=get_param(merged_arch_params, "bn_eps", None),
             bn_eps=get_param(merged_arch_params, "bn_eps", None),
             inplace_act=get_param(merged_arch_params, "inplace_act", None),
             inplace_act=get_param(merged_arch_params, "inplace_act", None),
Discard
@@ -45,6 +45,7 @@ from tests.unit_tests.dice_loss_test import DiceLossTest
 from tests.unit_tests.iou_loss_test import IoULossTest
 from tests.unit_tests.iou_loss_test import IoULossTest
 from tests.unit_tests.update_param_groups_unit_test import UpdateParamGroupsTest
 from tests.unit_tests.update_param_groups_unit_test import UpdateParamGroupsTest
 from tests.unit_tests.vit_unit_test import TestViT
 from tests.unit_tests.vit_unit_test import TestViT
+from tests.unit_tests.yolo_nas_tests import TestYOLONAS
 from tests.unit_tests.yolox_unit_test import TestYOLOX
 from tests.unit_tests.yolox_unit_test import TestYOLOX
 from tests.unit_tests.lr_cooldown_test import LRCooldownTest
 from tests.unit_tests.lr_cooldown_test import LRCooldownTest
 from tests.unit_tests.detection_targets_format_transform_test import DetectionTargetsTransformTest
 from tests.unit_tests.detection_targets_format_transform_test import DetectionTargetsTransformTest
@@ -135,6 +136,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LoadCheckpointTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LoadCheckpointTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(ReplaceHeadUnitTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(ReplaceHeadUnitTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(PreprocessingUnitTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(PreprocessingUnitTest))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestYOLONAS))
 
 
     def _add_modules_to_end_to_end_tests_suite(self):
     def _add_modules_to_end_to_end_tests_suite(self):
         """
         """
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
  1. import unittest
  2. import torch
  3. from super_gradients.common.object_names import Models
  4. from super_gradients.training import models
  5. class TestYOLONAS(unittest.TestCase):
  6. def setUp(self):
  7. pass
  8. def test_yolo_nas_custom_in_channels(self):
  9. """
  10. Validate that we can create a YOLO-NAS model with custom in_channels.
  11. """
  12. model = models.get(Models.YOLO_NAS_S, arch_params=dict(in_channels=2), num_classes=17)
  13. model(torch.rand(1, 2, 640, 640))
  14. if __name__ == "__main__":
  15. unittest.main()
Discard