Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#643 PPYolo-E

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-344-PP-Yolo-E-Training-Replicate-Recipe
@@ -1,47 +1,41 @@
-import os.path
 import unittest
 import unittest
 
 
-import hydra
-import pkg_resources
 import torch
 import torch
-from hydra import initialize_config_dir, compose
-from hydra.core.global_hydra import GlobalHydra
 
 
-from super_gradients.training.models.detection_models.csp_resnet import CSPResNet
-from super_gradients.common.environment.path_utils import normalize_path
+from super_gradients.training import models
+from super_gradients.training.models.detection_models.pp_yolo_e.pp_yolo_e import PPYoloE_X, PPYoloE_S, PPYoloE_M, PPYoloE_L
 
 
 
 
-class PPYoloETests(unittest.TestCase):
-    def get_model_arch_params(self, config_name):
-        GlobalHydra.instance().clear()
-        sg_recipes_dir = pkg_resources.resource_filename("super_gradients.recipes", "")
-        with initialize_config_dir(config_dir=normalize_path(sg_recipes_dir), version_base="1.2"):
-            cfg = compose(config_name=normalize_path(config_name))
-            cfg = hydra.utils.instantiate(cfg)
-            arch_params = cfg.arch_params
-
-        return arch_params
-
-    def _test_csp_resnet_variant(self, variant):
-        arch_params = self.get_model_arch_params(os.path.join("arch_params", variant))
+class TestPPYOLOE(unittest.TestCase):
+    def _test_ppyoloe_from_name(self, model_name, pretrained_weights):
+        ppyoloe = models.get(model_name, pretrained_weights=pretrained_weights, num_classes=80 if pretrained_weights is None else None).eval()
+        dummy_input = torch.randn(1, 3, 640, 480)
+        with torch.no_grad():
+            feature_maps = ppyoloe(dummy_input)
+            self.assertIsNotNone(feature_maps)
 
 
-        ppyoloe = CSPResNet(**arch_params)
-        dummy_input = torch.randn(1, 3, 320, 320)
+    def _test_ppyoloe_from_cls(self, model_cls):
+        ppyoloe = model_cls(arch_params={}).eval()
+        dummy_input = torch.randn(1, 3, 640, 480)
         with torch.no_grad():
         with torch.no_grad():
             feature_maps = ppyoloe(dummy_input)
             feature_maps = ppyoloe(dummy_input)
-            self.assertEqual(len(feature_maps), 3)
+            self.assertIsNotNone(feature_maps)
 
 
-    def test_csp_resnet_s(self):
-        self._test_csp_resnet_variant("csp_resnet_l_arch_params")
+    def test_ppyoloe_s(self):
+        self._test_ppyoloe_from_name("ppyoloe_s", pretrained_weights="coco")
+        self._test_ppyoloe_from_cls(PPYoloE_S)
 
 
-    def test_csp_resnet_m(self):
-        self._test_csp_resnet_variant("csp_resnet_m_arch_params")
+    def test_ppyoloe_m(self):
+        self._test_ppyoloe_from_name("ppyoloe_m", pretrained_weights="coco")
+        self._test_ppyoloe_from_cls(PPYoloE_M)
 
 
-    def test_csp_resnet_l(self):
-        self._test_csp_resnet_variant("csp_resnet_l_arch_params")
+    def test_ppyoloe_l(self):
+        self._test_ppyoloe_from_name("ppyoloe_l", pretrained_weights=None)
+        self._test_ppyoloe_from_cls(PPYoloE_L)
 
 
-    def test_csp_resnet_x(self):
-        self._test_csp_resnet_variant("csp_resnet_x_arch_params")
+    def test_ppyoloe_x(self):
+        self._test_ppyoloe_from_name("ppyoloe_x", pretrained_weights=None)
+        self._test_ppyoloe_from_cls(PPYoloE_X)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard
Tip!

Press p or to see the previous file or, n or to see the next file