|
@@ -2,6 +2,7 @@ import os
|
|
import unittest
|
|
import unittest
|
|
from copy import deepcopy
|
|
from copy import deepcopy
|
|
|
|
|
|
|
|
+from super_gradients.common.object_names import Models
|
|
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
|
|
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
|
|
from super_gradients.training.kd_trainer.kd_trainer import KDTrainer
|
|
from super_gradients.training.kd_trainer.kd_trainer import KDTrainer
|
|
import torch
|
|
import torch
|
|
@@ -57,8 +58,8 @@ class KDTrainerTest(unittest.TestCase):
|
|
}
|
|
}
|
|
|
|
|
|
def test_teacher_sg_module_methods(self):
|
|
def test_teacher_sg_module_methods(self):
|
|
- student = models.get("resnet18", arch_params={"num_classes": 1000})
|
|
|
|
- teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
|
|
|
|
|
|
+ student = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
|
|
|
|
+ teacher = models.get(Models.RESNET50, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
|
|
kd_module = KDModule(arch_params={}, student=student, teacher=teacher)
|
|
kd_module = KDModule(arch_params={}, student=student, teacher=teacher)
|
|
|
|
|
|
initial_param_groups = kd_module.initialize_param_groups(lr=0.1, training_params={})
|
|
initial_param_groups = kd_module.initialize_param_groups(lr=0.1, training_params={})
|
|
@@ -87,8 +88,8 @@ class KDTrainerTest(unittest.TestCase):
|
|
|
|
|
|
def test_train_model_with_input_adapter(self):
|
|
def test_train_model_with_input_adapter(self):
|
|
kd_trainer = KDTrainer("train_kd_module_with_with_input_adapter")
|
|
kd_trainer = KDTrainer("train_kd_module_with_with_input_adapter")
|
|
- student = models.get("resnet18", arch_params={"num_classes": 5})
|
|
|
|
- teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
|
|
|
|
|
|
+ student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
|
|
|
|
+ teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
|
|
|
|
|
|
adapter = NormalizationAdapter(
|
|
adapter = NormalizationAdapter(
|
|
mean_original=[0.485, 0.456, 0.406], std_original=[0.229, 0.224, 0.225], mean_required=[0.5, 0.5, 0.5], std_required=[0.5, 0.5, 0.5]
|
|
mean_original=[0.485, 0.456, 0.406], std_original=[0.229, 0.224, 0.225], mean_required=[0.5, 0.5, 0.5], std_required=[0.5, 0.5, 0.5]
|
|
@@ -108,8 +109,8 @@ class KDTrainerTest(unittest.TestCase):
|
|
|
|
|
|
def test_load_ckpt_best_for_student(self):
|
|
def test_load_ckpt_best_for_student(self):
|
|
kd_trainer = KDTrainer("test_load_ckpt_best")
|
|
kd_trainer = KDTrainer("test_load_ckpt_best")
|
|
- student = models.get("resnet18", arch_params={"num_classes": 5})
|
|
|
|
- teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
|
|
|
|
|
|
+ student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
|
|
|
|
+ teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
|
|
train_params = self.kd_train_params.copy()
|
|
train_params = self.kd_train_params.copy()
|
|
train_params["max_epochs"] = 1
|
|
train_params["max_epochs"] = 1
|
|
kd_trainer.train(
|
|
kd_trainer.train(
|
|
@@ -121,14 +122,14 @@ class KDTrainerTest(unittest.TestCase):
|
|
)
|
|
)
|
|
best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
|
|
best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
|
|
|
|
|
|
- student_reloaded = models.get("resnet18", arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt)
|
|
|
|
|
|
+ student_reloaded = models.get(Models.RESNET18, arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt)
|
|
|
|
|
|
self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.net.module.student))
|
|
self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.net.module.student))
|
|
|
|
|
|
def test_load_ckpt_best_for_student_with_ema(self):
|
|
def test_load_ckpt_best_for_student_with_ema(self):
|
|
kd_trainer = KDTrainer("test_load_ckpt_best")
|
|
kd_trainer = KDTrainer("test_load_ckpt_best")
|
|
- student = models.get("resnet18", arch_params={"num_classes": 5})
|
|
|
|
- teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
|
|
|
|
|
|
+ student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
|
|
|
|
+ teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
|
|
train_params = self.kd_train_params.copy()
|
|
train_params = self.kd_train_params.copy()
|
|
train_params["max_epochs"] = 1
|
|
train_params["max_epochs"] = 1
|
|
train_params["ema"] = True
|
|
train_params["ema"] = True
|
|
@@ -141,14 +142,14 @@ class KDTrainerTest(unittest.TestCase):
|
|
)
|
|
)
|
|
best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
|
|
best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
|
|
|
|
|
|
- student_reloaded = models.get("resnet18", arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt)
|
|
|
|
|
|
+ student_reloaded = models.get(Models.RESNET18, arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt)
|
|
|
|
|
|
self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.ema_model.ema.module.student))
|
|
self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.ema_model.ema.module.student))
|
|
|
|
|
|
def test_resume_kd_training(self):
|
|
def test_resume_kd_training(self):
|
|
kd_trainer = KDTrainer("test_resume_training_start")
|
|
kd_trainer = KDTrainer("test_resume_training_start")
|
|
- student = models.get("resnet18", arch_params={"num_classes": 5})
|
|
|
|
- teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
|
|
|
|
|
|
+ student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
|
|
|
|
+ teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
|
|
train_params = self.kd_train_params.copy()
|
|
train_params = self.kd_train_params.copy()
|
|
train_params["max_epochs"] = 1
|
|
train_params["max_epochs"] = 1
|
|
kd_trainer.train(
|
|
kd_trainer.train(
|
|
@@ -161,8 +162,8 @@ class KDTrainerTest(unittest.TestCase):
|
|
latest_net = deepcopy(kd_trainer.net)
|
|
latest_net = deepcopy(kd_trainer.net)
|
|
|
|
|
|
kd_trainer = KDTrainer("test_resume_training_start")
|
|
kd_trainer = KDTrainer("test_resume_training_start")
|
|
- student = models.get("resnet18", arch_params={"num_classes": 5})
|
|
|
|
- teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
|
|
|
|
|
|
+ student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
|
|
|
|
+ teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
|
|
|
|
|
|
train_params["max_epochs"] = 2
|
|
train_params["max_epochs"] = 2
|
|
train_params["resume"] = True
|
|
train_params["resume"] = True
|