Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#859 Fix non-registered (and crashing) unit test for DEKR Target generator

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-000-fix-import-error-in-test
@@ -30,6 +30,7 @@ from tests.unit_tests.load_checkpoint_test import LoadCheckpointTest
 from tests.unit_tests.local_ckpt_head_replacement_test import LocalCkptHeadReplacementTest
 from tests.unit_tests.local_ckpt_head_replacement_test import LocalCkptHeadReplacementTest
 from tests.unit_tests.max_batches_loop_break_test import MaxBatchesLoopBreakTest
 from tests.unit_tests.max_batches_loop_break_test import MaxBatchesLoopBreakTest
 from tests.unit_tests.phase_delegates_test import ContextMethodsTest
 from tests.unit_tests.phase_delegates_test import ContextMethodsTest
+from tests.unit_tests.pose_estimation_dataset_test import TestPoseEstimationDataset
 from tests.unit_tests.preprocessing_unit_test import PreprocessingUnitTest
 from tests.unit_tests.preprocessing_unit_test import PreprocessingUnitTest
 from tests.unit_tests.quantization_utility_tests import QuantizationUtilityTest
 from tests.unit_tests.quantization_utility_tests import QuantizationUtilityTest
 from tests.unit_tests.random_erase_test import RandomEraseTest
 from tests.unit_tests.random_erase_test import RandomEraseTest
@@ -129,6 +130,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPPYOLOE))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPPYOLOE))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DEKRLossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DEKRLossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPoseEstimationMetrics))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPoseEstimationMetrics))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPoseEstimationDataset))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LoadCheckpointTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LoadCheckpointTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(PreprocessingUnitTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(PreprocessingUnitTest))
 
 
Discard
@@ -1,29 +1,30 @@
-from super_gradients.training.datasets.pose_estimation_datasets.coco_keypoints import COCOKeypointsDataset
-from super_gradients.common.registry.registry import DEKRTargetsGenerator
-from super_gradients.training.transforms.keypoint_transforms import KeypointsCompose, KeypointsRandomVerticalFlip
+import unittest
+import numpy as np
+import torch
 
 
+from super_gradients.training.datasets.pose_estimation_datasets import DEKRTargetsGenerator
 
 
-def test_dataset():
-    target_generator = DEKRTargetsGenerator(
-        output_stride=4,
-        sigma=2,
-        center_sigma=4,
-        bg_weight=0.1,
-        offset_radius=4,
-    )
 
 
-    dataset = COCOKeypointsDataset(
-        data_dir="e:/coco2017",
-        images_dir="images/train2017",
-        json_file="annotations/person_keypoints_train2017.json",
-        include_empty_samples=False,
-        transforms=KeypointsCompose(
-            [
-                KeypointsRandomVerticalFlip(),
-            ]
-        ),
-        target_generator=target_generator,
-    )
+class TestPoseEstimationDataset(unittest.TestCase):
+    def test_dekr_target_generator(self):
+        target_generator = DEKRTargetsGenerator(
+            output_stride=4,
+            sigma=2,
+            center_sigma=4,
+            bg_weight=0.1,
+            offset_radius=4,
+        )
 
 
-    assert dataset is not None
-    assert dataset[0] is not None
+        joints = np.random.randint(0, 255, (4, 17, 3))
+        joints[:, :, 2] = 1
+
+        heatmaps, mask, offset_map, offset_weight = target_generator(
+            image=torch.zeros((3, 256, 256)),
+            joints=joints,
+            mask=np.ones((256, 256)),
+        )
+
+        self.assertEqual(heatmaps.shape, (18, 64, 64))
+        self.assertEqual(mask.shape, (18, 64, 64))
+        self.assertEqual(offset_map.shape, (34, 64, 64))
+        self.assertEqual(offset_weight.shape, (34, 64, 64))
Discard