Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#437 Feature/SG-195 Detection output adapter

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-195-detection-output-adapter
@@ -1,3 +1,6 @@
+import itertools
+import os
+import tempfile
 import unittest
 import unittest
 
 
 import numpy as np
 import numpy as np
@@ -16,10 +19,26 @@ from super_gradients.training.utils.bbox_formats import (
     BBOX_FORMATS,
     BBOX_FORMATS,
     BoundingBoxFormat,
     BoundingBoxFormat,
 )
 )
+from super_gradients.training.utils.bbox_formats.normalized_cxcywh import (
+    normalized_cxcywh_to_xyxy_inplace,
+    xyxy_to_normalized_cxcywh_inplace,
+    xyxy_to_normalized_cxcywh,
+    normalized_cxcywh_to_xyxy,
+)
+from super_gradients.training.utils.bbox_formats.normalized_xywh import (
+    xyxy_to_normalized_xywh_inplace,
+    xyxy_to_normalized_xywh,
+    normalized_xywh_to_xyxy_inplace,
+    normalized_xywh_to_xyxy,
+)
+from super_gradients.training.utils.bbox_formats.xywh import xyxy_to_xywh, xywh_to_xyxy, xywh_to_xyxy_inplace, xyxy_to_xywh_inplace
+from super_gradients.training.utils.bbox_formats.yxyx import xyxy_to_yxyx, xyxy_to_yxyx_inplace
+from super_gradients.training.utils.output_adapters.detection_adapter import ConvertBoundingBoxes
 
 
 
 
 class BBoxFormatsTest(unittest.TestCase):
 class BBoxFormatsTest(unittest.TestCase):
     def setUp(self):
     def setUp(self):
+
         # contains all formats
         # contains all formats
         self.formats = [
         self.formats = [
             XYWHCoordinateFormat(),
             XYWHCoordinateFormat(),
@@ -70,6 +89,26 @@ class BBoxFormatsTest(unittest.TestCase):
             },
             },
         ]
         ]
 
 
+    def test_inplace_vs_normal_conversion(self):
+        gt_bboxes = torch.randint(low=0, high=512, size=(8192, 4)).float()
+
+        conversion_functions = [
+            (xyxy_to_xywh_inplace, xyxy_to_xywh),
+            (xywh_to_xyxy_inplace, xywh_to_xyxy),
+            (xyxy_to_normalized_xywh_inplace, xyxy_to_normalized_xywh),
+            (normalized_xywh_to_xyxy_inplace, normalized_xywh_to_xyxy),
+            (normalized_cxcywh_to_xyxy_inplace, normalized_cxcywh_to_xyxy),
+            (xyxy_to_normalized_cxcywh_inplace, xyxy_to_normalized_cxcywh),
+            (xyxy_to_yxyx_inplace, xyxy_to_yxyx),
+        ]
+
+        for inplace_op, copy_op in conversion_functions:
+            inplace_pred = inplace_op(gt_bboxes.clone(), self.image_shape)
+            copy_pred = copy_op(gt_bboxes.clone(), self.image_shape)
+            self.assertTrue(
+                copy_pred.eq(inplace_pred).all(), msg=f"Inplace conversion operator {inplace_op} produced different results than non-inplace operator {copy_op}"
+            )
+
     def test_conversion_to_from_is_correct_2d_input_tensor(self):
     def test_conversion_to_from_is_correct_2d_input_tensor(self):
         """
         """
         Check whether bbox format supports 3D input shape as input: [L, 4]
         Check whether bbox format supports 3D input shape as input: [L, 4]
@@ -158,7 +197,7 @@ class BBoxFormatsTest(unittest.TestCase):
 
 
     def test_bbox_conversion_regression(self):
     def test_bbox_conversion_regression(self):
         # Convert bounding boxes to a dictionary of bboxes
         # Convert bounding boxes to a dictionary of bboxes
-        bounding_bboxes = {k: np.array([dic[k] for dic in self.bounding_bboxes]) for k in self.bounding_bboxes[0]}
+        bounding_bboxes = {k: np.array([dic[k] for dic in self.bounding_bboxes], dtype=np.float32) for k in self.bounding_bboxes[0]}
         gt_bboxes = bounding_bboxes["xyxy"]
         gt_bboxes = bounding_bboxes["xyxy"]
 
 
         image_shape = self.image_shape
         image_shape = self.image_shape
@@ -166,12 +205,24 @@ class BBoxFormatsTest(unittest.TestCase):
         for src_fmt in self.formats:
         for src_fmt in self.formats:
             input_bboxes = src_fmt.from_xyxy(gt_bboxes, image_shape, inplace=False)
             input_bboxes = src_fmt.from_xyxy(gt_bboxes, image_shape, inplace=False)
             if src_fmt.format in bounding_bboxes:
             if src_fmt.format in bounding_bboxes:
+                gt_bboxes_actual = src_fmt.to_xyxy(input_bboxes, image_shape, inplace=False)
+
+                np.testing.assert_allclose(gt_bboxes_actual, gt_bboxes, rtol=1e-4, atol=1e-4)
                 np.testing.assert_allclose(input_bboxes, bounding_bboxes[src_fmt.format], rtol=1e-4, atol=1e-4)
                 np.testing.assert_allclose(input_bboxes, bounding_bboxes[src_fmt.format], rtol=1e-4, atol=1e-4)
 
 
             for dst_fmt in self.formats:
             for dst_fmt in self.formats:
-                intermediate_format = convert_bboxes(input_bboxes, image_shape, src_fmt, dst_fmt, inplace=False)
+                intermediate_format = convert_bboxes(input_bboxes.copy(), image_shape, src_fmt, dst_fmt, inplace=False)
                 actual_bboxes = dst_fmt.to_xyxy(intermediate_format, image_shape, inplace=False)
                 actual_bboxes = dst_fmt.to_xyxy(intermediate_format, image_shape, inplace=False)
-                np.testing.assert_allclose(actual_bboxes, gt_bboxes, rtol=1e-4, atol=1e-4)
+                np.testing.assert_allclose(
+                    actual_bboxes, gt_bboxes, rtol=1e-4, atol=1e-4, err_msg=f"Conversion via copy from {src_fmt.format} to {dst_fmt.format} failed"
+                )
+
+                # In-place
+                intermediate_format = convert_bboxes(input_bboxes.copy(), image_shape, src_fmt, dst_fmt, inplace=True)
+                actual_bboxes = dst_fmt.to_xyxy(intermediate_format, image_shape, inplace=True)
+                np.testing.assert_allclose(
+                    actual_bboxes, gt_bboxes, rtol=1e-4, atol=1e-4, err_msg=f"Inplace conversion from {src_fmt.format} to {dst_fmt.format} failed"
+                )
 
 
     def test_bbox_formats_factory_test(self):
     def test_bbox_formats_factory_test(self):
         factory = BBoxFormatFactory()
         factory = BBoxFormatFactory()
@@ -180,6 +231,32 @@ class BBoxFormatsTest(unittest.TestCase):
             format: BoundingBoxFormat = factory.get(format_key)
             format: BoundingBoxFormat = factory.get(format_key)
             self.assertEqual(format_key, format.format)
             self.assertEqual(format_key, format.format)
 
 
+    def test_bbox_formats_converter_can_be_exported(self):
+        factory = BBoxFormatFactory()
+
+        src_format: BoundingBoxFormat = factory.get("xyxy")
+
+        gt_bboxes = torch.randint(low=0, high=512, size=(8192, 4)).float()
+
+        for format_key in BBOX_FORMATS.keys():
+            dst_format: BoundingBoxFormat = factory.get(format_key)
+
+            # Try all combinations of implace flags to ensure all functions are tested for exportability
+            for inp1, inp2 in itertools.product([True, False], [True, False]):
+                module = ConvertBoundingBoxes(
+                    location=(0, 4),
+                    to_xyxy=src_format.get_from_xyxy(inplace=inp1),
+                    from_xyxy=dst_format.get_to_xyxy(inplace=inp2),
+                    image_shape=self.image_shape,
+                )
+
+                torch.jit.script(module, example_inputs=[gt_bboxes.clone()])
+                torch.jit.trace(module, example_inputs=(gt_bboxes.clone(),))
+                with tempfile.TemporaryDirectory() as tmpdirname:
+                    adapter_fname = os.path.join(tmpdirname, "adapter.onnx")
+                    # Just test that export works, we test the correctness in the detection_output_adapter_test.py
+                    torch.onnx.export(module, gt_bboxes.clone(), adapter_fname)
+
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     unittest.main()
     unittest.main()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file