Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_yolo_nas_pose.py 4.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
  1. import unittest
  2. import torch
  3. from super_gradients.common import StrictLoad
  4. from super_gradients.common.object_names import Models
  5. from super_gradients.training import models
  6. from super_gradients.training.datasets.pose_estimation_datasets.yolo_nas_pose_collate_fn import (
  7. flat_collate_tensors_with_batch_index,
  8. undo_flat_collate_tensors_with_batch_index,
  9. )
  10. from super_gradients.training.losses import YoloNASPoseLoss
  11. class YoloNASPoseTests(unittest.TestCase):
  12. def test_yolo_nas_pose_forward(self):
  13. num_joints = 33
  14. model = models.get(Models.YOLO_NAS_POSE_N, num_classes=num_joints).eval()
  15. input = torch.randn((1, 3, 640, 640))
  16. decoded_predictions, _ = model(input)
  17. pred_bboxes, pred_scores, pred_pose_coords, pred_pose_scores = decoded_predictions
  18. self.assertEquals(pred_bboxes.shape[2], 4)
  19. self.assertEquals(pred_scores.shape[2], 1)
  20. self.assertEquals(pred_pose_coords.shape[2], num_joints)
  21. self.assertEquals(pred_pose_coords.shape[3], 2)
  22. self.assertEquals(pred_pose_scores.shape[2], num_joints)
  23. def test_yolo_nas_pose_loss_function(self):
  24. model = models.get(Models.YOLO_NAS_POSE_N, num_classes=17)
  25. input = torch.randn((3, 3, 640, 640))
  26. outputs = model(input)
  27. criterion = YoloNASPoseLoss(
  28. oks_sigmas=[0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089],
  29. )
  30. # A single tensor of shape (N, 1 + 4 + Num Joints * 3) (batch_index, x1, y1, x2, y2, [x, y, visibility] * Num Joints)
  31. # First image has 1 object, second image has 2 objects, third image has no objects
  32. target_boxes = flat_collate_tensors_with_batch_index(
  33. [
  34. torch.tensor([[10, 10, 100, 200]]),
  35. torch.tensor([[300, 500, 400, 550], [200, 200, 400, 400]]),
  36. torch.zeros((0, 4)),
  37. ]
  38. ).float()
  39. target_poses = flat_collate_tensors_with_batch_index(
  40. [
  41. torch.randn((1, 17, 3)), # First image has 1 object
  42. torch.randn((2, 17, 3)), # Second image has 2 objects
  43. torch.zeros((0, 17, 3)), # Third image has no objects
  44. ]
  45. ).float()
  46. target_poses[..., 3] = 2.0 # Mark all joints as visible
  47. target_crowds = flat_collate_tensors_with_batch_index([torch.zeros((1, 1)), torch.zeros((2, 1)), torch.zeros((0, 1))]).float()
  48. targets = (target_boxes, target_poses, target_crowds)
  49. loss = criterion(outputs=outputs, targets=targets)
  50. loss[0].backward()
  51. def test_flat_collate_2d(self):
  52. values = [
  53. torch.randn([1, 4]),
  54. torch.randn([2, 4]),
  55. torch.randn([0, 4]),
  56. torch.randn([3, 4]),
  57. ]
  58. flat_tensor = flat_collate_tensors_with_batch_index(values)
  59. undo_values = undo_flat_collate_tensors_with_batch_index(flat_tensor, 4)
  60. assert len(undo_values) == len(values)
  61. assert (undo_values[0] == values[0]).all()
  62. assert (undo_values[1] == values[1]).all()
  63. assert (undo_values[2] == values[2]).all()
  64. assert (undo_values[3] == values[3]).all()
  65. def test_flat_collate_3d(self):
  66. values = [
  67. torch.randn([1, 17, 3]),
  68. torch.randn([2, 17, 3]),
  69. torch.randn([0, 17, 3]),
  70. torch.randn([3, 17, 3]),
  71. ]
  72. flat_tensor = flat_collate_tensors_with_batch_index(values)
  73. undo_values = undo_flat_collate_tensors_with_batch_index(flat_tensor, 4)
  74. assert len(undo_values) == len(values)
  75. assert (undo_values[0] == values[0]).all()
  76. assert (undo_values[1] == values[1]).all()
  77. assert (undo_values[2] == values[2]).all()
  78. assert (undo_values[3] == values[3]).all()
  79. def test_yolo_nas_pose_replace_classes(self):
  80. model = models.get(Models.YOLO_NAS_POSE_N, num_classes=17)
  81. model.replace_head(new_num_classes=20)
  82. input = torch.randn((1, 3, 640, 640))
  83. decoded_predictions, _ = model(input)
  84. pred_bboxes, pred_scores, pred_pose_coords, pred_pose_scores = decoded_predictions
  85. self.assertEqual(pred_pose_coords.shape[2], 20)
  86. self.assertEqual(pred_pose_scores.shape[2], 20)
  87. def test_pose_former_b2(self):
  88. model = models.get(
  89. "PoseFormer_B2", num_classes=17, checkpoint_path="https://sghub.deci.ai/models/segformer_b2_cityscapes.pth", strict_load=StrictLoad.KEY_MATCHING
  90. )
  91. x = torch.rand((1, 3, 640, 640))
  92. y = model(x)
  93. print(y)
  94. pass
  95. def test_pose_former(self):
  96. model = models.get(
  97. "PoseFormer_B5", num_classes=17, checkpoint_path="https://sghub.deci.ai/models/segformer_b5_cityscapes.pth", strict_load=StrictLoad.KEY_MATCHING
  98. )
  99. x = torch.rand((1, 3, 640, 640))
  100. y = model(x)
  101. print(y)
  102. pass
  103. if __name__ == "__main__":
  104. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...