Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

kd_ema_test.py 7.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
  1. import unittest
  2. from super_gradients.training.sg_trainer import Trainer
  3. from super_gradients.training.kd_trainer.kd_trainer import KDTrainer
  4. import torch
  5. from super_gradients.training.utils.utils import check_models_have_same_weights
  6. from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
  7. from super_gradients.training.metrics import Accuracy
  8. from super_gradients.training.losses.kd_losses import KDLogitsLoss
  9. class KDEMATest(unittest.TestCase):
  10. @classmethod
  11. def setUp(cls):
  12. cls.sg_trained_teacher = Trainer("sg_trained_teacher", device='cpu')
  13. cls.dataset_params = {"batch_size": 5}
  14. cls.dataset = ClassificationTestDatasetInterface(dataset_params=cls.dataset_params)
  15. cls.kd_train_params = {"max_epochs": 3, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
  16. "lr_warmup_epochs": 0, "initial_lr": 0.1,
  17. "loss": KDLogitsLoss(torch.nn.CrossEntropyLoss()),
  18. "optimizer": "SGD",
  19. "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  20. "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
  21. "metric_to_watch": "Accuracy",
  22. 'loss_logging_items_names': ["Loss", "Task Loss", "Distillation Loss"],
  23. "greater_metric_to_watch_is_better": True, "average_best_models": False,
  24. "ema": True}
  25. def test_teacher_ema_not_duplicated(self):
  26. """Check that the teacher EMA is a reference to the teacher net (not a copy)."""
  27. kd_trainer = KDTrainer("test_teacher_ema_not_duplicated", device='cpu')
  28. kd_trainer.connect_dataset_interface(self.dataset)
  29. kd_trainer.build_model(student_architecture='resnet18',
  30. teacher_architecture='resnet50',
  31. student_arch_params={'num_classes': 1000},
  32. teacher_arch_params={'num_classes': 1000},
  33. checkpoint_params={'teacher_pretrained_weights': "imagenet"},
  34. run_teacher_on_eval=True, )
  35. kd_trainer.train(self.kd_train_params)
  36. self.assertTrue(kd_trainer.ema_model.ema.module.teacher is kd_trainer.net.module.teacher)
  37. self.assertTrue(kd_trainer.ema_model.ema.module.student is not kd_trainer.net.module.student)
  38. def test_kd_ckpt_reload_ema(self):
  39. """Check that the KD model load correctly from checkpoint when "load_ema_as_net=True"."""
  40. # Create a KD model and train it
  41. kd_trainer = KDTrainer("test_kd_ema_ckpt_reload", device='cpu')
  42. kd_trainer.connect_dataset_interface(self.dataset)
  43. kd_trainer.build_model(student_architecture='resnet18',
  44. teacher_architecture='resnet50',
  45. student_arch_params={'num_classes': 1000},
  46. teacher_arch_params={'num_classes': 1000},
  47. checkpoint_params={'teacher_pretrained_weights': "imagenet"},
  48. run_teacher_on_eval=True, )
  49. kd_trainer.train(self.kd_train_params)
  50. ema_model = kd_trainer.ema_model.ema
  51. net = kd_trainer.net
  52. # Load the trained KD model
  53. kd_trainer = KDTrainer("test_kd_ema_ckpt_reload", device='cpu')
  54. kd_trainer.connect_dataset_interface(self.dataset)
  55. kd_trainer.build_model(student_architecture='resnet18',
  56. teacher_architecture='resnet50',
  57. student_arch_params={'num_classes': 1000},
  58. teacher_arch_params={'num_classes': 1000},
  59. checkpoint_params={"load_checkpoint": True, "load_ema_as_net": True},
  60. run_teacher_on_eval=True, )
  61. # TRAIN FOR 0 EPOCHS JUST TO SEE THAT WHEN CONTINUING TRAINING EMA MODEL HAS BEEN SAVED CORRECTLY
  62. kd_trainer.train(self.kd_train_params)
  63. reloaded_ema_model = kd_trainer.ema_model.ema
  64. reloaded_net = kd_trainer.net
  65. # trained ema == loaded ema (Should always be true as long as "ema=True" in train_params)
  66. self.assertTrue(check_models_have_same_weights(ema_model, reloaded_ema_model))
  67. # loaded net != trained net (since load_ema_as_net = True)
  68. self.assertTrue(not check_models_have_same_weights(reloaded_net, net))
  69. # loaded net == trained ema (since load_ema_as_net = True)
  70. self.assertTrue(check_models_have_same_weights(reloaded_net, ema_model))
  71. # loaded student ema == loaded student net (since load_ema_as_net = True)
  72. self.assertTrue(check_models_have_same_weights(reloaded_ema_model.module.student, reloaded_net.module.student))
  73. # loaded teacher ema == loaded teacher net (teacher always loads ema)
  74. self.assertTrue(check_models_have_same_weights(reloaded_ema_model.module.teacher, reloaded_net.module.teacher))
  75. def test_kd_ckpt_reload_net(self):
  76. """Check that the KD model load correctly from checkpoint when "load_ema_as_net=False"."""
  77. # Create a KD model and train it
  78. kd_trainer = KDTrainer("test_kd_ema_ckpt_reload", device='cpu')
  79. kd_trainer.connect_dataset_interface(self.dataset)
  80. kd_trainer.build_model(student_architecture='resnet18',
  81. teacher_architecture='resnet50',
  82. student_arch_params={'num_classes': 1000},
  83. teacher_arch_params={'num_classes': 1000},
  84. checkpoint_params={'teacher_pretrained_weights': "imagenet"},
  85. run_teacher_on_eval=True, )
  86. kd_trainer.train(self.kd_train_params)
  87. ema_model = kd_trainer.ema_model.ema
  88. net = kd_trainer.net
  89. # Load the trained KD model
  90. kd_trainer = KDTrainer("test_kd_ema_ckpt_reload", device='cpu')
  91. kd_trainer.connect_dataset_interface(self.dataset)
  92. kd_trainer.build_model(student_architecture='resnet18',
  93. teacher_architecture='resnet50',
  94. student_arch_params={'num_classes': 1000},
  95. teacher_arch_params={'num_classes': 1000},
  96. checkpoint_params={"load_checkpoint": True, "load_ema_as_net": False},
  97. run_teacher_on_eval=True, )
  98. # TRAIN FOR 0 EPOCHS JUST TO SEE THAT WHEN CONTINUING TRAINING EMA MODEL HAS BEEN SAVED CORRECTLY
  99. kd_trainer.train(self.kd_train_params)
  100. reloaded_ema_model = kd_trainer.ema_model.ema
  101. reloaded_net = kd_trainer.net
  102. # trained ema == loaded ema (Should always be true as long as "ema=True" in train_params)
  103. self.assertTrue(check_models_have_same_weights(ema_model, reloaded_ema_model))
  104. # loaded net == trained net (since load_ema_as_net = False)
  105. self.assertTrue(check_models_have_same_weights(reloaded_net, net))
  106. # loaded net != trained ema (since load_ema_as_net = False)
  107. self.assertTrue(not check_models_have_same_weights(reloaded_net, ema_model))
  108. # loaded student ema == loaded student net (since load_ema_as_net = False)
  109. self.assertTrue(not check_models_have_same_weights(reloaded_ema_model.module.student, reloaded_net.module.student))
  110. # loaded teacher ema == loaded teacher net (teacher always loads ema)
  111. self.assertTrue(check_models_have_same_weights(reloaded_ema_model.module.teacher, reloaded_net.module.teacher))
  112. if __name__ == '__main__':
  113. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...