Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#251 Save student only in ckpt_best for KD

Merged
Ofri Masad merged 1 commits into Deci-AI:master from deci-ai:feature/SG-153_save_student_only_ckpt_best
1 changed files with 50 additions and 0 deletions
  1. 50
    0
      tests/unit_tests/kd_model_test.py
@@ -218,6 +218,56 @@ class KDModelTest(unittest.TestCase):
                                                                                std_required=[0.5, 0.5, 0.5])})
                                                                                std_required=[0.5, 0.5, 0.5])})
         self.assertTrue(isinstance(kd_model.net.module.teacher[0], NormalizationAdapter))
         self.assertTrue(isinstance(kd_model.net.module.teacher[0], NormalizationAdapter))
 
 
+    def test_load_ckpt_best_for_student(self):
+        sg_model = KDModel("test_load_ckpt_best", device='cpu')
+        teacher_model = ResNet50(arch_params={}, num_classes=5)
+        teacher_path = '/tmp/teacher.pth'
+        torch.save(teacher_model.state_dict(), teacher_path)
+        sg_model.connect_dataset_interface(self.dataset)
+
+        sg_model.build_model(student_arch_params={'num_classes': 5},
+                             teacher_arch_params={'num_classes': 5},
+                             student_architecture='resnet18',
+                             teacher_architecture='resnet50',
+                             checkpoint_params={"teacher_checkpoint_path": teacher_path}
+                             )
+        train_params = self.kd_train_params.copy()
+        train_params["max_epochs"] = 1
+        sg_model.train(train_params)
+        best_student_ckpt = os.path.join(sg_model.checkpoints_dir_path, "ckpt_best.pth")
+
+        student_sg_model = SgModel("studnet_sg_model")
+        student_sg_model.build_model("resnet18", arch_params={'num_classes': 5},
+                                     checkpoint_params={"load_checkpoint": True, "external_checkpoint_path": best_student_ckpt})
+
+        self.assertTrue(
+            check_models_have_same_weights(student_sg_model.net.module, sg_model.net.module.student))
+
+    def test_load_ckpt_best_for_student_with_ema(self):
+        sg_model = KDModel("test_load_ckpt_best_for_student_with_ema", device='cpu')
+        teacher_model = ResNet50(arch_params={}, num_classes=5)
+        teacher_path = '/tmp/teacher.pth'
+        torch.save(teacher_model.state_dict(), teacher_path)
+        sg_model.connect_dataset_interface(self.dataset)
+
+        sg_model.build_model(student_arch_params={'num_classes': 5},
+                             teacher_arch_params={'num_classes': 5},
+                             student_architecture='resnet18',
+                             teacher_architecture='resnet50',
+                             checkpoint_params={"teacher_checkpoint_path": teacher_path}
+                             )
+        train_params = self.kd_train_params.copy()
+        train_params["max_epochs"] = 1
+        train_params["ema"] = True
+        sg_model.train(train_params)
+        best_student_ckpt = os.path.join(sg_model.checkpoints_dir_path, "ckpt_best.pth")
+
+        student_sg_model = SgModel("studnet_sg_model")
+        student_sg_model.build_model("resnet18", arch_params={'num_classes': 5},
+                                     checkpoint_params={"load_checkpoint": True, "external_checkpoint_path": best_student_ckpt})
+        self.assertTrue(
+            check_models_have_same_weights(student_sg_model.net.module, sg_model.ema_model.ema.module.student))
+
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
     unittest.main()
     unittest.main()
Discard