Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#609 Ci fix

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/infra-000_ci
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
  1. import os
  2. import tempfile
  3. import pkg_resources
  4. import torch
  5. from super_gradients.common.abstractions.abstract_logger import get_logger
  6. from super_gradients.common import explicit_params_validation, ADNNModelRepositoryDataInterfaces
  7. from super_gradients.training.pretrained_models import MODEL_URLS
  8. from super_gradients.common.environment import environment_config
  9. try:
  10. from torch.hub import download_url_to_file, load_state_dict_from_url
  11. except (ModuleNotFoundError, ImportError, NameError):
  12. from torch.hub import _download_url_to_file as download_url_to_file
  13. logger = get_logger(__name__)
  14. def get_checkpoints_dir_path(experiment_name: str, ckpt_root_dir: str = None):
  15. """Creating the checkpoint directory of a given experiment.
  16. :param experiment_name: Name of the experiment.
  17. :param ckpt_root_dir: Local root directory path where all experiment logging directories will
  18. reside. When none is give, it is assumed that pkg_resources.resource_filename('checkpoints', "")
  19. exists and will be used.
  20. :return: checkpoints_dir_path
  21. """
  22. if ckpt_root_dir:
  23. return os.path.join(ckpt_root_dir, experiment_name)
  24. elif os.path.exists(environment_config.PKG_CHECKPOINTS_DIR):
  25. return os.path.join(environment_config.PKG_CHECKPOINTS_DIR, experiment_name)
  26. else:
  27. raise ValueError("Illegal checkpoints directory: pass ckpt_root_dir that exists, or add 'checkpoints' to resources.")
  28. def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt_name: str, external_checkpoint_path: str):
  29. """
  30. Gets the local path to the checkpoint file, which will be:
  31. - By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.
  32. - if the checkpoint file is remotely located:
  33. when overwrite_local_checkpoint=True then it will be saved in a temporary path which will be returned,
  34. otherwise it will be downloaded to YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name and overwrite
  35. YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name/ckpt_name if such file exists.
  36. - external_checkpoint_path when external_checkpoint_path != None
  37. @param source_ckpt_folder_name: The folder where the checkpoint is saved. When set to None- uses the experiment_name.
  38. @param experiment_name: experiment name attr in trainer
  39. @param ckpt_name: checkpoint filename
  40. @param external_checkpoint_path: full path to checkpoint file (that might be located outside of super_gradients/checkpoints directory)
  41. @return:
  42. """
  43. if external_checkpoint_path:
  44. return external_checkpoint_path
  45. else:
  46. checkpoints_folder_name = source_ckpt_folder_name or experiment_name
  47. checkpoints_dir_path = get_checkpoints_dir_path(checkpoints_folder_name)
  48. return os.path.join(checkpoints_dir_path, ckpt_name)
  49. def adaptive_load_state_dict(net: torch.nn.Module, state_dict: dict, strict: str):
  50. """
  51. Adaptively loads state_dict to net, by adapting the state_dict to net's layer names first.
  52. @param net: (nn.Module) to load state_dict to
  53. @param state_dict: (dict) Chekpoint state_dict
  54. @param strict: (str) key matching strictness
  55. @return:
  56. """
  57. try:
  58. net.load_state_dict(state_dict["net"] if "net" in state_dict.keys() else state_dict, strict=strict)
  59. except (RuntimeError, ValueError, KeyError) as ex:
  60. if strict == "no_key_matching":
  61. adapted_state_dict = adapt_state_dict_to_fit_model_layer_names(net.state_dict(), state_dict)
  62. net.load_state_dict(adapted_state_dict["net"], strict=True)
  63. else:
  64. raise_informative_runtime_error(net.state_dict(), state_dict, ex)
  65. @explicit_params_validation(validation_type="None")
  66. def copy_ckpt_to_local_folder(
  67. local_ckpt_destination_dir: str,
  68. ckpt_filename: str,
  69. remote_ckpt_source_dir: str = None,
  70. path_src: str = "local",
  71. overwrite_local_ckpt: bool = False,
  72. load_weights_only: bool = False,
  73. ):
  74. """
  75. Copy the checkpoint from any supported source to a local destination path
  76. :param local_ckpt_destination_dir: destination where the checkpoint will be saved to
  77. :param ckpt_filename: ckpt_best.pth Or ckpt_latest.pth
  78. :param remote_ckpt_source_dir: Name of the source checkpoint to be loaded (S3 Model\full URL)
  79. :param path_src: S3 / url
  80. :param overwrite_local_ckpt: determines if checkpoint will be saved in destination dir or in a temp folder
  81. :return: Path to checkpoint
  82. """
  83. ckpt_file_full_local_path = None
  84. # IF NOT DEFINED - IT IS SET TO THE TARGET's FOLDER NAME
  85. remote_ckpt_source_dir = local_ckpt_destination_dir if remote_ckpt_source_dir is None else remote_ckpt_source_dir
  86. if not overwrite_local_ckpt:
  87. # CREATE A TEMP FOLDER TO SAVE THE CHECKPOINT TO
  88. download_ckpt_destination_dir = tempfile.gettempdir()
  89. print(
  90. "PLEASE NOTICE - YOU ARE IMPORTING A REMOTE CHECKPOINT WITH overwrite_local_checkpoint = False "
  91. "-> IT WILL BE REDIRECTED TO A TEMP FOLDER AND DELETED ON MACHINE RESTART"
  92. )
  93. else:
  94. # SAVE THE CHECKPOINT TO MODEL's FOLDER
  95. download_ckpt_destination_dir = pkg_resources.resource_filename("checkpoints", local_ckpt_destination_dir)
  96. if path_src.startswith("s3"):
  97. model_checkpoints_data_interface = ADNNModelRepositoryDataInterfaces(data_connection_location=path_src)
  98. # DOWNLOAD THE FILE FROM S3 TO THE DESTINATION FOLDER
  99. ckpt_file_full_local_path = model_checkpoints_data_interface.load_remote_checkpoints_file(
  100. ckpt_source_remote_dir=remote_ckpt_source_dir,
  101. ckpt_destination_local_dir=download_ckpt_destination_dir,
  102. ckpt_file_name=ckpt_filename,
  103. overwrite_local_checkpoints_file=overwrite_local_ckpt,
  104. )
  105. if not load_weights_only:
  106. # COPY LOG FILES FROM THE REMOTE DIRECTORY TO THE LOCAL ONE ONLY IF LOADING THE CURRENT MODELs CKPT
  107. model_checkpoints_data_interface.load_all_remote_log_files(
  108. model_name=remote_ckpt_source_dir, model_checkpoint_local_dir=download_ckpt_destination_dir
  109. )
  110. if path_src == "url":
  111. ckpt_file_full_local_path = download_ckpt_destination_dir + os.path.sep + ckpt_filename
  112. # DOWNLOAD THE FILE FROM URL TO THE DESTINATION FOLDER
  113. download_url_to_file(remote_ckpt_source_dir, ckpt_file_full_local_path, progress=True)
  114. return ckpt_file_full_local_path
  115. def read_ckpt_state_dict(ckpt_path: str, device="cpu"):
  116. if not os.path.exists(ckpt_path):
  117. raise FileNotFoundError(f"Incorrect Checkpoint path: {ckpt_path} (This should be an absolute path)")
  118. if device == "cuda":
  119. state_dict = torch.load(ckpt_path)
  120. else:
  121. state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
  122. return state_dict
  123. def adapt_state_dict_to_fit_model_layer_names(model_state_dict: dict, source_ckpt: dict, exclude: list = [], solver: callable = None):
  124. """
  125. Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit
  126. the ckpt in order to properly load the weights into the model. If unsuccessful - returns None
  127. :param model_state_dict: the model state_dict
  128. :param source_ckpt: checkpoint dict
  129. :param exclude optional list for excluded layers
  130. :param solver: callable with signature (ckpt_key, ckpt_val, model_key, model_val)
  131. that returns a desired weight for ckpt_val.
  132. :return: renamed checkpoint dict (if possible)
  133. """
  134. if "net" in source_ckpt.keys():
  135. source_ckpt = source_ckpt["net"]
  136. model_state_dict_excluded = {k: v for k, v in model_state_dict.items() if not any(x in k for x in exclude)}
  137. new_ckpt_dict = {}
  138. for (ckpt_key, ckpt_val), (model_key, model_val) in zip(source_ckpt.items(), model_state_dict_excluded.items()):
  139. if solver is not None:
  140. ckpt_val = solver(ckpt_key, ckpt_val, model_key, model_val)
  141. if ckpt_val.shape != model_val.shape:
  142. raise ValueError(f"ckpt layer {ckpt_key} with shape {ckpt_val.shape} does not match {model_key}" f" with shape {model_val.shape} in the model")
  143. new_ckpt_dict[model_key] = ckpt_val
  144. return {"net": new_ckpt_dict}
  145. def raise_informative_runtime_error(state_dict, checkpoint, exception_msg):
  146. """
  147. Given a model state dict and source checkpoints, the method calls "adapt_state_dict_to_fit_model_layer_names"
  148. and enhances the exception_msg if loading the checkpoint_dict via the conversion method is possible
  149. """
  150. try:
  151. new_ckpt_dict = adapt_state_dict_to_fit_model_layer_names(state_dict, checkpoint)
  152. temp_file = tempfile.NamedTemporaryFile().name + ".pt"
  153. torch.save(new_ckpt_dict, temp_file)
  154. exception_msg = (
  155. f"\n{'=' * 200}\n{str(exception_msg)} \nconvert ckpt via the utils.adapt_state_dict_to_fit_"
  156. f"model_layer_names method\na converted checkpoint file was saved in the path {temp_file}\n{'=' * 200}"
  157. )
  158. except ValueError as ex: # IN CASE adapt_state_dict_to_fit_model_layer_names WAS UNSUCCESSFUL
  159. exception_msg = f"\n{'=' * 200} \nThe checkpoint and model shapes do no fit, e.g.: {ex}\n{'=' * 200}"
  160. finally:
  161. raise RuntimeError(exception_msg)
  162. def load_checkpoint_to_model(
  163. ckpt_local_path: str, load_backbone: bool, net: torch.nn.Module, strict: str, load_weights_only: bool, load_ema_as_net: bool = False
  164. ):
  165. """
  166. Loads the state dict in ckpt_local_path to net and returns the checkpoint's state dict.
  167. @param load_ema_as_net: Will load the EMA inside the checkpoint file to the network when set
  168. @param ckpt_local_path: local path to the checkpoint file
  169. @param load_backbone: whether to load the checkpoint as a backbone
  170. @param net: network to load the checkpoint to
  171. @param strict:
  172. @param load_weights_only:
  173. @return:
  174. """
  175. if ckpt_local_path is None or not os.path.exists(ckpt_local_path):
  176. error_msg = "Error - loading Model Checkpoint: Path {} does not exist".format(ckpt_local_path)
  177. raise RuntimeError(error_msg)
  178. if load_backbone and not hasattr(net, "backbone"):
  179. raise ValueError("No backbone attribute in net - Can't load backbone weights")
  180. # LOAD THE LOCAL CHECKPOINT PATH INTO A state_dict OBJECT
  181. checkpoint = read_ckpt_state_dict(ckpt_path=ckpt_local_path)
  182. if load_ema_as_net:
  183. if "ema_net" not in checkpoint.keys():
  184. raise ValueError("Can't load ema network- no EMA network stored in checkpoint file")
  185. else:
  186. checkpoint["net"] = checkpoint["ema_net"]
  187. # LOAD THE CHECKPOINTS WEIGHTS TO THE MODEL
  188. if load_backbone:
  189. adaptive_load_state_dict(net.backbone, checkpoint, strict)
  190. else:
  191. adaptive_load_state_dict(net, checkpoint, strict)
  192. message_suffix = " checkpoint." if not load_ema_as_net else " EMA checkpoint."
  193. message_model = "model" if not load_backbone else "model's backbone"
  194. logger.info("Successfully loaded " + message_model + " weights from " + ckpt_local_path + message_suffix)
  195. if load_weights_only or load_backbone:
  196. # DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS
  197. [checkpoint.pop(key) for key in list(checkpoint.keys()) if key != "net"]
  198. return checkpoint
  199. class MissingPretrainedWeightsException(Exception):
  200. """Exception raised by unsupported pretrianed model.
  201. Attributes:
  202. message -- explanation of the error
  203. """
  204. def __init__(self, desc):
  205. self.message = "Missing pretrained wights: " + desc
  206. super().__init__(self.message)
  207. def _yolox_ckpt_solver(ckpt_key, ckpt_val, model_key, model_val):
  208. """
  209. Helper method for reshaping old pretrained checkpoint's focus weights to 6x6 conv weights.
  210. """
  211. if (
  212. ckpt_val.shape != model_val.shape
  213. and ckpt_key == "module._backbone._modules_list.0.conv.conv.weight"
  214. and model_key == "_backbone._modules_list.0.conv.weight"
  215. ):
  216. model_val.data[:, :, ::2, ::2] = ckpt_val.data[:, :3]
  217. model_val.data[:, :, 1::2, ::2] = ckpt_val.data[:, 3:6]
  218. model_val.data[:, :, ::2, 1::2] = ckpt_val.data[:, 6:9]
  219. model_val.data[:, :, 1::2, 1::2] = ckpt_val.data[:, 9:12]
  220. replacement = model_val
  221. else:
  222. replacement = ckpt_val
  223. return replacement
  224. def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):
  225. """
  226. Loads pretrained weights from the MODEL_URLS dictionary to model
  227. @param architecture: name of the model's architecture
  228. @param model: model to load pretrinaed weights for
  229. @param pretrained_weights: name for the pretrianed weights (i.e imagenet)
  230. @return: None
  231. """
  232. model_url_key = architecture + "_" + str(pretrained_weights)
  233. if model_url_key not in MODEL_URLS.keys():
  234. raise MissingPretrainedWeightsException(model_url_key)
  235. url = MODEL_URLS[model_url_key]
  236. unique_filename = url.split("https://deci-pretrained-models.s3.amazonaws.com/")[1].replace("/", "_").replace(" ", "_")
  237. map_location = torch.device("cpu")
  238. pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
  239. _load_weights(architecture, model, pretrained_state_dict)
  240. def _load_weights(architecture, model, pretrained_state_dict):
  241. if "ema_net" in pretrained_state_dict.keys():
  242. pretrained_state_dict["net"] = pretrained_state_dict["ema_net"]
  243. solver = _yolox_ckpt_solver if "yolox" in architecture else None
  244. adapted_pretrained_state_dict = adapt_state_dict_to_fit_model_layer_names(
  245. model_state_dict=model.state_dict(), source_ckpt=pretrained_state_dict, solver=solver
  246. )
  247. model.load_state_dict(adapted_pretrained_state_dict["net"], strict=False)
  248. def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str):
  249. """
  250. Loads pretrained weights from the MODEL_URLS dictionary to model
  251. @param architecture: name of the model's architecture
  252. @param model: model to load pretrinaed weights for
  253. @param pretrained_weights: path tp pretrained weights
  254. @return: None
  255. """
  256. map_location = torch.device("cpu")
  257. pretrained_state_dict = torch.load(pretrained_weights, map_location=map_location)
  258. _load_weights(architecture, model, pretrained_state_dict)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file