|
@@ -22,9 +22,9 @@ class ReplaceHeadUnitTest(unittest.TestCase):
|
|
(_, pred_scores), _ = model.forward(input)
|
|
(_, pred_scores), _ = model.forward(input)
|
|
self.assertEqual(pred_scores.size(2), 100)
|
|
self.assertEqual(pred_scores.size(2), 100)
|
|
|
|
|
|
- def test_yolo_sg_replace_head(self):
|
|
|
|
|
|
+ def test_yolo_nas_replace_head(self):
|
|
input = torch.randn(1, 3, 640, 640).to(self.device)
|
|
input = torch.randn(1, 3, 640, 640).to(self.device)
|
|
- for model in [Models.YoloSG_S, Models.YoloSG_M, Models.YoloSG_L]:
|
|
|
|
|
|
+ for model in [Models.YOLO_NAS_S, Models.YOLO_NAS_M, Models.YOLO_NAS_L]:
|
|
model = models.get(model, pretrained_weights="coco").to(self.device).eval()
|
|
model = models.get(model, pretrained_weights="coco").to(self.device).eval()
|
|
model.replace_head(new_num_classes=100)
|
|
model.replace_head(new_num_classes=100)
|
|
(_, pred_scores), _ = model.forward(input)
|
|
(_, pred_scores), _ = model.forward(input)
|