Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_supports_check_input_shape.py 1.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
  1. import unittest
  2. import torch
  3. from super_gradients.common.object_names import Models
  4. from super_gradients.training import models
  5. class TestSupportsInputShapeCheck(unittest.TestCase):
  6. def setUp(self):
  7. self.models_to_check = [Models.YOLO_NAS_S, Models.YOLO_NAS_POSE_S, Models.PP_LITE_T_SEG50, Models.STDC1_SEG50, Models.DDRNET_23]
  8. @torch.no_grad()
  9. def test_can_run_inference_with_min_size(self):
  10. for model in self.models_to_check:
  11. with self.subTest(model=model):
  12. model = models.get(model, num_classes=20).eval()
  13. min_shape = model.get_minimum_input_shape_size()
  14. if min_shape is not None:
  15. dummy_input = torch.randn(1, 3, *min_shape)
  16. model.validate_input_shape(dummy_input.size())
  17. model(dummy_input)
  18. @torch.no_grad()
  19. def test_validate_invalid_size(self):
  20. for model in self.models_to_check:
  21. with self.subTest(model=model):
  22. model = models.get(model, num_classes=20).eval()
  23. steps = model.get_input_shape_steps()
  24. invalid_shape = [x * 4 + 1 for x in steps]
  25. dummy_input = torch.randn(1, 3, *invalid_shape)
  26. with self.assertRaises(ValueError):
  27. model.validate_input_shape(dummy_input.size())
  28. if __name__ == "__main__":
  29. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...