Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

pre_launch_callbacks.py 21 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
  1. import copy
  2. from copy import deepcopy
  3. from typing import Union
  4. from omegaconf import DictConfig
  5. import torch
  6. from super_gradients.common.environment.cfg_utils import load_recipe
  7. from super_gradients.common.registry.registry import register_pre_launch_callback
  8. from super_gradients import is_distributed
  9. from super_gradients.common.abstractions.abstract_logger import get_logger
  10. from super_gradients.training import models
  11. from torch.distributed import barrier
  12. import cv2
  13. import numpy as np
  14. from super_gradients.training.utils import get_param
  15. logger = get_logger(__name__)
  16. class PreLaunchCallback:
  17. """
  18. PreLaunchCallback
  19. Base class for callbacks to be triggered, manipulating the config (cfg) prior to launching training,
  20. when calling Trainer.train_from_config(cfg).
  21. """
  22. def __call__(self, cfg: Union[dict, DictConfig]) -> Union[dict, DictConfig]:
  23. raise NotImplementedError
  24. @register_pre_launch_callback()
  25. class AutoTrainBatchSizeSelectionCallback(PreLaunchCallback):
  26. """
  27. AutoTrainBatchSizeSelectionCallback
  28. Modifies cfg.dataset_params.train_dataloader_params.batch_size by searching for the maximal batch size that fits
  29. gpu memory/ the one resulting in fastest time for the selected number of train datalaoder iterations. Works out of the box for DDP.
  30. The search is done by running a few forward passes for increasing batch sizes, until CUDA OUT OF MEMORY is raised:
  31. For batch_size in range(min_batch_size:max_batch_size:size_step):
  32. if batch_size raises CUDA OUT OF MEMORY ERROR:
  33. return batch_size-size_step
  34. return batch_size
  35. Example usage: Inside the main recipe .YAML file (for example super_gradients/recipes/cifar10_resnet.yaml),
  36. add the following:
  37. pre_launch_callbacks_list:
  38. - AutoTrainBatchSizeSelectionCallback:
  39. min_batch_size: 128
  40. size_step: 64
  41. num_forward_passes: 10
  42. Then, when running super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=...
  43. this pre_launch_callback will modify cfg.dataset_params.train_dataloader_params.batch_size then pass cfg to
  44. Trainer.train_from_config(cfg) and training will continue with the selected batch size.
  45. :param min_batch_size: int, the first batch size to try running forward passes. Should fit memory.
  46. :param size_step: int, the difference between 2 consecutive batch_size trials.
  47. :param num_forward_passes: int, number of forward passes (i.e train_loader data iterations inside an epoch).
  48. Note that the more forward passes being done, the less the selected batch size is prawn to fail. This is because
  49. other then gradients, model computations, data and other fixed gpu memory that is being used- some more gpu memory
  50. might be used by the metric objects and PhaseCallbacks.
  51. :param max_batch_size: int, optional, upper limit of the batch sizes to try. When None, the search will continue until
  52. the maximal batch size that does not raise CUDA OUT OF MEMORY is found (deafult=None).
  53. :param scale_lr: bool, whether to linearly scale cfg.training_hyperparamsinitial_lr, i.e multiply by
  54. FOUND_BATCH_SIZE/cfg.dataset_params.train_datalaoder_params.batch_size (default=True)
  55. :param mode: str, one of ["fastest","largest"], whether to select the largest batch size that fits memory or the one
  56. that the resulted in overall fastest execution.
  57. """
  58. def __init__(self, min_batch_size: int, size_step: int, num_forward_passes: int = 3, max_batch_size=None, scale_lr: bool = True, mode: str = "fastest"):
  59. if mode not in ["fastest", "largest"]:
  60. raise TypeError(f"Expected mode to be one of: ['fastest','largest'], got {mode}")
  61. self.scale_lr = scale_lr
  62. self.min_batch_size = min_batch_size
  63. self.size_step = size_step
  64. self.max_batch_size = max_batch_size
  65. self.num_forward_passes = num_forward_passes
  66. self.mode = mode
  67. def __call__(self, cfg: DictConfig) -> DictConfig:
  68. # IMPORT IS HERE DUE TO CIRCULAR IMPORT PROBLEM
  69. from super_gradients.training.sg_trainer import Trainer
  70. curr_batch_size = self.min_batch_size
  71. # BUILD NETWORK
  72. model = models.get(
  73. model_name=cfg.architecture,
  74. num_classes=cfg.arch_params.num_classes,
  75. arch_params=cfg.arch_params,
  76. strict_load=cfg.checkpoint_params.strict_load,
  77. pretrained_weights=cfg.checkpoint_params.pretrained_weights,
  78. checkpoint_path=cfg.checkpoint_params.checkpoint_path,
  79. load_backbone=cfg.checkpoint_params.load_backbone,
  80. )
  81. tmp_cfg = deepcopy(cfg)
  82. tmp_cfg.training_hyperparamsbatch_accumulate = 1
  83. tmp_cfg.training_hyperparamsmax_train_batches = self.num_forward_passes
  84. tmp_cfg.training_hyperparamsrun_validation_freq = 2
  85. tmp_cfg.training_hyperparamssilent_mode = True
  86. tmp_cfg.training_hyperparamssave_model = False
  87. tmp_cfg.training_hyperparamsmax_epochs = 1
  88. tmp_cfg.training_hyperparamsaverage_best_models = False
  89. tmp_cfg.training_hyperparamskill_ddp_pgroup_on_end = False
  90. tmp_cfg.pre_launch_callbacks_list = []
  91. fastest_batch_time = np.inf
  92. fastest_batch_size = curr_batch_size
  93. bs_found = False
  94. while not bs_found:
  95. tmp_cfg.dataset_params.train_dataloader_params.batch_size = curr_batch_size
  96. try:
  97. passes_start = cv2.getTickCount()
  98. Trainer.train_from_config(tmp_cfg)
  99. curr_batch_time = (cv2.getTickCount() - passes_start) / cv2.getTickFrequency()
  100. logger.info(f"Batch size = {curr_batch_size} time for {self.num_forward_passes} forward passes: {curr_batch_time} seconds.")
  101. if curr_batch_time < fastest_batch_time:
  102. fastest_batch_size = curr_batch_size
  103. fastest_batch_time = curr_batch_time
  104. except RuntimeError as e:
  105. if "out of memory" in str(e):
  106. if curr_batch_size == self.min_batch_size:
  107. logger.error("Ran out of memory for the smallest batch, try setting smaller min_batch_size.")
  108. raise e
  109. else:
  110. selected_batch_size = curr_batch_size - self.size_step if self.mode == "largest" else fastest_batch_size
  111. msg = f"Ran out of memory for {curr_batch_size}, setting batch size to {selected_batch_size}."
  112. bs_found = True
  113. else:
  114. raise e
  115. else:
  116. if self.max_batch_size is not None and curr_batch_size >= self.max_batch_size:
  117. selected_batch_size = self.max_batch_size if self.mode == "largest" else fastest_batch_size
  118. msg = (
  119. f"Did not run out of memory for {curr_batch_size} >= max_batch_size={self.max_batch_size}, " f"setting batch to {selected_batch_size}."
  120. )
  121. bs_found = True
  122. else:
  123. logger.info(f"Did not run out of memory for {curr_batch_size}, retrying batch {curr_batch_size + self.size_step}.")
  124. curr_batch_size += self.size_step
  125. self._clear_model_gpu_mem(model)
  126. return self._inject_selected_batch_size_to_config(cfg, model, msg, selected_batch_size)
  127. def _inject_selected_batch_size_to_config(self, cfg, model, msg, selected_batch_size):
  128. logger.info(msg)
  129. self._adapt_lr_if_needed(cfg, found_batch_size=selected_batch_size)
  130. cfg.dataset_params.train_dataloader_params.batch_size = selected_batch_size
  131. self._clear_model_gpu_mem(model)
  132. return cfg
  133. def _adapt_lr_if_needed(self, cfg: DictConfig, found_batch_size: int) -> DictConfig:
  134. if self.scale_lr:
  135. scale_factor = found_batch_size / cfg.dataset_params.train_dataloader_params.batch_size
  136. cfg.training_hyperparamsinitial_lr = cfg.training_hyperparamsinitial_lr * scale_factor
  137. return cfg
  138. @classmethod
  139. def _clear_model_gpu_mem(cls, model):
  140. for p in model.parameters():
  141. if p.grad is not None:
  142. del p.grad # free some memory
  143. torch.cuda.empty_cache()
  144. # WAIT FOR ALL PROCESSES TO CLEAR THEIR MEMORY BEFORE MOVING ON
  145. if is_distributed():
  146. barrier()
  147. def modify_params_for_qat(
  148. training_hyperparams,
  149. train_dataset_params,
  150. val_dataset_params,
  151. train_dataloader_params,
  152. val_dataloader_params,
  153. quantization_params=None,
  154. batch_size_divisor: int = 2,
  155. max_epochs_divisor: int = 10,
  156. lr_decay_factor: float = 0.01,
  157. warmup_epochs_divisor: int = 10,
  158. cosine_final_lr_ratio: float = 0.01,
  159. disable_phase_callbacks: bool = True,
  160. disable_augmentations: bool = False,
  161. ):
  162. """
  163. This method modifies the recipe for QAT to implement rules of thumb based on the regular non-qat recipe.
  164. It does so by manipulating the training_hyperparams, train_dataloader_params, val_dataloader_params, train_dataset_params, val_dataset_params.
  165. Usage:
  166. trainer = Trainer("test_launch_qat_with_minimal_changes")
  167. net = ResNet18(num_classes=10, arch_params={})
  168. train_params = {...}
  169. train_dataset_params = {
  170. "transforms": [...
  171. ]
  172. }
  173. train_dataloader_params = {"batch_size": 256}
  174. val_dataset_params = {"transforms": [ToTensor(), Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])]}
  175. val_dataloader_params = {"batch_size": 256}
  176. train_loader = cifar10_train(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params)
  177. valid_loader = cifar10_val(dataset_params=val_dataset_params, dataloader_params=val_dataloader_params)
  178. trainer.train(
  179. model=net,
  180. training_params=train_params,
  181. train_loader=train_loader,
  182. valid_loader=valid_loader,
  183. )
  184. train_params, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params = modify_params_for_qat(
  185. train_params, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params
  186. )
  187. train_loader = cifar10_train(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params)
  188. valid_loader = cifar10_val(dataset_params=val_dataset_params, dataloader_params=val_dataloader_params)
  189. trainer.qat(
  190. model=net,
  191. training_params=train_params,
  192. train_loader=train_loader,
  193. valid_loader=valid_loader,
  194. calib_loader=train_loader,
  195. )
  196. :param val_dataset_params: Dict, validation dataset_params to be passed to dataloaders.get(...) when instantiating the train dataloader.
  197. :param train_dataset_params: Dict, train dataset_params to be passed to dataloaders.get(...) when instantiating the validation dataloader.
  198. :param val_dataloader_params: Dict, validation dataloader_params to be passed to dataloaders.get(...) when instantiating the validation dataloader.
  199. :param train_dataloader_params: Dict, train dataloader_params to be passed to dataloaders.get(...) when instantiating the train dataloader.
  200. :param training_hyperparams: Dict, train parameters passed to Trainer.qat(...)
  201. :param quantization_params: Dict, quantization parameters as passed to Trainer.qat(...). When None, will use the
  202. default parameters in super_gradients/recipes/quantization_params/default_quantization_params.yaml
  203. :param int batch_size_divisor: Divisor used to calculate the batch size. Default value is 2.
  204. :param int max_epochs_divisor: Divisor used to calculate the maximum number of epochs. Default value is 10.
  205. :param float lr_decay_factor: Factor used to decay the learning rate, weight decay and warmup. Default value is 0.01.
  206. :param int warmup_epochs_divisor: Divisor used to calculate the number of warm-up epochs. Default value is 10.
  207. :param float cosine_final_lr_ratio: Ratio used to determine the final learning rate in a cosine annealing schedule. Default value is 0.01.
  208. :param bool disable_phase_callbacks: Flag to control to disable phase callbacks, which can interfere with QAT. Default value is True.
  209. :param bool disable_augmentations: Flag to control to disable phase augmentations, which can interfere with QAT. Default value is False.
  210. :return: modified (copy) training_hyperparams, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params
  211. """
  212. if quantization_params is None:
  213. quantization_params = load_recipe("quantization_params/default_quantization_params").quantization_params
  214. quantization_params = deepcopy(quantization_params)
  215. training_hyperparams = deepcopy(training_hyperparams)
  216. train_dataloader_params = deepcopy(train_dataloader_params)
  217. val_dataloader_params = deepcopy(val_dataloader_params)
  218. train_dataset_params = deepcopy(train_dataset_params)
  219. val_dataset_params = deepcopy(val_dataset_params)
  220. if "max_epochs" not in training_hyperparams.keys():
  221. raise ValueError("max_epochs is a required field in training_hyperparams for QAT modification.")
  222. if "initial_lr" not in training_hyperparams.keys():
  223. raise ValueError("initial_lr is a required field in training_hyperparams for QAT modification.")
  224. if "optimizer_params" not in training_hyperparams.keys():
  225. raise ValueError("optimizer_params is a required field in training_hyperparams for QAT modification.")
  226. if "weight_decay" not in training_hyperparams["optimizer_params"].keys():
  227. raise ValueError("weight_decay is a required field in training_hyperparams['optimizer_params'] for QAT modification.")
  228. # Q/DQ Layers take a lot of space for activations in training mode
  229. if get_param(quantization_params, "selective_quantizer_params") and get_param(quantization_params["selective_quantizer_params"], "learn_amax"):
  230. train_dataloader_params["batch_size"] //= batch_size_divisor
  231. val_dataloader_params["batch_size"] //= batch_size_divisor
  232. logger.warning(f"New dataset_params.train_dataloader_params.batch_size: {train_dataloader_params['batch_size']}")
  233. logger.warning(f"New dataset_params.val_dataloader_params.batch_size: {val_dataloader_params['batch_size']}")
  234. training_hyperparams["max_epochs"] //= max_epochs_divisor
  235. logger.warning(f"New number of epochs: {training_hyperparams['max_epochs']}")
  236. training_hyperparams["initial_lr"] *= lr_decay_factor
  237. if get_param(training_hyperparams, "warmup_initial_lr") is not None:
  238. training_hyperparams["warmup_initial_lr"] *= lr_decay_factor
  239. else:
  240. training_hyperparams["warmup_initial_lr"] = training_hyperparams["initial_lr"] * 0.01
  241. training_hyperparams["optimizer_params"]["weight_decay"] *= lr_decay_factor
  242. logger.warning(f"New learning rate: {training_hyperparams['initial_lr']}")
  243. logger.warning(f"New weight decay: {training_hyperparams['optimizer_params']['weight_decay']}")
  244. # as recommended by pytorch-quantization docs
  245. if get_param(training_hyperparams, "lr_mode") != "cosine":
  246. training_hyperparams["lr_mode"] = "cosine"
  247. training_hyperparams["cosine_final_lr_ratio"] = cosine_final_lr_ratio
  248. logger.warning(
  249. f"lr_mode will be set to cosine for QAT run instead of {get_param(training_hyperparams, 'lr_mode')} with "
  250. f"cosine_final_lr_ratio={cosine_final_lr_ratio}"
  251. )
  252. training_hyperparams["lr_warmup_epochs"] = (training_hyperparams["max_epochs"] // warmup_epochs_divisor) or 1
  253. logger.warning(f"New lr_warmup_epochs: {training_hyperparams['lr_warmup_epochs']}")
  254. # do mess with Q/DQ
  255. if get_param(training_hyperparams, "ema"):
  256. logger.warning("EMA will be disabled for QAT run.")
  257. training_hyperparams["ema"] = False
  258. if get_param(training_hyperparams, "sync_bn"):
  259. logger.warning("SyncBatchNorm will be disabled for QAT run.")
  260. training_hyperparams["sync_bn"] = False
  261. if disable_phase_callbacks and get_param(training_hyperparams, "phase_callbacks") is not None and len(training_hyperparams["phase_callbacks"]) > 0:
  262. logger.warning(f"Recipe contains {len(training_hyperparams['phase_callbacks'])} phase callbacks. All of them will be disabled.")
  263. training_hyperparams["phase_callbacks"] = []
  264. # no augmentations
  265. if disable_augmentations and "transforms" in val_dataset_params:
  266. logger.warning("Augmentations will be disabled for QAT run. Using validation transforms instead.")
  267. train_dataset_params["transforms"] = val_dataset_params["transforms"]
  268. return training_hyperparams, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params
  269. @register_pre_launch_callback()
  270. class QATRecipeModificationCallback(PreLaunchCallback):
  271. """
  272. QATRecipeModificationCallback(PreLaunchCallback)
  273. This callback modifies the recipe for QAT to implement rules of thumb based on the regular non-qat recipe.
  274. :param int batch_size_divisor: Divisor used to calculate the batch size. Default value is 2.
  275. :param int max_epochs_divisor: Divisor used to calculate the maximum number of epochs. Default value is 10.
  276. :param float lr_decay_factor: Factor used to decay the learning rate, weight decay and warmup. Default value is 0.01.
  277. :param int warmup_epochs_divisor: Divisor used to calculate the number of warm-up epochs. Default value is 10.
  278. :param float cosine_final_lr_ratio: Ratio used to determine the final learning rate in a cosine annealing schedule. Default value is 0.01.
  279. :param bool disable_phase_callbacks: Flag to control to disable phase callbacks, which can interfere with QAT. Default value is True.
  280. :param bool disable_augmentations: Flag to control to disable phase augmentations, which can interfere with QAT. Default value is False.
  281. Example usage:
  282. Inside the main recipe .YAML file (for example super_gradients/recipes/cifar10_resnet.yaml), add the following:
  283. pre_launch_callbacks_list:
  284. - QATRecipeModificationCallback:
  285. batch_size_divisor: 2
  286. max_epochs_divisor: 10
  287. lr_decay_factor: 0.01
  288. warmup_epochs_divisor: 10
  289. cosine_final_lr_ratio: 0.01
  290. disable_phase_callbacks: True
  291. disable_augmentations: False
  292. USE THIS CALLBACK ONLY WITH Trainer.quantize_from_config
  293. """
  294. def __init__(
  295. self,
  296. batch_size_divisor: int = 2,
  297. max_epochs_divisor: int = 10,
  298. lr_decay_factor: float = 0.01,
  299. warmup_epochs_divisor: int = 10,
  300. cosine_final_lr_ratio: float = 0.01,
  301. disable_phase_callbacks: bool = True,
  302. disable_augmentations: bool = False,
  303. ):
  304. self.disable_augmentations = disable_augmentations
  305. self.disable_phase_callbacks = disable_phase_callbacks
  306. self.cosine_final_lr_ratio = cosine_final_lr_ratio
  307. self.warmup_epochs_divisor = warmup_epochs_divisor
  308. self.lr_decay_factor = lr_decay_factor
  309. self.max_epochs_divisor = max_epochs_divisor
  310. self.batch_size_divisor = batch_size_divisor
  311. def __call__(self, cfg: Union[dict, DictConfig]) -> Union[dict, DictConfig]:
  312. logger.info("Modifying recipe to suit QAT rules of thumb. Remove QATRecipeModificationCallback to disable.")
  313. cfg = copy.deepcopy(cfg)
  314. (
  315. cfg.training_hyperparams,
  316. cfg.dataset_params.train_dataset_params,
  317. cfg.dataset_params.val_dataset_params,
  318. cfg.dataset_params.train_dataloader_params,
  319. cfg.dataset_params.val_dataloader_params,
  320. ) = modify_params_for_qat(
  321. training_hyperparams=cfg.training_hyperparams,
  322. train_dataset_params=cfg.dataset_params.train_dataset_params,
  323. train_dataloader_params=cfg.dataset_params.train_dataloader_params,
  324. val_dataset_params=cfg.dataset_params.val_dataset_params,
  325. val_dataloader_params=cfg.dataset_params.train_dataloader_params,
  326. quantization_params=cfg.quantization_params,
  327. batch_size_divisor=self.batch_size_divisor,
  328. disable_phase_callbacks=self.disable_phase_callbacks,
  329. cosine_final_lr_ratio=self.cosine_final_lr_ratio,
  330. warmup_epochs_divisor=self.warmup_epochs_divisor,
  331. lr_decay_factor=self.lr_decay_factor,
  332. max_epochs_divisor=self.max_epochs_divisor,
  333. disable_augmentations=self.disable_augmentations,
  334. )
  335. if cfg.multi_gpu != "OFF" or cfg.num_gpus != 1:
  336. logger.warning(f"Recipe requests multi_gpu={cfg.multi_gpu} and num_gpus={cfg.num_gpus}. Changing to multi_gpu=OFF and num_gpus=1")
  337. cfg.multi_gpu = "OFF"
  338. cfg.num_gpus = 1
  339. return cfg
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...