Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

strictload_enum_test.py 5.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
  1. import shutil
  2. import tempfile
  3. import unittest
  4. import os
  5. from super_gradients.training import SgModel
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from super_gradients.training.sg_model.sg_model import StrictLoad
  10. class Net(nn.Module):
  11. def __init__(self):
  12. super(Net, self).__init__()
  13. self.conv1 = nn.Conv2d(3, 6, 3)
  14. self.pool = nn.MaxPool2d(2, 2)
  15. self.conv2 = nn.Conv2d(6, 16, 3)
  16. self.fc1 = nn.Linear(16 * 3 * 3, 120)
  17. self.fc2 = nn.Linear(120, 84)
  18. self.fc3 = nn.Linear(84, 10)
  19. def forward(self, x):
  20. x = self.pool(F.relu(self.conv1(x)))
  21. x = self.pool(F.relu(self.conv2(x)))
  22. x = x.view(-1, 16 * 3 * 3)
  23. x = F.relu(self.fc1(x))
  24. x = F.relu(self.fc2(x))
  25. x = self.fc3(x)
  26. return x
  27. class StrictLoadEnumTest(unittest.TestCase):
  28. @classmethod
  29. def setUpClass(cls):
  30. cls.temp_working_file_dir = tempfile.TemporaryDirectory(prefix='strict_load_test').name
  31. if not os.path.isdir(cls.temp_working_file_dir):
  32. os.mkdir(cls.temp_working_file_dir)
  33. cls.experiment_name = 'load_checkpoint_test'
  34. cls.checkpoint_diff_keys_name = 'strict_load_test_diff_keys.pth'
  35. cls.checkpoint_diff_keys_path = cls.temp_working_file_dir + '/' + cls.checkpoint_diff_keys_name
  36. # Setup the model
  37. cls.original_torch_net = Net()
  38. # Save the model's state_dict checkpoint with different keys
  39. torch.save(cls.change_state_dict_keys(cls.original_torch_net.state_dict()), cls.checkpoint_diff_keys_path)
  40. # Save the model's state_dict checkpoint in SgModel format
  41. cls.sg_model = SgModel("load_checkpoint_test", model_checkpoints_location='local') # Saves in /checkpoints
  42. cls.sg_model.build_model(cls.original_torch_net, arch_params={'num_classes': 10})
  43. cls.sg_model.save_checkpoint()
  44. @classmethod
  45. def tearDownClass(cls):
  46. if os.path.isdir(cls.temp_working_file_dir):
  47. shutil.rmtree(cls.temp_working_file_dir)
  48. @classmethod
  49. def change_state_dict_keys(self, state_dict):
  50. new_ckpt_dict = {}
  51. for i, (ckpt_key, ckpt_val) in enumerate(state_dict.items()):
  52. new_ckpt_dict[i] = ckpt_val
  53. return new_ckpt_dict
  54. def check_models_have_same_weights(self, model_1, model_2):
  55. model_1, model_2 = model_1.to('cpu'), model_2.to('cpu')
  56. models_differ = 0
  57. for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
  58. if torch.equal(key_item_1[1], key_item_2[1]):
  59. pass
  60. else:
  61. models_differ += 1
  62. if (key_item_1[0] == key_item_2[0]):
  63. print('Mismtach found at', key_item_1[0])
  64. else:
  65. raise Exception
  66. if models_differ == 0:
  67. return True
  68. else:
  69. return False
  70. def test_strict_load_on(self):
  71. # Define Model
  72. new_torch_net = Net()
  73. # Make sure we initialized a model with different weights
  74. assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
  75. # Build the SgModel and load the checkpoint
  76. model = SgModel(self.experiment_name, model_checkpoints_location='local',
  77. ckpt_name='ckpt_latest_weights_only.pth')
  78. model.build_model(new_torch_net, arch_params={'num_classes': 10}, strict_load=StrictLoad.ON,
  79. load_checkpoint=True)
  80. # Assert the weights were loaded correctly
  81. assert self.check_models_have_same_weights(model.net, self.original_torch_net)
  82. def test_strict_load_off(self):
  83. # Define Model
  84. new_torch_net = Net()
  85. # Make sure we initialized a model with different weights
  86. assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
  87. # Build the SgModel and load the checkpoint
  88. model = SgModel(self.experiment_name, model_checkpoints_location='local',
  89. ckpt_name='ckpt_latest_weights_only.pth')
  90. model.build_model(new_torch_net, arch_params={'num_classes': 10}, strict_load=StrictLoad.OFF,
  91. load_checkpoint=True)
  92. # Assert the weights were loaded correctly
  93. assert self.check_models_have_same_weights(model.net, self.original_torch_net)
  94. def test_strict_load_no_key_matching_external_checkpoint(self):
  95. # Define Model
  96. new_torch_net = Net()
  97. # Make sure we initialized a model with different weights
  98. assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
  99. # Build the SgModel and load the checkpoint
  100. model = SgModel(self.experiment_name, model_checkpoints_location='local')
  101. model.build_model(new_torch_net, arch_params={'num_classes': 10}, strict_load=StrictLoad.NO_KEY_MATCHING,
  102. external_checkpoint_path=self.checkpoint_diff_keys_path, load_checkpoint=True)
  103. # Assert the weights were loaded correctly
  104. assert self.check_models_have_same_weights(model.net, self.original_torch_net)
  105. def test_strict_load_no_key_matching_sg_checkpoint(self):
  106. # Define Model
  107. new_torch_net = Net()
  108. # Make sure we initialized a model with different weights
  109. assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
  110. # Build the SgModel and load the checkpoint
  111. model = SgModel(self.experiment_name, model_checkpoints_location='local',
  112. ckpt_name='ckpt_latest_weights_only.pth')
  113. model.build_model(new_torch_net, arch_params={'num_classes': 10}, strict_load=StrictLoad.NO_KEY_MATCHING,
  114. load_checkpoint=True)
  115. # Assert the weights were loaded correctly
  116. assert self.check_models_have_same_weights(model.net, self.original_torch_net)
  117. if __name__ == '__main__':
  118. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...