Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

module_utils_test.py 2.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
  1. from typing import List
  2. import unittest
  3. import torch.nn as nn
  4. from super_gradients.training.models.detection_models.yolov5 import YoLoV5X
  5. from super_gradients.training.utils.module_utils import replace_activations
  6. from super_gradients.training.utils.utils import HpmStruct
  7. class TestModuleUtils(unittest.TestCase):
  8. def test_activation_replacement(self):
  9. arch_params = HpmStruct()
  10. yolov5x = YoLoV5X(arch_params=arch_params)
  11. new_activation = nn.ReLU()
  12. activations_to_replace = [nn.SiLU]
  13. yolov5x_relu = YoLoV5X(arch_params=arch_params)
  14. replace_activations(yolov5x_relu, new_activation, activations_to_replace)
  15. self._assert_activations_replaced(yolov5x_relu, yolov5x, new_activation, activations_to_replace)
  16. def _assert_activations_replaced(self, new_module: nn.Module, orig_module: nn.Module,
  17. new_activation: nn.Module, replaced_activations: List[type]):
  18. """
  19. Assert:
  20. * that new_module doesn't contain any of activations of replaced types
  21. * that in places where original module has an activation of any of replaced_activations types
  22. new_module has a new activation
  23. * that new activations are unique objects and don't share new_activation's address
  24. Runs recursively on all submodules.
  25. :param new_module: A module with replaced activations
  26. :param orig_module: A module of the same architecture, but with activations of an original type
  27. :param new_activation: A new activation
  28. :param replaced_activations: A list of types of activations that should have been replaced;
  29. each should be a subclass of nn.Module
  30. """
  31. for new_m, orig_m in zip(new_module.children(), orig_module.children()):
  32. self.assertTrue(type(new_m) not in replaced_activations)
  33. if type(orig_m) in replaced_activations:
  34. self.assertTrue(type(new_m) == type(new_activation))
  35. self.assertTrue(id(new_m) != id(new_activation))
  36. self._assert_activations_replaced(new_m, orig_m, new_activation, replaced_activations)
  37. if __name__ == '__main__':
  38. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...