Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

transforms_test.py 4.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
  1. import unittest
  2. import numpy as np
  3. from super_gradients.training.transforms.keypoint_transforms import (
  4. KeypointsRandomHorizontalFlip,
  5. KeypointsRandomVerticalFlip,
  6. KeypointsRandomAffineTransform,
  7. KeypointsPadIfNeeded,
  8. KeypointsLongestMaxSize,
  9. )
  10. from super_gradients.training.transforms.transforms import DetectionImagePermute, DetectionPadToSize
  11. class TestTransforms(unittest.TestCase):
  12. def test_keypoints_random_affine(self):
  13. image = np.random.rand(640, 480, 3)
  14. mask = np.random.rand(640, 480)
  15. joints = np.random.randint(0, 480, size=(1, 17, 3))
  16. joints[..., 2] = 2 # all visible
  17. aug = KeypointsRandomAffineTransform(min_scale=0.8, max_scale=1.2, max_rotation=30, max_translate=0.5, prob=1, image_pad_value=0, mask_pad_value=0)
  18. aug_image, aug_mask, aug_joints, _, _ = aug(image, mask, joints, None, None)
  19. joints_outside_image = (
  20. (aug_joints[:, :, 0] < 0) | (aug_joints[:, :, 1] < 0) | (aug_joints[:, :, 0] >= aug_image.shape[1]) | (aug_joints[:, :, 1] >= aug_image.shape[0])
  21. )
  22. # Ensure that keypoints outside the image are not visible
  23. self.assertTrue((aug_joints[joints_outside_image, 2] == 0).all())
  24. self.assertTrue((aug_joints[~joints_outside_image, 2] != 0).all())
  25. def test_keypoints_horizontal_flip(self):
  26. image = np.random.rand(640, 480, 3)
  27. mask = np.random.rand(640, 480)
  28. joints = np.random.randint(0, 100, size=(1, 17, 3))
  29. aug = KeypointsRandomHorizontalFlip(flip_index=[16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], prob=1)
  30. aug_image, aug_mask, aug_joints, _, _ = aug(image, mask, joints, None, None)
  31. np.testing.assert_array_equal(aug_image, image[:, ::-1, :])
  32. np.testing.assert_array_equal(aug_mask, mask[:, ::-1])
  33. np.testing.assert_array_equal(image.shape[1] - aug_joints[:, ::-1, 0] - 1, joints[..., 0])
  34. np.testing.assert_array_equal(aug_joints[:, ::-1, 1], joints[..., 1])
  35. np.testing.assert_array_equal(aug_joints[:, ::-1, 2], joints[..., 2])
  36. def test_keypoints_vertical_flip(self):
  37. image = np.random.rand(640, 480, 3)
  38. mask = np.random.rand(640, 480)
  39. joints = np.random.randint(0, 100, size=(1, 17, 3))
  40. aug = KeypointsRandomVerticalFlip(prob=1)
  41. aug_image, aug_mask, aug_joints, _, _ = aug(image, mask, joints, None, None)
  42. np.testing.assert_array_equal(aug_image, image[::-1, :, :])
  43. np.testing.assert_array_equal(aug_mask, mask[::-1, :])
  44. np.testing.assert_array_equal(aug_joints[..., 0], joints[..., 0])
  45. np.testing.assert_array_equal(image.shape[0] - aug_joints[..., 1] - 1, joints[..., 1])
  46. np.testing.assert_array_equal(aug_joints[..., 2], joints[..., 2])
  47. def test_keypoints_pad_if_needed(self):
  48. image = np.random.rand(640, 480, 3)
  49. mask = np.random.rand(640, 480)
  50. joints = np.random.randint(0, 100, size=(1, 17, 3))
  51. aug = KeypointsPadIfNeeded(min_width=768, min_height=768, image_pad_value=0, mask_pad_value=0)
  52. aug_image, aug_mask, aug_joints, _, _ = aug(image, mask, joints, None, None)
  53. self.assertEqual(aug_image.shape, (768, 768, 3))
  54. self.assertEqual(aug_mask.shape, (768, 768))
  55. np.testing.assert_array_equal(aug_joints, joints)
  56. def test_keypoints_longest_max_size(self):
  57. image = np.random.rand(640, 480, 3)
  58. mask = np.random.rand(640, 480)
  59. joints = np.random.randint(0, 480, size=(1, 17, 3))
  60. aug = KeypointsLongestMaxSize(max_height=512, max_width=512)
  61. aug_image, aug_mask, aug_joints, _, _ = aug(image, mask, joints, None, None)
  62. self.assertEqual(aug_image.shape[:2], aug_mask.shape[:2])
  63. self.assertLessEqual(aug_image.shape[0], 512)
  64. self.assertLessEqual(aug_image.shape[1], 512)
  65. self.assertTrue((aug_joints[..., 0] < aug_image.shape[1]).all())
  66. self.assertTrue((aug_joints[..., 1] < aug_image.shape[0]).all())
  67. def test_detection_image_permute(self):
  68. aug = DetectionImagePermute(dims=(2, 1, 0))
  69. image = np.random.rand(640, 480, 3)
  70. sample = {"image": image}
  71. output = aug(sample)
  72. self.assertEqual(output["image"].shape, (3, 480, 640))
  73. def test_detection_pad_to_size(self):
  74. aug = DetectionPadToSize(output_size=(640, 640))
  75. image = np.ones((512, 480, 3))
  76. # Boxes in format (x1, y1, x2, y2, class_id)
  77. boxes = np.array([[0, 0, 100, 100, 0], [100, 100, 200, 200, 1]])
  78. sample = {"image": image, "target": boxes}
  79. output = aug(sample)
  80. shift_x = (640 - 480) // 2
  81. shift_y = (640 - 512) // 2
  82. expected_boxes = np.array(
  83. [[0 + shift_x, 0 + shift_y, 100 + shift_x, 100 + shift_y, 0], [100 + shift_x, 100 + shift_y, 200 + shift_x, 200 + shift_y, 1]]
  84. )
  85. self.assertEqual(output["image"].shape, (640, 640, 3))
  86. np.testing.assert_array_equal(output["target"], expected_boxes)
  87. if __name__ == "__main__":
  88. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...