Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#604 fix master installation

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-000_fix_master_inastallation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
  1. import hydra
  2. import torch.nn
  3. from omegaconf import DictConfig
  4. from torch.utils.data import DataLoader
  5. from super_gradients.common import MultiGPUMode
  6. from super_gradients.training.dataloaders import dataloaders
  7. from super_gradients.training.models import SgModule
  8. from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
  9. from super_gradients.training.models.kd_modules.kd_module import KDModule
  10. from super_gradients.training.sg_trainer import Trainer
  11. from typing import Union
  12. from super_gradients.common.abstractions.abstract_logger import get_logger
  13. from super_gradients.training import utils as core_utils, models
  14. from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
  15. from super_gradients.training.utils import get_param, HpmStruct
  16. from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, load_checkpoint_to_model
  17. from super_gradients.training.exceptions.kd_trainer_exceptions import (
  18. ArchitectureKwargsException,
  19. UnsupportedKDArchitectureException,
  20. InconsistentParamsException,
  21. UnsupportedKDModelArgException,
  22. TeacherKnowledgeException,
  23. UndefinedNumClassesException,
  24. )
  25. from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
  26. from super_gradients.training.utils.ema import KDModelEMA
  27. from super_gradients.training.utils.sg_trainer_utils import parse_args
  28. logger = get_logger(__name__)
  29. class KDTrainer(Trainer):
  30. def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = None, ckpt_root_dir: str = None):
  31. super().__init__(experiment_name=experiment_name, device=device, multi_gpu=multi_gpu, ckpt_root_dir=ckpt_root_dir)
  32. self.student_architecture = None
  33. self.teacher_architecture = None
  34. self.student_arch_params = None
  35. self.teacher_arch_params = None
  36. @classmethod
  37. def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None:
  38. """
  39. Trains according to cfg recipe configuration.
  40. @param cfg: The parsed DictConfig from yaml recipe files
  41. @return: output of kd_trainer.train(...) (i.e results tuple)
  42. """
  43. # INSTANTIATE ALL OBJECTS IN CFG
  44. cfg = hydra.utils.instantiate(cfg)
  45. kwargs = parse_args(cfg, cls.__init__)
  46. trainer = KDTrainer(**kwargs)
  47. # INSTANTIATE DATA LOADERS
  48. train_dataloader = dataloaders.get(
  49. name=cfg.train_dataloader, dataset_params=cfg.dataset_params.train_dataset_params, dataloader_params=cfg.dataset_params.train_dataloader_params
  50. )
  51. val_dataloader = dataloaders.get(
  52. name=cfg.val_dataloader, dataset_params=cfg.dataset_params.val_dataset_params, dataloader_params=cfg.dataset_params.val_dataloader_params
  53. )
  54. student = models.get(
  55. cfg.student_architecture,
  56. arch_params=cfg.student_arch_params,
  57. strict_load=cfg.student_checkpoint_params.strict_load,
  58. pretrained_weights=cfg.student_checkpoint_params.pretrained_weights,
  59. checkpoint_path=cfg.student_checkpoint_params.checkpoint_path,
  60. load_backbone=cfg.student_checkpoint_params.load_backbone,
  61. )
  62. teacher = models.get(
  63. cfg.teacher_architecture,
  64. arch_params=cfg.teacher_arch_params,
  65. strict_load=cfg.teacher_checkpoint_params.strict_load,
  66. pretrained_weights=cfg.teacher_checkpoint_params.pretrained_weights,
  67. checkpoint_path=cfg.teacher_checkpoint_params.checkpoint_path,
  68. load_backbone=cfg.teacher_checkpoint_params.load_backbone,
  69. )
  70. # TRAIN
  71. trainer.train(
  72. training_params=cfg.training_hyperparams,
  73. student=student,
  74. teacher=teacher,
  75. kd_architecture=cfg.architecture,
  76. kd_arch_params=cfg.arch_params,
  77. run_teacher_on_eval=cfg.run_teacher_on_eval,
  78. train_loader=train_dataloader,
  79. valid_loader=val_dataloader,
  80. )
  81. def _validate_args(self, arch_params, architecture, checkpoint_params, **kwargs):
  82. student_architecture = get_param(kwargs, "student_architecture")
  83. teacher_architecture = get_param(kwargs, "teacher_architecture")
  84. student_arch_params = get_param(kwargs, "student_arch_params")
  85. teacher_arch_params = get_param(kwargs, "teacher_arch_params")
  86. if get_param(checkpoint_params, "pretrained_weights") is not None:
  87. raise UnsupportedKDModelArgException("pretrained_weights", "checkpoint_params")
  88. if not isinstance(architecture, KDModule):
  89. if student_architecture is None or teacher_architecture is None:
  90. raise ArchitectureKwargsException()
  91. if architecture not in KD_ARCHITECTURES.keys():
  92. raise UnsupportedKDArchitectureException(architecture)
  93. # DERIVE NUMBER OF CLASSES FROM DATASET INTERFACE IF NOT SPECIFIED OR ARCH PARAMS FOR TEACHER AND STUDENT
  94. self._validate_num_classes(student_arch_params, teacher_arch_params)
  95. arch_params["num_classes"] = student_arch_params["num_classes"]
  96. # MAKE SURE TEACHER'S PRETRAINED NUM CLASSES EQUALS TO THE ONES BELONGING TO STUDENT AS WE CAN'T REPLACE
  97. # THE TEACHER'S HEAD
  98. teacher_pretrained_weights = core_utils.get_param(checkpoint_params, "teacher_pretrained_weights", default_val=None)
  99. if teacher_pretrained_weights is not None:
  100. teacher_pretrained_num_classes = PRETRAINED_NUM_CLASSES[teacher_pretrained_weights]
  101. if teacher_pretrained_num_classes != teacher_arch_params["num_classes"]:
  102. raise InconsistentParamsException(
  103. "Pretrained dataset number of classes", "teacher's arch params", "number of classes", "student's number of classes"
  104. )
  105. teacher_checkpoint_path = get_param(checkpoint_params, "teacher_checkpoint_path")
  106. load_kd_model_checkpoint = get_param(checkpoint_params, "load_checkpoint")
  107. # CHECK THAT TEACHER NETWORK HOLDS KNOWLEDGE FOR THE STUDENT TO LEARN FROM OR THAT WE ARE LOADING AN ENTIRE KD
  108. if not (teacher_pretrained_weights or teacher_checkpoint_path or load_kd_model_checkpoint or isinstance(teacher_architecture, torch.nn.Module)):
  109. raise TeacherKnowledgeException()
  110. def _validate_num_classes(self, student_arch_params, teacher_arch_params):
  111. """
  112. Checks validity of num_classes for num_classes (i.e existence and consistency between subnets)
  113. :param student_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for student
  114. :param teacher_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for teacher
  115. """
  116. self._validate_subnet_num_classes(student_arch_params)
  117. self._validate_subnet_num_classes(teacher_arch_params)
  118. if teacher_arch_params["num_classes"] != student_arch_params["num_classes"]:
  119. raise InconsistentParamsException("num_classes", "student_arch_params", "num_classes", "teacher_arch_params")
  120. def _validate_subnet_num_classes(self, subnet_arch_params):
  121. """
  122. Derives num_classes in student_arch_params/teacher_arch_params from dataset interface or raises an error
  123. when none is given
  124. :param subnet_arch_params: Arch params for student/teacher
  125. """
  126. if "num_classes" not in subnet_arch_params.keys():
  127. if self.dataset_interface is None:
  128. raise UndefinedNumClassesException()
  129. else:
  130. subnet_arch_params["num_classes"] = len(self.classes)
  131. def _instantiate_net(self, architecture: Union[KDModule, KDModule.__class__, str], arch_params: dict, checkpoint_params: dict, *args, **kwargs) -> tuple:
  132. """
  133. Instantiates kd_module according to architecture and arch_params, handles pretrained weights for the student
  134. and teacher networks, and the required module manipulation (i.e head replacement) for the teacher network.
  135. :param architecture: String, KDModule or uninstantiated KDModule class describing the netowrks architecture.
  136. :param arch_params: Architecture's parameters passed to networks c'tor.
  137. :param checkpoint_params: checkpoint loading related parameters dictionary with 'pretrained_weights' key,
  138. s.t it's value is a string describing the dataset of the pretrained weights (for example "imagenent").
  139. :return: instantiated netowrk i.e KDModule, architecture_class (will be none when architecture is not str)
  140. """
  141. student_architecture = get_param(kwargs, "student_architecture")
  142. teacher_architecture = get_param(kwargs, "teacher_architecture")
  143. student_arch_params = get_param(kwargs, "student_arch_params")
  144. teacher_arch_params = get_param(kwargs, "teacher_arch_params")
  145. student_arch_params = core_utils.HpmStruct(**student_arch_params)
  146. teacher_arch_params = core_utils.HpmStruct(**teacher_arch_params)
  147. student_pretrained_weights = get_param(checkpoint_params, "student_pretrained_weights")
  148. teacher_pretrained_weights = get_param(checkpoint_params, "teacher_pretrained_weights")
  149. student = super()._instantiate_net(student_architecture, student_arch_params, {"pretrained_weights": student_pretrained_weights})
  150. teacher = super()._instantiate_net(teacher_architecture, teacher_arch_params, {"pretrained_weights": teacher_pretrained_weights})
  151. run_teacher_on_eval = get_param(kwargs, "run_teacher_on_eval", default_val=False)
  152. return self._instantiate_kd_net(arch_params, architecture, run_teacher_on_eval, student, teacher)
  153. def _instantiate_kd_net(self, arch_params, architecture, run_teacher_on_eval, student, teacher):
  154. if isinstance(architecture, str):
  155. architecture_cls = KD_ARCHITECTURES[architecture]
  156. net = architecture_cls(arch_params=arch_params, student=student, teacher=teacher, run_teacher_on_eval=run_teacher_on_eval)
  157. elif isinstance(architecture, KDModule.__class__):
  158. net = architecture(arch_params=arch_params, student=student, teacher=teacher, run_teacher_on_eval=run_teacher_on_eval)
  159. else:
  160. net = architecture
  161. return net
  162. def _load_checkpoint_to_model(self):
  163. """
  164. Initializes teacher weights with teacher_checkpoint_path if needed, then handles checkpoint loading for
  165. the entire KD network following the same logic as in Trainer.
  166. """
  167. teacher_checkpoint_path = get_param(self.checkpoint_params, "teacher_checkpoint_path")
  168. teacher_net = self.net.module.teacher
  169. if teacher_checkpoint_path is not None:
  170. # WARN THAT TEACHER_CKPT WILL OVERRIDE TEACHER'S PRETRAINED WEIGHTS
  171. teacher_pretrained_weights = get_param(self.checkpoint_params, "teacher_pretrained_weights")
  172. if teacher_pretrained_weights:
  173. logger.warning(teacher_checkpoint_path + " checkpoint is " "overriding " + teacher_pretrained_weights + " for teacher model")
  174. # ALWAYS LOAD ITS EMA IF IT EXISTS
  175. load_teachers_ema = "ema_net" in read_ckpt_state_dict(teacher_checkpoint_path).keys()
  176. load_checkpoint_to_model(
  177. ckpt_local_path=teacher_checkpoint_path,
  178. load_backbone=False,
  179. net=teacher_net,
  180. strict="no_key_matching",
  181. load_weights_only=True,
  182. load_ema_as_net=load_teachers_ema,
  183. )
  184. super(KDTrainer, self)._load_checkpoint_to_model()
  185. def _add_metrics_update_callback(self, phase):
  186. """
  187. Adds KDModelMetricsUpdateCallback to be fired at phase
  188. :param phase: Phase for the metrics callback to be fired at
  189. """
  190. self.phase_callbacks.append(KDModelMetricsUpdateCallback(phase))
  191. def _get_hyper_param_config(self):
  192. """
  193. Creates a training hyper param config for logging with additional KD related hyper params.
  194. """
  195. hyper_param_config = super()._get_hyper_param_config()
  196. hyper_param_config.update(
  197. {
  198. "student_architecture": self.student_architecture,
  199. "teacher_architecture": self.teacher_architecture,
  200. "student_arch_params": self.student_arch_params,
  201. "teacher_arch_params": self.teacher_arch_params,
  202. }
  203. )
  204. return hyper_param_config
  205. def _instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> KDModelEMA:
  206. """Instantiate KD ema model for KDModule.
  207. If the model is of class KDModule, the instance will be adapted to work on knowledge distillation.
  208. :param decay: the maximum decay value. as the training process advances, the decay will climb towards
  209. this value until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
  210. :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will
  211. saturate to its final value. beta=15 is ~40% of the training process.
  212. :param exp_activation:
  213. """
  214. return KDModelEMA(self.net, decay, beta, exp_activation)
  215. def _save_best_checkpoint(self, epoch, state):
  216. """
  217. Overrides parent best_ckpt saving to modify the state dict so that we only save the student.
  218. """
  219. if self.ema:
  220. best_net = core_utils.WrappedModel(self.ema_model.ema.module.student)
  221. state.pop("ema_net")
  222. else:
  223. best_net = core_utils.WrappedModel(self.net.module.student)
  224. state["net"] = best_net.state_dict()
  225. self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
  226. def train(
  227. self,
  228. model: KDModule = None,
  229. training_params: dict = dict(),
  230. student: SgModule = None,
  231. teacher: torch.nn.Module = None,
  232. kd_architecture: Union[KDModule.__class__, str] = "kd_module",
  233. kd_arch_params: dict = dict(),
  234. run_teacher_on_eval=False,
  235. train_loader: DataLoader = None,
  236. valid_loader: DataLoader = None,
  237. *args,
  238. **kwargs,
  239. ):
  240. """
  241. Trains the student network (wrapped in KDModule network).
  242. :param model: KDModule, network to train. When none is given will initialize KDModule according to kd_architecture,
  243. student and teacher (default=None)
  244. :param training_params: dict, Same as in Trainer.train()
  245. :param student: SgModule - the student trainer
  246. :param teacher: torch.nn.Module- the teacher trainer
  247. :param kd_architecture: KDModule architecture to use, currently only 'kd_module' is supported (default='kd_module').
  248. :param kd_arch_params: architecture params to pas to kd_architecture constructor.
  249. :param run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode)
  250. :param train_loader: Dataloader for train set.
  251. :param valid_loader: Dataloader for validation.
  252. """
  253. kd_net = self.net or model
  254. if kd_net is None:
  255. if student is None or teacher is None:
  256. raise ValueError("Must pass student and teacher models or net (KDModule).")
  257. kd_net = self._instantiate_kd_net(
  258. arch_params=HpmStruct(**kd_arch_params), architecture=kd_architecture, run_teacher_on_eval=run_teacher_on_eval, student=student, teacher=teacher
  259. )
  260. super(KDTrainer, self).train(model=kd_net, training_params=training_params, train_loader=train_loader, valid_loader=valid_loader)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file