Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

checkpoint_utils.py 12 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
  1. import os
  2. import tempfile
  3. import pkg_resources
  4. import torch
  5. from super_gradients.common import explicit_params_validation, ADNNModelRepositoryDataInterfaces
  6. from super_gradients.training.pretrained_models import MODEL_URLS
  7. try:
  8. from torch.hub import download_url_to_file, load_state_dict_from_url
  9. except (ModuleNotFoundError, ImportError, NameError):
  10. from torch.hub import _download_url_to_file as download_url_to_file
  11. def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt_name: str, model_checkpoints_location: str, external_checkpoint_path: str, overwrite_local_checkpoint: bool, load_weights_only: bool):
  12. """
  13. Gets the local path to the checkpoint file, which will be:
  14. - By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.
  15. - if the checkpoint file is remotely located:
  16. when overwrite_local_checkpoint=True then it will be saved in a temporary path which will be returned,
  17. otherwise it will be downloaded to YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name and overwrite
  18. YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name/ckpt_name if such file exists.
  19. - external_checkpoint_path when external_checkpoint_path != None
  20. @param source_ckpt_folder_name: The folder where the checkpoint is saved. When set to None- uses the experiment_name.
  21. @param experiment_name: experiment name attr in sg_model
  22. @param ckpt_name: checkpoint filename
  23. @param model_checkpoints_location: S3, local ot URL
  24. @param external_checkpoint_path: full path to checkpoint file (that might be located outside of super_gradients/checkpoints directory)
  25. @param overwrite_local_checkpoint: whether to overwrite the checkpoint file with the same name when downloading from S3.
  26. @param load_weights_only: whether to load the network's state dict only.
  27. @return:
  28. """
  29. source_ckpt_folder_name = source_ckpt_folder_name or experiment_name
  30. if model_checkpoints_location == 'local':
  31. default_local_ckpt_path = pkg_resources.resource_filename('checkpoints',
  32. source_ckpt_folder_name + os.path.sep + ckpt_name)
  33. ckpt_local_path = external_checkpoint_path or default_local_ckpt_path
  34. # COPY THE DATA FROM 'S3'/'URL' INTO A LOCAL DIRECTORY
  35. elif model_checkpoints_location.startswith('s3') or model_checkpoints_location == 'url':
  36. # COPY REMOTE DATA TO A LOCAL DIRECTORY AND GET THAT DIRECTORYs NAME
  37. ckpt_local_path = copy_ckpt_to_local_folder(local_ckpt_destination_dir=experiment_name,
  38. ckpt_filename=ckpt_name,
  39. remote_ckpt_source_dir=source_ckpt_folder_name,
  40. path_src=model_checkpoints_location,
  41. overwrite_local_ckpt=overwrite_local_checkpoint,
  42. load_weights_only=load_weights_only)
  43. else:
  44. # ERROR IN USER CODE FLOW - THIS WILL EVENTUALLY RAISE AN EXCEPTION
  45. raise NotImplementedError(
  46. 'model_checkpoints_data_source: ' + str(model_checkpoints_location) + 'not supported')
  47. return ckpt_local_path
  48. def adaptive_load_state_dict(net: torch.nn.Module, state_dict: dict, strict: str):
  49. """
  50. Adaptively loads state_dict to net, by adapting the state_dict to net's layer names first.
  51. @param net: (nn.Module) to load state_dict to
  52. @param state_dict: (dict) Chekpoint state_dict
  53. @param strict: (str) key matching strictness
  54. @return:
  55. """
  56. try:
  57. net.load_state_dict(state_dict['net'], strict=strict)
  58. except (RuntimeError, ValueError, KeyError) as ex:
  59. if strict == 'no_key_matching':
  60. adapted_state_dict = adapt_state_dict_to_fit_model_layer_names(net.state_dict(), state_dict)
  61. net.load_state_dict(adapted_state_dict['net'], strict=True)
  62. else:
  63. raise_informative_runtime_error(net.state_dict(), state_dict, ex)
  64. @explicit_params_validation(validation_type='None')
  65. def copy_ckpt_to_local_folder(local_ckpt_destination_dir: str, ckpt_filename: str, remote_ckpt_source_dir: str = None,
  66. path_src: str = 'local', overwrite_local_ckpt: bool = False,
  67. load_weights_only: bool = False):
  68. """
  69. Copy the checkpoint from any supported source to a local destination path
  70. :param local_ckpt_destination_dir: destination where the checkpoint will be saved to
  71. :param ckpt_filename: ckpt_best.pth Or ckpt_latest.pth
  72. :param remote_ckpt_source_dir: Name of the source checkpoint to be loaded (S3 Model\full URL)
  73. :param path_src: S3 / url
  74. :param overwrite_local_ckpt: determines if checkpoint will be saved in destination dir or in a temp folder
  75. :return: Path to checkpoint
  76. """
  77. ckpt_file_full_local_path = None
  78. # IF NOT DEFINED - IT IS SET TO THE TARGET's FOLDER NAME
  79. remote_ckpt_source_dir = local_ckpt_destination_dir if remote_ckpt_source_dir is None else remote_ckpt_source_dir
  80. if not overwrite_local_ckpt:
  81. # CREATE A TEMP FOLDER TO SAVE THE CHECKPOINT TO
  82. download_ckpt_destination_dir = tempfile.gettempdir()
  83. print('PLEASE NOTICE - YOU ARE IMPORTING A REMOTE CHECKPOINT WITH overwrite_local_checkpoint = False '
  84. '-> IT WILL BE REDIRECTED TO A TEMP FOLDER AND DELETED ON MACHINE RESTART')
  85. else:
  86. # SAVE THE CHECKPOINT TO MODEL's FOLDER
  87. download_ckpt_destination_dir = pkg_resources.resource_filename('checkpoints', local_ckpt_destination_dir)
  88. if path_src.startswith('s3'):
  89. model_checkpoints_data_interface = ADNNModelRepositoryDataInterfaces(data_connection_location=path_src)
  90. # DOWNLOAD THE FILE FROM S3 TO THE DESTINATION FOLDER
  91. ckpt_file_full_local_path = model_checkpoints_data_interface.load_remote_checkpoints_file(
  92. ckpt_source_remote_dir=remote_ckpt_source_dir,
  93. ckpt_destination_local_dir=download_ckpt_destination_dir,
  94. ckpt_file_name=ckpt_filename,
  95. overwrite_local_checkpoints_file=overwrite_local_ckpt)
  96. if not load_weights_only:
  97. # COPY LOG FILES FROM THE REMOTE DIRECTORY TO THE LOCAL ONE ONLY IF LOADING THE CURRENT MODELs CKPT
  98. model_checkpoints_data_interface.load_all_remote_log_files(model_name=remote_ckpt_source_dir,
  99. model_checkpoint_local_dir=download_ckpt_destination_dir)
  100. if path_src == 'url':
  101. ckpt_file_full_local_path = download_ckpt_destination_dir + os.path.sep + ckpt_filename
  102. # DOWNLOAD THE FILE FROM URL TO THE DESTINATION FOLDER
  103. download_url_to_file(remote_ckpt_source_dir, ckpt_file_full_local_path, progress=True)
  104. return ckpt_file_full_local_path
  105. def read_ckpt_state_dict(ckpt_path: str, device="cpu"):
  106. if not os.path.exists(ckpt_path):
  107. raise ValueError('Incorrect Checkpoint path')
  108. if device == "cuda":
  109. state_dict = torch.load(ckpt_path)
  110. else:
  111. state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
  112. return state_dict
  113. def adapt_state_dict_to_fit_model_layer_names(model_state_dict: dict, source_ckpt: dict, exclude: list = []):
  114. """
  115. Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit
  116. the ckpt in order to properly load the weights into the model. If unsuccessful - returns None
  117. :param model_state_dict: the model state_dict
  118. :param source_ckpt: checkpoint dict
  119. :exclude optional list for excluded layers
  120. :return: renamed checkpoint dict (if possible)
  121. """
  122. if 'net' in source_ckpt.keys():
  123. source_ckpt = source_ckpt['net']
  124. model_state_dict_excluded = {k: v for k, v in model_state_dict.items() if not any(x in k for x in exclude)}
  125. new_ckpt_dict = {}
  126. for (ckpt_key, ckpt_val), (model_key, model_val) in zip(source_ckpt.items(), model_state_dict_excluded.items()):
  127. if ckpt_val.shape != model_val.shape:
  128. raise ValueError(f'ckpt layer {ckpt_key} with shape {ckpt_val.shape} does not match {model_key}'
  129. f' with shape {model_val.shape} in the model')
  130. new_ckpt_dict[model_key] = ckpt_val
  131. return {'net': new_ckpt_dict}
  132. def raise_informative_runtime_error(state_dict, checkpoint, exception_msg):
  133. """
  134. Given a model state dict and source checkpoints, the method calls "adapt_state_dict_to_fit_model_layer_names"
  135. and enhances the exception_msg if loading the checkpoint_dict via the conversion method is possible
  136. """
  137. try:
  138. new_ckpt_dict = adapt_state_dict_to_fit_model_layer_names(state_dict, checkpoint)
  139. temp_file = tempfile.NamedTemporaryFile().name + '.pt'
  140. torch.save(new_ckpt_dict, temp_file)
  141. exception_msg = f"\n{'=' * 200}\n{str(exception_msg)} \nconvert ckpt via the utils.adapt_state_dict_to_fit_" \
  142. f"model_layer_names method\na converted checkpoint file was saved in the path {temp_file}\n{'=' * 200}"
  143. except ValueError as ex: # IN CASE adapt_state_dict_to_fit_model_layer_names WAS UNSUCCESSFUL
  144. exception_msg = f"\n{'=' * 200} \nThe checkpoint and model shapes do no fit, e.g.: {ex}\n{'=' * 200}"
  145. finally:
  146. raise RuntimeError(exception_msg)
  147. def load_checkpoint_to_model(ckpt_local_path: str, load_backbone: bool, net: torch.nn.Module, strict: str, load_weights_only: bool, load_ema_as_net: bool = False):
  148. """
  149. Loads the state dict in ckpt_local_path to net and returns the checkpoint's state dict.
  150. @param load_ema_as_net: Will load the EMA inside the checkpoint file to the network when set
  151. @param ckpt_local_path: local path to the checkpoint file
  152. @param load_backbone: whether to load the checkpoint as a backbone
  153. @param net: network to load the checkpoint to
  154. @param strict:
  155. @param load_weights_only:
  156. @return:
  157. """
  158. if ckpt_local_path is None or not os.path.exists(ckpt_local_path):
  159. error_msg = 'Error - loading Model Checkpoint: Path {} does not exist'.format(ckpt_local_path)
  160. raise RuntimeError(error_msg)
  161. if load_backbone and not hasattr(net.module, 'backbone'):
  162. raise ValueError("No backbone attribute in net - Can't load backbone weights")
  163. # LOAD THE LOCAL CHECKPOINT PATH INTO A state_dict OBJECT
  164. checkpoint = read_ckpt_state_dict(ckpt_path=ckpt_local_path)
  165. if load_ema_as_net:
  166. if 'ema_net' not in checkpoint.keys():
  167. raise ValueError("Can't load ema network- no EMA network stored in checkpoint file")
  168. else:
  169. checkpoint['net'] = checkpoint['ema_net']
  170. # LOAD THE CHECKPOINTS WEIGHTS TO THE MODEL
  171. if load_backbone:
  172. adaptive_load_state_dict(net.module.backbone, checkpoint, strict)
  173. else:
  174. adaptive_load_state_dict(net, checkpoint, strict)
  175. if load_weights_only or load_backbone:
  176. # DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS
  177. [checkpoint.pop(key) for key in list(checkpoint.keys()) if key != 'net']
  178. return checkpoint
  179. class MissingPretrainedWeightsException(Exception):
  180. """Exception raised by unsupported pretrianed model.
  181. Attributes:
  182. message -- explanation of the error
  183. """
  184. def __init__(self, desc):
  185. self.message = "Missing pretrained wights: " + desc
  186. super().__init__(self.message)
  187. def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):
  188. """
  189. Loads pretrained weights from the MODEL_URLS dictionary to model
  190. @param architecture: name of the model's architecture
  191. @param model: model to load pretrinaed weights for
  192. @param pretrained_weights: name for the pretrianed weights (i.e imagenet)
  193. @return: None
  194. """
  195. model_url_key = architecture + '_' + str(pretrained_weights)
  196. if model_url_key not in MODEL_URLS.keys():
  197. raise MissingPretrainedWeightsException(model_url_key)
  198. url = MODEL_URLS[model_url_key]
  199. map_location = torch.device('cpu') if not torch.cuda.is_available() else None
  200. pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location)
  201. adapted_pretrained_state_dict = adapt_state_dict_to_fit_model_layer_names(model_state_dict=model.state_dict(), source_ckpt=pretrained_state_dict)
  202. model.load_state_dict(adapted_pretrained_state_dict['net'])
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...