Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#875 Feature/sg 761 yolo nas

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-761-yolo-nas
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
  1. import copy
  2. from copy import deepcopy
  3. from typing import Union
  4. from omegaconf import DictConfig
  5. import torch
  6. from super_gradients.common.registry.registry import register_pre_launch_callback
  7. from super_gradients import is_distributed
  8. from super_gradients.common.abstractions.abstract_logger import get_logger
  9. from super_gradients.training import models
  10. from torch.distributed import barrier
  11. import cv2
  12. import numpy as np
  13. logger = get_logger(__name__)
  14. class PreLaunchCallback:
  15. """
  16. PreLaunchCallback
  17. Base class for callbacks to be triggered, manipulating the config (cfg) prior to launching training,
  18. when calling Trainer.train_from_config(cfg).
  19. """
  20. def __call__(self, cfg: Union[dict, DictConfig]) -> Union[dict, DictConfig]:
  21. raise NotImplementedError
  22. @register_pre_launch_callback()
  23. class AutoTrainBatchSizeSelectionCallback(PreLaunchCallback):
  24. """
  25. AutoTrainBatchSizeSelectionCallback
  26. Modifies cfg.dataset_params.train_dataloader_params.batch_size by searching for the maximal batch size that fits
  27. gpu memory/ the one resulting in fastest time for the selected number of train datalaoder iterations. Works out of the box for DDP.
  28. The search is done by running a few forward passes for increasing batch sizes, until CUDA OUT OF MEMORY is raised:
  29. For batch_size in range(min_batch_size:max_batch_size:size_step):
  30. if batch_size raises CUDA OUT OF MEMORY ERROR:
  31. return batch_size-size_step
  32. return batch_size
  33. Example usage: Inside the main recipe .YAML file (for example super_gradients/recipes/cifar10_resnet.yaml),
  34. add the following:
  35. pre_launch_callbacks_list:
  36. - AutoTrainBatchSizeSelectionCallback:
  37. min_batch_size: 128
  38. size_step: 64
  39. num_forward_passes: 10
  40. Then, when running super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=...
  41. this pre_launch_callback will modify cfg.dataset_params.train_dataloader_params.batch_size then pass cfg to
  42. Trainer.train_from_config(cfg) and training will continue with the selected batch size.
  43. :param min_batch_size: int, the first batch size to try running forward passes. Should fit memory.
  44. :param size_step: int, the difference between 2 consecutive batch_size trials.
  45. :param num_forward_passes: int, number of forward passes (i.e train_loader data iterations inside an epoch).
  46. Note that the more forward passes being done, the less the selected batch size is prawn to fail. This is because
  47. other then gradients, model computations, data and other fixed gpu memory that is being used- some more gpu memory
  48. might be used by the metric objects and PhaseCallbacks.
  49. :param max_batch_size: int, optional, upper limit of the batch sizes to try. When None, the search will continue until
  50. the maximal batch size that does not raise CUDA OUT OF MEMORY is found (deafult=None).
  51. :param scale_lr: bool, whether to linearly scale cfg.training_hyperparams.initial_lr, i.e multiply by
  52. FOUND_BATCH_SIZE/cfg.dataset_params.train_datalaoder_params.batch_size (default=True)
  53. :param mode: str, one of ["fastest","largest"], whether to select the largest batch size that fits memory or the one
  54. that the resulted in overall fastest execution.
  55. """
  56. def __init__(self, min_batch_size: int, size_step: int, num_forward_passes: int = 3, max_batch_size=None, scale_lr: bool = True, mode: str = "fastest"):
  57. if mode not in ["fastest", "largest"]:
  58. raise TypeError(f"Expected mode to be one of: ['fastest','largest'], got {mode}")
  59. self.scale_lr = scale_lr
  60. self.min_batch_size = min_batch_size
  61. self.size_step = size_step
  62. self.max_batch_size = max_batch_size
  63. self.num_forward_passes = num_forward_passes
  64. self.mode = mode
  65. def __call__(self, cfg: DictConfig) -> DictConfig:
  66. # IMPORT IS HERE DUE TO CIRCULAR IMPORT PROBLEM
  67. from super_gradients.training.sg_trainer import Trainer
  68. curr_batch_size = self.min_batch_size
  69. # BUILD NETWORK
  70. model = models.get(
  71. model_name=cfg.architecture,
  72. num_classes=cfg.arch_params.num_classes,
  73. arch_params=cfg.arch_params,
  74. strict_load=cfg.checkpoint_params.strict_load,
  75. pretrained_weights=cfg.checkpoint_params.pretrained_weights,
  76. checkpoint_path=cfg.checkpoint_params.checkpoint_path,
  77. load_backbone=cfg.checkpoint_params.load_backbone,
  78. )
  79. tmp_cfg = deepcopy(cfg)
  80. tmp_cfg.training_hyperparams.batch_accumulate = 1
  81. tmp_cfg.training_hyperparams.max_train_batches = self.num_forward_passes
  82. tmp_cfg.training_hyperparams.run_validation_freq = 2
  83. tmp_cfg.training_hyperparams.silent_mode = True
  84. tmp_cfg.training_hyperparams.save_model = False
  85. tmp_cfg.training_hyperparams.max_epochs = 1
  86. tmp_cfg.training_hyperparams.average_best_models = False
  87. tmp_cfg.training_hyperparams.kill_ddp_pgroup_on_end = False
  88. tmp_cfg.pre_launch_callbacks_list = []
  89. fastest_batch_time = np.inf
  90. fastest_batch_size = curr_batch_size
  91. bs_found = False
  92. while not bs_found:
  93. tmp_cfg.dataset_params.train_dataloader_params.batch_size = curr_batch_size
  94. try:
  95. passes_start = cv2.getTickCount()
  96. Trainer.train_from_config(tmp_cfg)
  97. curr_batch_time = (cv2.getTickCount() - passes_start) / cv2.getTickFrequency()
  98. logger.info(f"Batch size = {curr_batch_size} time for {self.num_forward_passes} forward passes: {curr_batch_time} seconds.")
  99. if curr_batch_time < fastest_batch_time:
  100. fastest_batch_size = curr_batch_size
  101. fastest_batch_time = curr_batch_time
  102. except RuntimeError as e:
  103. if "out of memory" in str(e):
  104. if curr_batch_size == self.min_batch_size:
  105. logger.error("Ran out of memory for the smallest batch, try setting smaller min_batch_size.")
  106. raise e
  107. else:
  108. selected_batch_size = curr_batch_size - self.size_step if self.mode == "largest" else fastest_batch_size
  109. msg = f"Ran out of memory for {curr_batch_size}, setting batch size to {selected_batch_size}."
  110. bs_found = True
  111. else:
  112. raise e
  113. else:
  114. if self.max_batch_size is not None and curr_batch_size >= self.max_batch_size:
  115. selected_batch_size = self.max_batch_size if self.mode == "largest" else fastest_batch_size
  116. msg = (
  117. f"Did not run out of memory for {curr_batch_size} >= max_batch_size={self.max_batch_size}, " f"setting batch to {selected_batch_size}."
  118. )
  119. bs_found = True
  120. else:
  121. logger.info(f"Did not run out of memory for {curr_batch_size}, retrying batch {curr_batch_size + self.size_step}.")
  122. curr_batch_size += self.size_step
  123. self._clear_model_gpu_mem(model)
  124. return self._inject_selected_batch_size_to_config(cfg, model, msg, selected_batch_size)
  125. def _inject_selected_batch_size_to_config(self, cfg, model, msg, selected_batch_size):
  126. logger.info(msg)
  127. self._adapt_lr_if_needed(cfg, found_batch_size=selected_batch_size)
  128. cfg.dataset_params.train_dataloader_params.batch_size = selected_batch_size
  129. self._clear_model_gpu_mem(model)
  130. return cfg
  131. def _adapt_lr_if_needed(self, cfg: DictConfig, found_batch_size: int) -> DictConfig:
  132. if self.scale_lr:
  133. scale_factor = found_batch_size / cfg.dataset_params.train_dataloader_params.batch_size
  134. cfg.training_hyperparams.initial_lr = cfg.training_hyperparams.initial_lr * scale_factor
  135. return cfg
  136. @classmethod
  137. def _clear_model_gpu_mem(cls, model):
  138. for p in model.parameters():
  139. if p.grad is not None:
  140. del p.grad # free some memory
  141. torch.cuda.empty_cache()
  142. # WAIT FOR ALL PROCESSES TO CLEAR THEIR MEMORY BEFORE MOVING ON
  143. if is_distributed():
  144. barrier()
  145. @register_pre_launch_callback()
  146. class QATRecipeModificationCallback(PreLaunchCallback):
  147. """
  148. QATRecipeModificationCallback(PreLaunchCallback)
  149. This callback modifies the recipe for QAT to implement rules of thumb based on the regular non-qat recipe.
  150. :param int batch_size_divisor: Divisor used to calculate the batch size. Default value is 2.
  151. :param int max_epochs_divisor: Divisor used to calculate the maximum number of epochs. Default value is 10.
  152. :param float lr_decay_factor: Factor used to decay the learning rate, weight decay and warmup. Default value is 0.01.
  153. :param int warmup_epochs_divisor: Divisor used to calculate the number of warm-up epochs. Default value is 10.
  154. :param float cosine_final_lr_ratio: Ratio used to determine the final learning rate in a cosine annealing schedule. Default value is 0.01.
  155. :param bool disable_phase_callbacks: Flag to control to disable phase callbacks, which can interfere with QAT. Default value is True.
  156. :param bool disable_augmentations: Flag to control to disable phase augmentations, which can interfere with QAT. Default value is False.
  157. Example usage:
  158. Inside the main recipe .YAML file (for example super_gradients/recipes/cifar10_resnet.yaml), add the following:
  159. pre_launch_callbacks_list:
  160. - QATRecipeModificationCallback:
  161. batch_size_divisor: 2
  162. max_epochs_divisor: 10
  163. lr_decay_factor: 0.01
  164. warmup_epochs_divisor: 10
  165. cosine_final_lr_ratio: 0.01
  166. disable_phase_callbacks: True
  167. disable_augmentations: False
  168. USE THIS CALLBACK ONLY WITH QATTrainer!
  169. """
  170. def __init__(
  171. self,
  172. batch_size_divisor: int = 2,
  173. max_epochs_divisor: int = 10,
  174. lr_decay_factor: float = 0.01,
  175. warmup_epochs_divisor: int = 10,
  176. cosine_final_lr_ratio: float = 0.01,
  177. disable_phase_callbacks: bool = True,
  178. disable_augmentations: bool = False,
  179. ):
  180. self.disable_augmentations = disable_augmentations
  181. self.disable_phase_callbacks = disable_phase_callbacks
  182. self.cosine_final_lr_ratio = cosine_final_lr_ratio
  183. self.warmup_epochs_divisor = warmup_epochs_divisor
  184. self.lr_decay_factor = lr_decay_factor
  185. self.max_epochs_divisor = max_epochs_divisor
  186. self.batch_size_divisor = batch_size_divisor
  187. def __call__(self, cfg: Union[dict, DictConfig]) -> Union[dict, DictConfig]:
  188. logger.info("Modifying recipe to suit QAT rules of thumb. Remove QATRecipeModificationCallback to disable.")
  189. cfg = copy.deepcopy(cfg)
  190. # Q/DQ Layers take a lot of space for activations in training mode
  191. if cfg.quantization_params.selective_quantizer_params.learn_amax:
  192. cfg.dataset_params.train_dataloader_params.batch_size //= self.batch_size_divisor
  193. cfg.dataset_params.val_dataloader_params.batch_size //= self.batch_size_divisor
  194. logger.warning(f"New dataset_params.train_dataloader_params.batch_size: {cfg.dataset_params.train_dataloader_params.batch_size}")
  195. logger.warning(f"New dataset_params.val_dataloader_params.batch_size: {cfg.dataset_params.val_dataloader_params.batch_size}")
  196. cfg.training_hyperparams.max_epochs //= self.max_epochs_divisor
  197. logger.warning(f"New number of epochs: {cfg.training_hyperparams.max_epochs}")
  198. cfg.training_hyperparams.initial_lr *= self.lr_decay_factor
  199. if cfg.training_hyperparams.warmup_initial_lr is not None:
  200. cfg.training_hyperparams.warmup_initial_lr *= self.lr_decay_factor
  201. else:
  202. cfg.training_hyperparams.warmup_initial_lr = cfg.training_hyperparams.initial_lr * 0.01
  203. cfg.training_hyperparams.optimizer_params.weight_decay *= self.lr_decay_factor
  204. logger.warning(f"New learning rate: {cfg.training_hyperparams.initial_lr}")
  205. logger.warning(f"New weight decay: {cfg.training_hyperparams.optimizer_params.weight_decay}")
  206. # as recommended by pytorch-quantization docs
  207. cfg.training_hyperparams.lr_mode = "cosine"
  208. cfg.training_hyperparams.lr_warmup_epochs = (cfg.training_hyperparams.max_epochs // self.warmup_epochs_divisor) or 1
  209. cfg.training_hyperparams.cosine_final_lr_ratio = self.cosine_final_lr_ratio
  210. # do mess with Q/DQ
  211. if cfg.training_hyperparams.ema:
  212. logger.warning("EMA will be disabled for QAT run.")
  213. cfg.training_hyperparams.ema = False
  214. if cfg.training_hyperparams.sync_bn:
  215. logger.warning("SyncBatchNorm will be disabled for QAT run.")
  216. cfg.training_hyperparams.sync_bn = False
  217. if self.disable_phase_callbacks and len(cfg.training_hyperparams.phase_callbacks) > 0:
  218. logger.warning(f"Recipe contains {len(cfg.training_hyperparams.phase_callbacks)} phase callbacks. All of them will be disabled.")
  219. cfg.training_hyperparams.phase_callbacks = []
  220. if cfg.multi_gpu != "OFF" or cfg.num_gpus != 1:
  221. logger.warning(f"Recipe requests multi_gpu={cfg.multi_gpu} and num_gpus={cfg.num_gpus}. Changing to multi_gpu=OFF and num_gpus=1")
  222. cfg.multi_gpu = "OFF"
  223. cfg.num_gpus = 1
  224. # no augmentations
  225. if self.disable_augmentations and "transforms" in cfg.dataset_params.val_dataset_params:
  226. logger.warning("Augmentations will be disabled for QAT run.")
  227. cfg.dataset_params.train_dataset_params.transforms = cfg.dataset_params.val_dataset_params.transforms
  228. return cfg
Discard
Tip!

Press p or to see the previous file or, n or to see the next file