Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#864 Feature/sg 736 deci yolo rf100 yolo nas

Merged
Ghost merged 1 commits into Deci-AI:feature/SG-736_deci_yolo_rf100 from deci-ai:feature/SG-736_deci_yolo_rf100_yolo_nas
@@ -22,9 +22,9 @@ class ReplaceHeadUnitTest(unittest.TestCase):
             (_, pred_scores), _ = model.forward(input)
             (_, pred_scores), _ = model.forward(input)
             self.assertEqual(pred_scores.size(2), 100)
             self.assertEqual(pred_scores.size(2), 100)
 
 
-    def test_yolo_sg_replace_head(self):
+    def test_yolo_nas_replace_head(self):
         input = torch.randn(1, 3, 640, 640).to(self.device)
         input = torch.randn(1, 3, 640, 640).to(self.device)
-        for model in [Models.YoloSG_S, Models.YoloSG_M, Models.YoloSG_L]:
+        for model in [Models.YOLO_NAS_S, Models.YOLO_NAS_M, Models.YOLO_NAS_L]:
             model = models.get(model, pretrained_weights="coco").to(self.device).eval()
             model = models.get(model, pretrained_weights="coco").to(self.device).eval()
             model.replace_head(new_num_classes=100)
             model.replace_head(new_num_classes=100)
             (_, pred_scores), _ = model.forward(input)
             (_, pred_scores), _ = model.forward(input)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file