|
@@ -218,6 +218,56 @@ class KDModelTest(unittest.TestCase):
|
|
std_required=[0.5, 0.5, 0.5])})
|
|
std_required=[0.5, 0.5, 0.5])})
|
|
self.assertTrue(isinstance(kd_model.net.module.teacher[0], NormalizationAdapter))
|
|
self.assertTrue(isinstance(kd_model.net.module.teacher[0], NormalizationAdapter))
|
|
|
|
|
|
|
|
+ def test_load_ckpt_best_for_student(self):
|
|
|
|
+ sg_model = KDModel("test_load_ckpt_best", device='cpu')
|
|
|
|
+ teacher_model = ResNet50(arch_params={}, num_classes=5)
|
|
|
|
+ teacher_path = '/tmp/teacher.pth'
|
|
|
|
+ torch.save(teacher_model.state_dict(), teacher_path)
|
|
|
|
+ sg_model.connect_dataset_interface(self.dataset)
|
|
|
|
+
|
|
|
|
+ sg_model.build_model(student_arch_params={'num_classes': 5},
|
|
|
|
+ teacher_arch_params={'num_classes': 5},
|
|
|
|
+ student_architecture='resnet18',
|
|
|
|
+ teacher_architecture='resnet50',
|
|
|
|
+ checkpoint_params={"teacher_checkpoint_path": teacher_path}
|
|
|
|
+ )
|
|
|
|
+ train_params = self.kd_train_params.copy()
|
|
|
|
+ train_params["max_epochs"] = 1
|
|
|
|
+ sg_model.train(train_params)
|
|
|
|
+ best_student_ckpt = os.path.join(sg_model.checkpoints_dir_path, "ckpt_best.pth")
|
|
|
|
+
|
|
|
|
+ student_sg_model = SgModel("studnet_sg_model")
|
|
|
|
+ student_sg_model.build_model("resnet18", arch_params={'num_classes': 5},
|
|
|
|
+ checkpoint_params={"load_checkpoint": True, "external_checkpoint_path": best_student_ckpt})
|
|
|
|
+
|
|
|
|
+ self.assertTrue(
|
|
|
|
+ check_models_have_same_weights(student_sg_model.net.module, sg_model.net.module.student))
|
|
|
|
+
|
|
|
|
+ def test_load_ckpt_best_for_student_with_ema(self):
|
|
|
|
+ sg_model = KDModel("test_load_ckpt_best_for_student_with_ema", device='cpu')
|
|
|
|
+ teacher_model = ResNet50(arch_params={}, num_classes=5)
|
|
|
|
+ teacher_path = '/tmp/teacher.pth'
|
|
|
|
+ torch.save(teacher_model.state_dict(), teacher_path)
|
|
|
|
+ sg_model.connect_dataset_interface(self.dataset)
|
|
|
|
+
|
|
|
|
+ sg_model.build_model(student_arch_params={'num_classes': 5},
|
|
|
|
+ teacher_arch_params={'num_classes': 5},
|
|
|
|
+ student_architecture='resnet18',
|
|
|
|
+ teacher_architecture='resnet50',
|
|
|
|
+ checkpoint_params={"teacher_checkpoint_path": teacher_path}
|
|
|
|
+ )
|
|
|
|
+ train_params = self.kd_train_params.copy()
|
|
|
|
+ train_params["max_epochs"] = 1
|
|
|
|
+ train_params["ema"] = True
|
|
|
|
+ sg_model.train(train_params)
|
|
|
|
+ best_student_ckpt = os.path.join(sg_model.checkpoints_dir_path, "ckpt_best.pth")
|
|
|
|
+
|
|
|
|
+ student_sg_model = SgModel("studnet_sg_model")
|
|
|
|
+ student_sg_model.build_model("resnet18", arch_params={'num_classes': 5},
|
|
|
|
+ checkpoint_params={"load_checkpoint": True, "external_checkpoint_path": best_student_ckpt})
|
|
|
|
+ self.assertTrue(
|
|
|
|
+ check_models_have_same_weights(student_sg_model.net.module, sg_model.ema_model.ema.module.student))
|
|
|
|
+
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
unittest.main()
|
|
unittest.main()
|