Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

ppyoloe_unit_test.py 4.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
  1. import unittest
  2. import torch
  3. from super_gradients.common.object_names import Models
  4. from super_gradients.training import models
  5. from super_gradients.training.losses import PPYoloELoss
  6. from super_gradients.training.models.detection_models.pp_yolo_e.pp_yolo_e import PPYoloE_X, PPYoloE_S, PPYoloE_M, PPYoloE_L
  7. class TestPPYOLOE(unittest.TestCase):
  8. def _test_ppyoloe_from_name(self, model_name, pretrained_weights):
  9. ppyoloe = models.get(model_name, pretrained_weights=pretrained_weights, num_classes=80 if pretrained_weights is None else None).eval()
  10. dummy_input = torch.randn(1, 3, 640, 480)
  11. with torch.no_grad():
  12. feature_maps = ppyoloe(dummy_input)
  13. self.assertIsNotNone(feature_maps)
  14. def _test_ppyoloe_from_cls(self, model_cls):
  15. ppyoloe = model_cls(arch_params={}).eval()
  16. dummy_input = torch.randn(1, 3, 640, 480)
  17. with torch.no_grad():
  18. feature_maps = ppyoloe(dummy_input)
  19. self.assertIsNotNone(feature_maps)
  20. def test_ppyoloe_s(self):
  21. self._test_ppyoloe_from_name("ppyoloe_s", pretrained_weights="coco")
  22. self._test_ppyoloe_from_cls(PPYoloE_S)
  23. def test_ppyoloe_m(self):
  24. self._test_ppyoloe_from_name("ppyoloe_m", pretrained_weights="coco")
  25. self._test_ppyoloe_from_cls(PPYoloE_M)
  26. def test_ppyoloe_l(self):
  27. self._test_ppyoloe_from_name("ppyoloe_l", pretrained_weights=None)
  28. self._test_ppyoloe_from_cls(PPYoloE_L)
  29. def test_ppyoloe_x(self):
  30. self._test_ppyoloe_from_name("ppyoloe_x", pretrained_weights=None)
  31. self._test_ppyoloe_from_cls(PPYoloE_X)
  32. def test_ppyoloe_batched_vs_sequential_loss(self):
  33. for use_static_assigner in [True, False]:
  34. with self.subTest(use_static_assigner=use_static_assigner):
  35. torch.random.manual_seed(0)
  36. batched_loss = PPYoloELoss(
  37. num_classes=80, use_varifocal_loss=True, use_static_assigner=use_static_assigner, reg_max=16, use_batched_assignment=True
  38. )
  39. sequential_loss = PPYoloELoss(
  40. num_classes=80, use_varifocal_loss=True, use_static_assigner=use_static_assigner, reg_max=16, use_batched_assignment=False
  41. )
  42. model = models.get(Models.PP_YOLOE_S, num_classes=80)
  43. random_input = torch.randn(4, 3, 640, 480)
  44. output = model(random_input)
  45. # (N, 6) (batch_index, class_index, cx, cy, w, h)
  46. # Five objects in the first image, three objects in the second image, two objects in the third image, no objects in the fourth image
  47. targets = torch.tensor(
  48. [
  49. [0, 2, 40, 60, 100, 200],
  50. [0, 3, 100, 200, 100, 200],
  51. [0, 4, 200, 300, 100, 200],
  52. [0, 5, 300, 400, 100, 200],
  53. [0, 6, 400, 500, 100, 200],
  54. [1, 2, 40, 60, 100, 200],
  55. [1, 3, 100, 200, 100, 200],
  56. [1, 4, 200, 300, 100, 200],
  57. [2, 2, 40, 60, 100, 200],
  58. [2, 3, 100, 200, 100, 200],
  59. ]
  60. ).float()
  61. batched_loss_output = batched_loss(output, targets)
  62. sequential_loss_output = sequential_loss(output, targets)
  63. self.assertAlmostEqual(batched_loss_output[0].item(), sequential_loss_output[0].item(), places=4)
  64. self.assertAlmostEqual(batched_loss_output[1][0].item(), sequential_loss_output[1][0].item(), places=4)
  65. self.assertAlmostEqual(batched_loss_output[1][1].item(), sequential_loss_output[1][1].item(), places=4)
  66. self.assertAlmostEqual(batched_loss_output[1][2].item(), sequential_loss_output[1][2].item(), places=4)
  67. self.assertAlmostEqual(batched_loss_output[1][3].item(), sequential_loss_output[1][3].item(), places=4)
  68. if __name__ == "__main__":
  69. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...