Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

repvgg_unit_test.py 5.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
  1. import unittest
  2. from super_gradients.common.registry.registry import ARCHITECTURES
  3. from super_gradients.training.models.classification_models.repvgg import RepVggA1
  4. from super_gradients.training.utils.utils import HpmStruct
  5. import torch
  6. import copy
  7. import numpy as np
  8. class BackboneBasedModel(torch.nn.Module):
  9. """
  10. Auxiliary model which will use repvgg as backbone
  11. """
  12. def __init__(self, backbone, backbone_output_channel, num_classes=1000):
  13. super(BackboneBasedModel, self).__init__()
  14. self.backbone = backbone
  15. self.conv = torch.nn.Conv2d(in_channels=backbone_output_channel, out_channels=backbone_output_channel, kernel_size=1, stride=1, padding=0)
  16. self.bn = torch.nn.BatchNorm2d(backbone_output_channel) # Adding a bn layer that should NOT be fused
  17. self.avgpool = torch.nn.AdaptiveAvgPool2d(output_size=1)
  18. self.linear = torch.nn.Linear(backbone_output_channel, num_classes)
  19. def forward(self, x):
  20. x = self.backbone(x)
  21. x = self.conv(x)
  22. x = self.bn(x)
  23. x = self.avgpool(x)
  24. x = x.view(x.size(0), -1)
  25. return self.linear(x)
  26. def prep_model_for_conversion(self):
  27. if hasattr(self.backbone, "prep_model_for_conversion"):
  28. self.backbone.prep_model_for_conversion()
  29. class TestRepVgg(unittest.TestCase):
  30. def setUp(self):
  31. # contains all arch_params needed for initialization of all architectures
  32. self.all_arch_params = HpmStruct(**{"num_classes": 10, "width_mult": 1, "build_residual_branches": True})
  33. self.backbone_arch_params = copy.deepcopy(self.all_arch_params)
  34. self.backbone_arch_params.override(backbone_mode=True)
  35. def test_deployment_architecture(self):
  36. """
  37. Validate all models that has a deployment mode are in fact different after deployment
  38. """
  39. image_size = 224
  40. in_channels = 3
  41. for arch_name in ARCHITECTURES:
  42. # skip custom constructors to keep all_arch_params as general as a possible
  43. if "repvgg" not in arch_name or "custom" in arch_name:
  44. continue
  45. with self.subTest(arch_name=arch_name):
  46. # Set the seed to 0 to ensure that the model is initialized with the same weights
  47. torch.manual_seed(0)
  48. model = ARCHITECTURES[arch_name](arch_params=self.all_arch_params)
  49. self.assertTrue(hasattr(model.stem, "branch_3x3")) # check single layer for training mode
  50. self.assertTrue(model.build_residual_branches)
  51. training_mode_sd = model.state_dict()
  52. for module in training_mode_sd:
  53. self.assertFalse("reparam" in module) # deployment block included in training mode
  54. # Initializing input with 0.1 instead of 1.0 to move mean of input closer to 0
  55. test_input = torch.ones((1, in_channels, image_size, image_size)) * 0.1
  56. model.eval()
  57. training_mode_output = model(test_input)
  58. model.prep_model_for_conversion()
  59. self.assertTrue(hasattr(model.stem, "rbr_reparam")) # check single layer for training mode
  60. self.assertFalse(model.build_residual_branches)
  61. deployment_mode_sd = model.state_dict()
  62. for module in deployment_mode_sd:
  63. self.assertFalse("running_mean" in module) # BN were not fused
  64. self.assertFalse("branch" in module) # branches were not joined
  65. deployment_mode_output = model(test_input)
  66. # difference is of very low magnitude
  67. np.testing.assert_array_almost_equal(training_mode_output.detach().numpy(), deployment_mode_output.detach().numpy(), decimal=4)
  68. def test_backbone_mode(self):
  69. """
  70. Validate repvgg models (A1) as backbone.
  71. """
  72. image_size = 224
  73. in_channels = 3
  74. # Set the seed to 0 to ensure that the model is initialized with the same weights
  75. torch.manual_seed(0)
  76. test_input = torch.rand((1, in_channels, image_size, image_size))
  77. backbone_model = RepVggA1(self.backbone_arch_params)
  78. model = BackboneBasedModel(backbone_model, backbone_output_channel=1280, num_classes=self.backbone_arch_params.num_classes)
  79. backbone_model.eval()
  80. model.eval()
  81. backbone_training_mode_output = backbone_model(test_input)
  82. model_training_mode_output = model(test_input)
  83. # check that the linear head was dropped
  84. self.assertFalse(backbone_training_mode_output.shape[1] == self.backbone_arch_params.num_classes)
  85. training_mode_sd = model.state_dict()
  86. for module in training_mode_sd:
  87. self.assertFalse("reparam" in module) # deployment block included in training mode
  88. model.prep_model_for_conversion()
  89. deployment_mode_sd_list = list(model.state_dict().keys())
  90. self.assertTrue("bn.running_mean" in deployment_mode_sd_list) # Verify non backbone batch norm wasn't fused
  91. for module in deployment_mode_sd_list:
  92. self.assertFalse("running_mean" in module and module.startswith("backbone")) # BN were not fused
  93. self.assertFalse("branch" in module and module.startswith("backbone")) # branches were not joined
  94. model_deployment_mode_output = model(test_input)
  95. self.assertFalse(False in torch.isclose(model_deployment_mode_output, model_training_mode_output, atol=1e-5))
  96. if __name__ == "__main__":
  97. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...