Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#614 Feature/sg 493 modelnames instead of strings

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-493_modelnames_instead_of_strings
@@ -1,5 +1,6 @@
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 
 
 import super_gradients
 import super_gradients
@@ -18,7 +19,7 @@ class TestCifarTrainer(unittest.TestCase):
         super_gradients.init_trainer()
         super_gradients.init_trainer()
         trainer = Trainer("test")
         trainer = Trainer("test")
         cifar10_train_dl, cifar10_val_dl = cifar10_train(), cifar10_val()
         cifar10_train_dl, cifar10_val_dl = cifar10_train(), cifar10_val()
-        model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
+        model = models.get(Models.RESNET18_CIFAR, arch_params={"num_classes": 10})
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params={
             training_params={
@@ -37,7 +38,7 @@ class TestCifarTrainer(unittest.TestCase):
         super_gradients.init_trainer()
         super_gradients.init_trainer()
         trainer = Trainer("test")
         trainer = Trainer("test")
         cifar100_train_dl, cifar100_val_dl = cifar100_train(), cifar100_val()
         cifar100_train_dl, cifar100_val_dl = cifar100_train(), cifar100_val()
-        model = models.get("resnet18_cifar", arch_params={"num_classes": 100})
+        model = models.get(Models.RESNET18_CIFAR, arch_params={"num_classes": 100})
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params={
             training_params={
Discard
@@ -1,6 +1,7 @@
 import shutil
 import shutil
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 
 
 import super_gradients
 import super_gradients
@@ -41,7 +42,7 @@ class TestTrainer(unittest.TestCase):
     @staticmethod
     @staticmethod
     def get_classification_trainer(name=""):
     def get_classification_trainer(name=""):
         trainer = Trainer(name)
         trainer = Trainer(name)
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         return trainer, model
         return trainer, model
 
 
     def test_train(self):
     def test_train(self):
Discard
@@ -1,3 +1,4 @@
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
@@ -22,7 +23,7 @@ class CallWrapper:
 class EMAIntegrationTest(unittest.TestCase):
 class EMAIntegrationTest(unittest.TestCase):
     def _init_model(self) -> None:
     def _init_model(self) -> None:
         self.trainer = Trainer("resnet18_cifar_ema_test")
         self.trainer = Trainer("resnet18_cifar_ema_test")
-        self.model = models.get("resnet18_cifar", arch_params={"num_classes": 5})
+        self.model = models.get(Models.RESNET18_CIFAR, arch_params={"num_classes": 5})
 
 
     @classmethod
     @classmethod
     def tearDownClass(cls) -> None:
     def tearDownClass(cls) -> None:
Discard
@@ -2,6 +2,7 @@ import shutil
 import unittest
 import unittest
 import os
 import os
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 
 
 from super_gradients import Trainer
 from super_gradients import Trainer
@@ -34,7 +35,7 @@ class LRTest(unittest.TestCase):
     @staticmethod
     @staticmethod
     def get_trainer(name=""):
     def get_trainer(name=""):
         trainer = Trainer(name)
         trainer = Trainer(name)
-        model = models.get("resnet18_cifar", num_classes=5)
+        model = models.get(Models.RESNET18_CIFAR, num_classes=5)
         return trainer, model
         return trainer, model
 
 
     def test_function_lr(self):
     def test_function_lr(self):
Discard
@@ -1,6 +1,7 @@
 #!/usr/bin/env python3
 #!/usr/bin/env python3
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
 from super_gradients.training.dataloaders import imagenet_val, imagenet_vit_base_val
 from super_gradients.training.dataloaders import imagenet_val, imagenet_vit_base_val
 from super_gradients.training.dataloaders.dataloaders import (
 from super_gradients.training.dataloaders.dataloaders import (
@@ -35,19 +36,19 @@ class PretrainedModelsTest(unittest.TestCase):
         self.imagenet_pretrained_arch_params = {
         self.imagenet_pretrained_arch_params = {
             "resnet": {},
             "resnet": {},
             "regnet": {},
             "regnet": {},
-            "repvgg_a0": {"build_residual_branches": True},
-            "efficientnet_b0": {},
+            Models.REPVGG_A0: {"build_residual_branches": True},
+            Models.EFFICIENTNET_B0: {},
             "mobilenet": {},
             "mobilenet": {},
-            "vit_base": {"image_size": (224, 224), "patch_size": (16, 16)},
+            Models.VIT_BASE: {"image_size": (224, 224), "patch_size": (16, 16)},
         }
         }
 
 
         self.imagenet_pretrained_trainsfer_learning_arch_params = {
         self.imagenet_pretrained_trainsfer_learning_arch_params = {
             "resnet": {},
             "resnet": {},
             "regnet": {},
             "regnet": {},
-            "repvgg_a0": {"build_residual_branches": True},
-            "efficientnet_b0": {},
+            Models.REPVGG_A0: {"build_residual_branches": True},
+            Models.EFFICIENTNET_B0: {},
             "mobilenet": {},
             "mobilenet": {},
-            "vit_base": {"image_size": (224, 224), "patch_size": (16, 16)},
+            Models.VIT_BASE: {"image_size": (224, 224), "patch_size": (16, 16)},
         }
         }
 
 
         self.imagenet_pretrained_ckpt_params = {"pretrained_weights": "imagenet"}
         self.imagenet_pretrained_ckpt_params = {"pretrained_weights": "imagenet"}
@@ -55,21 +56,21 @@ class PretrainedModelsTest(unittest.TestCase):
         self.imagenet21k_pretrained_ckpt_params = {"pretrained_weights": "imagenet21k"}
         self.imagenet21k_pretrained_ckpt_params = {"pretrained_weights": "imagenet21k"}
 
 
         self.imagenet_pretrained_accuracies = {
         self.imagenet_pretrained_accuracies = {
-            "resnet50": 0.8191,
-            "resnet34": 0.7413,
-            "resnet18": 0.706,
-            "repvgg_a0": 0.7205,
-            "regnetY800": 0.7707,
-            "regnetY600": 0.7618,
-            "regnetY400": 0.7474,
-            "regnetY200": 0.7088,
-            "efficientnet_b0": 0.7762,
-            "mobilenet_v3_large": 0.7452,
-            "mobilenet_v3_small": 0.6745,
-            "mobilenet_v2": 0.7308,
-            "vit_base": 0.8415,
-            "vit_large": 0.8564,
-            "beit_base_patch16_224": 0.85,
+            Models.RESNET50: 0.8191,
+            Models.RESNET34: 0.7413,
+            Models.RESNET18: 0.706,
+            Models.REPVGG_A0: 0.7205,
+            Models.REGNETY800: 0.7707,
+            Models.REGNETY600: 0.7618,
+            Models.REGNETY400: 0.7474,
+            Models.REGNETY200: 0.7088,
+            Models.EFFICIENTNET_B0: 0.7762,
+            Models.MOBILENET_V3_LARGE: 0.7452,
+            Models.MOBILENET_V3_SMALL: 0.6745,
+            Models.MOBILENET_V2: 0.7308,
+            Models.VIT_BASE: 0.8415,
+            Models.VIT_LARGE: 0.8564,
+            Models.BEIT_BASE_PATCH16_224: 0.85,
         }
         }
         self.imagenet_dataset = imagenet_val(dataloader_params={"batch_size": 128})
         self.imagenet_dataset = imagenet_val(dataloader_params={"batch_size": 128})
 
 
@@ -91,7 +92,7 @@ class PretrainedModelsTest(unittest.TestCase):
             "greater_metric_to_watch_is_better": True,
             "greater_metric_to_watch_is_better": True,
         }
         }
         self.coco_pretrained_ckpt_params = {"pretrained_weights": "coco"}
         self.coco_pretrained_ckpt_params = {"pretrained_weights": "coco"}
-        self.coco_pretrained_arch_params = {"ssd_lite_mobilenet_v2": {"num_classes": 80}, "coco_ssd_mobilenet_v1": {"num_classes": 80}}
+        self.coco_pretrained_arch_params = {Models.SSD_LITE_MOBILENET_V2: {"num_classes": 80}, Models.SSD_MOBILENET_V1: {"num_classes": 80}}
         self.coco_pretrained_ckpt_params = {"pretrained_weights": "coco"}
         self.coco_pretrained_ckpt_params = {"pretrained_weights": "coco"}
 
 
         self.coco_dataset = {
         self.coco_dataset = {
@@ -102,13 +103,13 @@ class PretrainedModelsTest(unittest.TestCase):
         }
         }
 
 
         self.coco_pretrained_maps = {
         self.coco_pretrained_maps = {
-            "ssd_lite_mobilenet_v2": 0.2041,
-            "coco_ssd_mobilenet_v1": 0.243,
-            "yolox_s": 0.4047,
-            "yolox_m": 0.4640,
-            "yolox_l": 0.4925,
-            "yolox_n": 0.2677,
-            "yolox_t": 0.3718,
+            Models.SSD_LITE_MOBILENET_V2: 0.2041,
+            Models.SSD_MOBILENET_V1: 0.243,
+            Models.YOLOX_S: 0.4047,
+            Models.YOLOX_M: 0.4640,
+            Models.YOLOX_L: 0.4925,
+            Models.YOLOX_N: 0.2677,
+            Models.YOLOX_T: 0.3718,
         }
         }
 
 
         self.transfer_detection_dataset = detection_test_dataloader()
         self.transfer_detection_dataset = detection_test_dataloader()
@@ -151,27 +152,27 @@ class PretrainedModelsTest(unittest.TestCase):
         self.coco_segmentation_subclass_pretrained_mious = {"shelfnet34_lw": 0.651}
         self.coco_segmentation_subclass_pretrained_mious = {"shelfnet34_lw": 0.651}
         self.coco_segmentation_dataset = coco_segmentation_val()
         self.coco_segmentation_dataset = coco_segmentation_val()
 
 
-        self.cityscapes_pretrained_models = ["ddrnet_23", "ddrnet_23_slim", "stdc1_seg50", "regseg48"]
+        self.cityscapes_pretrained_models = [Models.DDRNET_23, Models.DDRNET_23_SLIM, Models.STDC1_SEG50, Models.REGSEG48]
         self.cityscapes_pretrained_arch_params = {
         self.cityscapes_pretrained_arch_params = {
-            "ddrnet_23": {"aux_head": True},
-            "regseg48": {},
+            Models.DDRNET_23: {"aux_head": True},
+            Models.REGSEG48: {},
             "stdc": {"use_aux_heads": True, "aux_head": True},
             "stdc": {"use_aux_heads": True, "aux_head": True},
             "pplite_seg": {"use_aux_heads": True},
             "pplite_seg": {"use_aux_heads": True},
         }
         }
 
 
         self.cityscapes_pretrained_ckpt_params = {"pretrained_weights": "cityscapes"}
         self.cityscapes_pretrained_ckpt_params = {"pretrained_weights": "cityscapes"}
         self.cityscapes_pretrained_mious = {
         self.cityscapes_pretrained_mious = {
-            "ddrnet_23": 0.8026,
-            "ddrnet_23_slim": 0.7801,
-            "stdc1_seg50": 0.7511,
-            "stdc1_seg75": 0.7687,
-            "stdc2_seg50": 0.7644,
-            "stdc2_seg75": 0.7893,
-            "regseg48": 0.7815,
-            "pp_lite_t_seg50": 0.7492,
-            "pp_lite_t_seg75": 0.7756,
-            "pp_lite_b_seg50": 0.7648,
-            "pp_lite_b_seg75": 0.7852,
+            Models.DDRNET_23: 0.8026,
+            Models.DDRNET_23_SLIM: 0.7801,
+            Models.STDC1_SEG50: 0.7511,
+            Models.STDC1_SEG75: 0.7687,
+            Models.STDC2_SEG50: 0.7644,
+            Models.STDC2_SEG75: 0.7893,
+            Models.REGSEG48: 0.7815,
+            Models.PP_LITE_T_SEG50: 0.7492,
+            Models.PP_LITE_T_SEG75: 0.7756,
+            Models.PP_LITE_B_SEG50: 0.7648,
+            Models.PP_LITE_B_SEG75: 0.7852,
         }
         }
 
 
         self.cityscapes_dataset = cityscapes_val()
         self.cityscapes_dataset = cityscapes_val()
@@ -229,13 +230,13 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_pretrained_resnet50_imagenet(self):
     def test_pretrained_resnet50_imagenet(self):
         trainer = Trainer("imagenet_pretrained_resnet50")
         trainer = Trainer("imagenet_pretrained_resnet50")
-        model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.RESNET50, arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet50"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.RESNET50], delta=0.001)
 
 
     def test_transfer_learning_resnet50_imagenet(self):
     def test_transfer_learning_resnet50_imagenet(self):
         trainer = Trainer("imagenet_pretrained_resnet50_transfer_learning")
         trainer = Trainer("imagenet_pretrained_resnet50_transfer_learning")
-        model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
+        model = models.get(Models.RESNET50, arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -246,13 +247,13 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_resnet34_imagenet(self):
     def test_pretrained_resnet34_imagenet(self):
         trainer = Trainer("imagenet_pretrained_resnet34")
         trainer = Trainer("imagenet_pretrained_resnet34")
 
 
-        model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.RESNET34, arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet34"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.RESNET34], delta=0.001)
 
 
     def test_transfer_learning_resnet34_imagenet(self):
     def test_transfer_learning_resnet34_imagenet(self):
         trainer = Trainer("imagenet_pretrained_resnet34_transfer_learning")
         trainer = Trainer("imagenet_pretrained_resnet34_transfer_learning")
-        model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
+        model = models.get(Models.RESNET34, arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -263,13 +264,13 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_resnet18_imagenet(self):
     def test_pretrained_resnet18_imagenet(self):
         trainer = Trainer("imagenet_pretrained_resnet18")
         trainer = Trainer("imagenet_pretrained_resnet18")
 
 
-        model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.RESNET18, arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet18"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.RESNET18], delta=0.001)
 
 
     def test_transfer_learning_resnet18_imagenet(self):
     def test_transfer_learning_resnet18_imagenet(self):
         trainer = Trainer("imagenet_pretrained_resnet18_transfer_learning")
         trainer = Trainer("imagenet_pretrained_resnet18_transfer_learning")
-        model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
+        model = models.get(Models.RESNET18, arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -280,13 +281,13 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_regnetY800_imagenet(self):
     def test_pretrained_regnetY800_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY800")
         trainer = Trainer("imagenet_pretrained_regnetY800")
 
 
-        model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.REGNETY800, arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY800"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.REGNETY800], delta=0.001)
 
 
     def test_transfer_learning_regnetY800_imagenet(self):
     def test_transfer_learning_regnetY800_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY800_transfer_learning")
         trainer = Trainer("imagenet_pretrained_regnetY800_transfer_learning")
-        model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
+        model = models.get(Models.REGNETY800, arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -297,13 +298,13 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_regnetY600_imagenet(self):
     def test_pretrained_regnetY600_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY600")
         trainer = Trainer("imagenet_pretrained_regnetY600")
 
 
-        model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.REGNETY600, arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY600"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.REGNETY600], delta=0.001)
 
 
     def test_transfer_learning_regnetY600_imagenet(self):
     def test_transfer_learning_regnetY600_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY600_transfer_learning")
         trainer = Trainer("imagenet_pretrained_regnetY600_transfer_learning")
-        model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
+        model = models.get(Models.REGNETY600, arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -314,13 +315,13 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_regnetY400_imagenet(self):
     def test_pretrained_regnetY400_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY400")
         trainer = Trainer("imagenet_pretrained_regnetY400")
 
 
-        model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.REGNETY400, arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY400"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.REGNETY400], delta=0.001)
 
 
     def test_transfer_learning_regnetY400_imagenet(self):
     def test_transfer_learning_regnetY400_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY400_transfer_learning")
         trainer = Trainer("imagenet_pretrained_regnetY400_transfer_learning")
-        model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
+        model = models.get(Models.REGNETY400, arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -331,13 +332,13 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_regnetY200_imagenet(self):
     def test_pretrained_regnetY200_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY200")
         trainer = Trainer("imagenet_pretrained_regnetY200")
 
 
-        model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.REGNETY200, arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY200"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.REGNETY200], delta=0.001)
 
 
     def test_transfer_learning_regnetY200_imagenet(self):
     def test_transfer_learning_regnetY200_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY200_transfer_learning")
         trainer = Trainer("imagenet_pretrained_regnetY200_transfer_learning")
-        model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
+        model = models.get(Models.REGNETY200, arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -348,13 +349,15 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_repvgg_a0_imagenet(self):
     def test_pretrained_repvgg_a0_imagenet(self):
         trainer = Trainer("imagenet_pretrained_repvgg_a0")
         trainer = Trainer("imagenet_pretrained_repvgg_a0")
 
 
-        model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.REPVGG_A0, arch_params=self.imagenet_pretrained_arch_params[Models.REPVGG_A0], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["repvgg_a0"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.REPVGG_A0], delta=0.001)
 
 
     def test_transfer_learning_repvgg_a0_imagenet(self):
     def test_transfer_learning_repvgg_a0_imagenet(self):
         trainer = Trainer("imagenet_pretrained_repvgg_a0_transfer_learning")
         trainer = Trainer("imagenet_pretrained_repvgg_a0_transfer_learning")
-        model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
+        model = models.get(
+            Models.REPVGG_A0, arch_params=self.imagenet_pretrained_arch_params[Models.REPVGG_A0], **self.imagenet_pretrained_ckpt_params, num_classes=5
+        )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -364,7 +367,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_pretrained_regseg48_cityscapes(self):
     def test_pretrained_regseg48_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_regseg48")
         trainer = Trainer("cityscapes_pretrained_regseg48")
-        model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.REGSEG48, arch_params=self.cityscapes_pretrained_arch_params[Models.REGSEG48], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
                 model=model, test_loader=self.cityscapes_dataset, test_metrics_list=[IoU(num_classes=20, ignore_index=19)], metrics_progress_verbose=True
                 model=model, test_loader=self.cityscapes_dataset, test_metrics_list=[IoU(num_classes=20, ignore_index=19)], metrics_progress_verbose=True
@@ -372,11 +375,11 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["regseg48"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.REGSEG48], delta=0.001)
 
 
     def test_transfer_learning_regseg48_cityscapes(self):
     def test_transfer_learning_regseg48_cityscapes(self):
         trainer = Trainer("regseg48_cityscapes_transfer_learning")
         trainer = Trainer("regseg48_cityscapes_transfer_learning")
-        model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.REGSEG48, arch_params=self.cityscapes_pretrained_arch_params[Models.REGSEG48], **self.cityscapes_pretrained_ckpt_params)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             train_loader=self.transfer_segmentation_dataset,
             train_loader=self.transfer_segmentation_dataset,
@@ -386,7 +389,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_pretrained_ddrnet23_cityscapes(self):
     def test_pretrained_ddrnet23_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_ddrnet23")
         trainer = Trainer("cityscapes_pretrained_ddrnet23")
-        model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.DDRNET_23, arch_params=self.cityscapes_pretrained_arch_params[Models.DDRNET_23], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
                 model=model, test_loader=self.cityscapes_dataset, test_metrics_list=[IoU(num_classes=20, ignore_index=19)], metrics_progress_verbose=True
                 model=model, test_loader=self.cityscapes_dataset, test_metrics_list=[IoU(num_classes=20, ignore_index=19)], metrics_progress_verbose=True
@@ -394,11 +397,13 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.DDRNET_23], delta=0.001)
 
 
     def test_pretrained_ddrnet23_slim_cityscapes(self):
     def test_pretrained_ddrnet23_slim_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_ddrnet23_slim")
         trainer = Trainer("cityscapes_pretrained_ddrnet23_slim")
-        model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(
+            Models.DDRNET_23_SLIM, arch_params=self.cityscapes_pretrained_arch_params[Models.DDRNET_23], **self.cityscapes_pretrained_ckpt_params
+        )
         res = (
         res = (
             trainer.test(
             trainer.test(
                 model=model, test_loader=self.cityscapes_dataset, test_metrics_list=[IoU(num_classes=20, ignore_index=19)], metrics_progress_verbose=True
                 model=model, test_loader=self.cityscapes_dataset, test_metrics_list=[IoU(num_classes=20, ignore_index=19)], metrics_progress_verbose=True
@@ -406,11 +411,11 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23_slim"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.DDRNET_23_SLIM], delta=0.001)
 
 
     def test_transfer_learning_ddrnet23_cityscapes(self):
     def test_transfer_learning_ddrnet23_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_ddrnet23_transfer_learning")
         trainer = Trainer("cityscapes_pretrained_ddrnet23_transfer_learning")
-        model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.DDRNET_23, arch_params=self.cityscapes_pretrained_arch_params[Models.DDRNET_23], **self.cityscapes_pretrained_ckpt_params)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.ddrnet_transfer_segmentation_train_params,
             training_params=self.ddrnet_transfer_segmentation_train_params,
@@ -420,7 +425,9 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_transfer_learning_ddrnet23_slim_cityscapes(self):
     def test_transfer_learning_ddrnet23_slim_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_ddrnet23_slim_transfer_learning")
         trainer = Trainer("cityscapes_pretrained_ddrnet23_slim_transfer_learning")
-        model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(
+            Models.DDRNET_23_SLIM, arch_params=self.cityscapes_pretrained_arch_params[Models.DDRNET_23], **self.cityscapes_pretrained_ckpt_params
+        )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.ddrnet_transfer_segmentation_train_params,
             training_params=self.ddrnet_transfer_segmentation_train_params,
@@ -441,15 +448,20 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_efficientnet_b0_imagenet(self):
     def test_pretrained_efficientnet_b0_imagenet(self):
         trainer = Trainer("imagenet_pretrained_efficientnet_b0")
         trainer = Trainer("imagenet_pretrained_efficientnet_b0")
 
 
-        model = models.get("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(
+            Models.EFFICIENTNET_B0, arch_params=self.imagenet_pretrained_arch_params[Models.EFFICIENTNET_B0], **self.imagenet_pretrained_ckpt_params
+        )
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["efficientnet_b0"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.EFFICIENTNET_B0], delta=0.001)
 
 
     def test_transfer_learning_efficientnet_b0_imagenet(self):
     def test_transfer_learning_efficientnet_b0_imagenet(self):
         trainer = Trainer("imagenet_pretrained_efficientnet_b0_transfer_learning")
         trainer = Trainer("imagenet_pretrained_efficientnet_b0_transfer_learning")
 
 
         model = models.get(
         model = models.get(
-            "efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"], **self.imagenet_pretrained_ckpt_params, num_classes=5
+            Models.EFFICIENTNET_B0,
+            arch_params=self.imagenet_pretrained_arch_params[Models.EFFICIENTNET_B0],
+            **self.imagenet_pretrained_ckpt_params,
+            num_classes=5,
         )
         )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -460,7 +472,9 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_pretrained_ssd_lite_mobilenet_v2_coco(self):
     def test_pretrained_ssd_lite_mobilenet_v2_coco(self):
         trainer = Trainer("coco_ssd_lite_mobilenet_v2")
         trainer = Trainer("coco_ssd_lite_mobilenet_v2")
-        model = models.get("ssd_lite_mobilenet_v2", arch_params=self.coco_pretrained_arch_params["ssd_lite_mobilenet_v2"], **self.coco_pretrained_ckpt_params)
+        model = models.get(
+            Models.SSD_LITE_MOBILENET_V2, arch_params=self.coco_pretrained_arch_params[Models.SSD_LITE_MOBILENET_V2], **self.coco_pretrained_ckpt_params
+        )
         ssd_post_prediction_callback = SSDPostPredictCallback()
         ssd_post_prediction_callback = SSDPostPredictCallback()
         res = trainer.test(
         res = trainer.test(
             model=model,
             model=model,
@@ -468,13 +482,13 @@ class PretrainedModelsTest(unittest.TestCase):
             test_metrics_list=[DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback, num_cls=80)],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback, num_cls=80)],
             metrics_progress_verbose=True,
             metrics_progress_verbose=True,
         )[2]
         )[2]
-        self.assertAlmostEqual(res, self.coco_pretrained_maps["ssd_lite_mobilenet_v2"], delta=0.001)
+        self.assertAlmostEqual(res, self.coco_pretrained_maps[Models.SSD_LITE_MOBILENET_V2], delta=0.001)
 
 
     def test_transfer_learning_ssd_lite_mobilenet_v2_coco(self):
     def test_transfer_learning_ssd_lite_mobilenet_v2_coco(self):
         trainer = Trainer("coco_ssd_lite_mobilenet_v2_transfer_learning")
         trainer = Trainer("coco_ssd_lite_mobilenet_v2_transfer_learning")
-        transfer_arch_params = self.coco_pretrained_arch_params["ssd_lite_mobilenet_v2"].copy()
+        transfer_arch_params = self.coco_pretrained_arch_params[Models.SSD_LITE_MOBILENET_V2].copy()
         transfer_arch_params["num_classes"] = 5
         transfer_arch_params["num_classes"] = 5
-        model = models.get("ssd_lite_mobilenet_v2", arch_params=transfer_arch_params, **self.coco_pretrained_ckpt_params)
+        model = models.get(Models.SSD_LITE_MOBILENET_V2, arch_params=transfer_arch_params, **self.coco_pretrained_ckpt_params)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_detection_train_params_ssd,
             training_params=self.transfer_detection_train_params_ssd,
@@ -483,8 +497,8 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_ssd_mobilenet_v1_coco(self):
     def test_pretrained_ssd_mobilenet_v1_coco(self):
-        trainer = Trainer("coco_ssd_mobilenet_v1")
-        model = models.get("ssd_mobilenet_v1", arch_params=self.coco_pretrained_arch_params["coco_ssd_mobilenet_v1"], **self.coco_pretrained_ckpt_params)
+        trainer = Trainer(Models.SSD_MOBILENET_V1)
+        model = models.get(Models.SSD_MOBILENET_V1, arch_params=self.coco_pretrained_arch_params[Models.SSD_MOBILENET_V1], **self.coco_pretrained_ckpt_params)
         ssd_post_prediction_callback = SSDPostPredictCallback()
         ssd_post_prediction_callback = SSDPostPredictCallback()
         res = trainer.test(
         res = trainer.test(
             model=model,
             model=model,
@@ -492,63 +506,63 @@ class PretrainedModelsTest(unittest.TestCase):
             test_metrics_list=[DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback, num_cls=80)],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback, num_cls=80)],
             metrics_progress_verbose=True,
             metrics_progress_verbose=True,
         )[2]
         )[2]
-        self.assertAlmostEqual(res, self.coco_pretrained_maps["coco_ssd_mobilenet_v1"], delta=0.001)
+        self.assertAlmostEqual(res, self.coco_pretrained_maps[Models.SSD_MOBILENET_V1], delta=0.001)
 
 
     def test_pretrained_yolox_s_coco(self):
     def test_pretrained_yolox_s_coco(self):
-        trainer = Trainer("yolox_s")
+        trainer = Trainer(Models.YOLOX_S)
 
 
-        model = models.get("yolox_s", **self.coco_pretrained_ckpt_params)
+        model = models.get(Models.YOLOX_S, **self.coco_pretrained_ckpt_params)
         res = trainer.test(
         res = trainer.test(
             model=model,
             model=model,
             test_loader=self.coco_dataset["yolox"],
             test_loader=self.coco_dataset["yolox"],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
         )[2]
         )[2]
-        self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_s"], delta=0.001)
+        self.assertAlmostEqual(res, self.coco_pretrained_maps[Models.YOLOX_S], delta=0.001)
 
 
     def test_pretrained_yolox_m_coco(self):
     def test_pretrained_yolox_m_coco(self):
-        trainer = Trainer("yolox_m")
-        model = models.get("yolox_m", **self.coco_pretrained_ckpt_params)
+        trainer = Trainer(Models.YOLOX_M)
+        model = models.get(Models.YOLOX_M, **self.coco_pretrained_ckpt_params)
         res = trainer.test(
         res = trainer.test(
             model=model,
             model=model,
             test_loader=self.coco_dataset["yolox"],
             test_loader=self.coco_dataset["yolox"],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
         )[2]
         )[2]
-        self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_m"], delta=0.001)
+        self.assertAlmostEqual(res, self.coco_pretrained_maps[Models.YOLOX_M], delta=0.001)
 
 
     def test_pretrained_yolox_l_coco(self):
     def test_pretrained_yolox_l_coco(self):
-        trainer = Trainer("yolox_l")
-        model = models.get("yolox_l", **self.coco_pretrained_ckpt_params)
+        trainer = Trainer(Models.YOLOX_L)
+        model = models.get(Models.YOLOX_L, **self.coco_pretrained_ckpt_params)
         res = trainer.test(
         res = trainer.test(
             model=model,
             model=model,
             test_loader=self.coco_dataset["yolox"],
             test_loader=self.coco_dataset["yolox"],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
         )[2]
         )[2]
-        self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_l"], delta=0.001)
+        self.assertAlmostEqual(res, self.coco_pretrained_maps[Models.YOLOX_L], delta=0.001)
 
 
     def test_pretrained_yolox_n_coco(self):
     def test_pretrained_yolox_n_coco(self):
-        trainer = Trainer("yolox_n")
+        trainer = Trainer(Models.YOLOX_N)
 
 
-        model = models.get("yolox_n", **self.coco_pretrained_ckpt_params)
+        model = models.get(Models.YOLOX_N, **self.coco_pretrained_ckpt_params)
         res = trainer.test(
         res = trainer.test(
             model=model,
             model=model,
             test_loader=self.coco_dataset["yolox"],
             test_loader=self.coco_dataset["yolox"],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
         )[2]
         )[2]
-        self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_n"], delta=0.001)
+        self.assertAlmostEqual(res, self.coco_pretrained_maps[Models.YOLOX_N], delta=0.001)
 
 
     def test_pretrained_yolox_t_coco(self):
     def test_pretrained_yolox_t_coco(self):
-        trainer = Trainer("yolox_t")
-        model = models.get("yolox_t", **self.coco_pretrained_ckpt_params)
+        trainer = Trainer(Models.YOLOX_T)
+        model = models.get(Models.YOLOX_T, **self.coco_pretrained_ckpt_params)
         res = trainer.test(
         res = trainer.test(
             model=model,
             model=model,
             test_loader=self.coco_dataset["yolox"],
             test_loader=self.coco_dataset["yolox"],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
         )[2]
         )[2]
-        self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_t"], delta=0.001)
+        self.assertAlmostEqual(res, self.coco_pretrained_maps[Models.YOLOX_T], delta=0.001)
 
 
     def test_transfer_learning_yolox_n_coco(self):
     def test_transfer_learning_yolox_n_coco(self):
         trainer = Trainer("test_transfer_learning_yolox_n_coco")
         trainer = Trainer("test_transfer_learning_yolox_n_coco")
-        model = models.get("yolox_n", **self.coco_pretrained_ckpt_params, num_classes=5)
+        model = models.get(Models.YOLOX_N, **self.coco_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_detection_train_params_yolox,
             training_params=self.transfer_detection_train_params_yolox,
@@ -560,7 +574,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer("imagenet_pretrained_mobilenet_v3_large_transfer_learning")
         trainer = Trainer("imagenet_pretrained_mobilenet_v3_large_transfer_learning")
 
 
         model = models.get(
         model = models.get(
-            "mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params, num_classes=5
+            Models.MOBILENET_V3_LARGE, arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params, num_classes=5
         )
         )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -572,15 +586,15 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_mobilenet_v3_large_imagenet(self):
     def test_pretrained_mobilenet_v3_large_imagenet(self):
         trainer = Trainer("imagenet_mobilenet_v3_large")
         trainer = Trainer("imagenet_mobilenet_v3_large")
 
 
-        model = models.get("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.MOBILENET_V3_LARGE, arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_large"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.MOBILENET_V3_LARGE], delta=0.001)
 
 
     def test_transfer_learning_mobilenet_v3_small_imagenet(self):
     def test_transfer_learning_mobilenet_v3_small_imagenet(self):
         trainer = Trainer("imagenet_pretrained_mobilenet_v3_small_transfer_learning")
         trainer = Trainer("imagenet_pretrained_mobilenet_v3_small_transfer_learning")
 
 
         model = models.get(
         model = models.get(
-            "mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params, num_classes=5
+            Models.MOBILENET_V3_SMALL, arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params, num_classes=5
         )
         )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -592,14 +606,16 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_mobilenet_v3_small_imagenet(self):
     def test_pretrained_mobilenet_v3_small_imagenet(self):
         trainer = Trainer("imagenet_mobilenet_v3_small")
         trainer = Trainer("imagenet_mobilenet_v3_small")
 
 
-        model = models.get("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.MOBILENET_V3_SMALL, arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_small"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.MOBILENET_V3_SMALL], delta=0.001)
 
 
     def test_transfer_learning_mobilenet_v2_imagenet(self):
     def test_transfer_learning_mobilenet_v2_imagenet(self):
         trainer = Trainer("imagenet_pretrained_mobilenet_v2_transfer_learning")
         trainer = Trainer("imagenet_pretrained_mobilenet_v2_transfer_learning")
 
 
-        model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
+        model = models.get(
+            Models.MOBILENET_V2, arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params, num_classes=5
+        )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -610,13 +626,14 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_mobilenet_v2_imagenet(self):
     def test_pretrained_mobilenet_v2_imagenet(self):
         trainer = Trainer("imagenet_mobilenet_v2")
         trainer = Trainer("imagenet_mobilenet_v2")
 
 
-        model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.MOBILENET_V2, arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v2"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.MOBILENET_V2], delta=0.001)
 
 
     def test_pretrained_stdc1_seg50_cityscapes(self):
     def test_pretrained_stdc1_seg50_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_stdc1_seg50")
         trainer = Trainer("cityscapes_pretrained_stdc1_seg50")
-        model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
+
+        model = models.get(Models.STDC1_SEG50, arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
                 model=model,
                 model=model,
@@ -627,11 +644,13 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg50"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.STDC1_SEG50], delta=0.001)
 
 
     def test_transfer_learning_stdc1_seg50_cityscapes(self):
     def test_transfer_learning_stdc1_seg50_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_stdc1_seg50_transfer_learning")
         trainer = Trainer("cityscapes_pretrained_stdc1_seg50_transfer_learning")
-        model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5)
+        model = models.get(
+            Models.STDC1_SEG50, arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5
+        )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.stdc_transfer_segmentation_train_params,
             training_params=self.stdc_transfer_segmentation_train_params,
@@ -641,7 +660,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_pretrained_stdc1_seg75_cityscapes(self):
     def test_pretrained_stdc1_seg75_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_stdc1_seg75")
         trainer = Trainer("cityscapes_pretrained_stdc1_seg75")
-        model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.STDC1_SEG75, arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
                 model=model,
                 model=model,
@@ -652,11 +671,13 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg75"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.STDC1_SEG75], delta=0.001)
 
 
     def test_transfer_learning_stdc1_seg75_cityscapes(self):
     def test_transfer_learning_stdc1_seg75_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_stdc1_seg75_transfer_learning")
         trainer = Trainer("cityscapes_pretrained_stdc1_seg75_transfer_learning")
-        model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5)
+        model = models.get(
+            Models.STDC1_SEG75, arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5
+        )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.stdc_transfer_segmentation_train_params,
             training_params=self.stdc_transfer_segmentation_train_params,
@@ -666,7 +687,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_pretrained_stdc2_seg50_cityscapes(self):
     def test_pretrained_stdc2_seg50_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_stdc2_seg50")
         trainer = Trainer("cityscapes_pretrained_stdc2_seg50")
-        model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.STDC2_SEG50, arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
                 model=model,
                 model=model,
@@ -677,11 +698,13 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg50"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.STDC2_SEG50], delta=0.001)
 
 
     def test_transfer_learning_stdc2_seg50_cityscapes(self):
     def test_transfer_learning_stdc2_seg50_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_stdc2_seg50_transfer_learning")
         trainer = Trainer("cityscapes_pretrained_stdc2_seg50_transfer_learning")
-        model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5)
+        model = models.get(
+            Models.STDC2_SEG50, arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5
+        )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.stdc_transfer_segmentation_train_params,
             training_params=self.stdc_transfer_segmentation_train_params,
@@ -691,7 +714,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_pretrained_stdc2_seg75_cityscapes(self):
     def test_pretrained_stdc2_seg75_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_stdc2_seg75")
         trainer = Trainer("cityscapes_pretrained_stdc2_seg75")
-        model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.STDC2_SEG75, arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
                 model=model,
                 model=model,
@@ -702,11 +725,13 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg75"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.STDC2_SEG75], delta=0.001)
 
 
     def test_transfer_learning_stdc2_seg75_cityscapes(self):
     def test_transfer_learning_stdc2_seg75_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_stdc2_seg75_transfer_learning")
         trainer = Trainer("cityscapes_pretrained_stdc2_seg75_transfer_learning")
-        model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5)
+        model = models.get(
+            Models.STDC2_SEG75, arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5
+        )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.stdc_transfer_segmentation_train_params,
             training_params=self.stdc_transfer_segmentation_train_params,
@@ -717,7 +742,9 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_transfer_learning_vit_base_imagenet21k(self):
     def test_transfer_learning_vit_base_imagenet21k(self):
         trainer = Trainer("imagenet21k_pretrained_vit_base")
         trainer = Trainer("imagenet21k_pretrained_vit_base")
 
 
-        model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet21k_pretrained_ckpt_params, num_classes=5)
+        model = models.get(
+            Models.VIT_BASE, arch_params=self.imagenet_pretrained_arch_params[Models.VIT_BASE], **self.imagenet21k_pretrained_ckpt_params, num_classes=5
+        )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -728,7 +755,9 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_transfer_learning_vit_large_imagenet21k(self):
     def test_transfer_learning_vit_large_imagenet21k(self):
         trainer = Trainer("imagenet21k_pretrained_vit_large")
         trainer = Trainer("imagenet21k_pretrained_vit_large")
 
 
-        model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet21k_pretrained_ckpt_params, num_classes=5)
+        model = models.get(
+            Models.VIT_LARGE, arch_params=self.imagenet_pretrained_arch_params[Models.VIT_BASE], **self.imagenet21k_pretrained_ckpt_params, num_classes=5
+        )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -738,39 +767,44 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_pretrained_vit_base_imagenet(self):
     def test_pretrained_vit_base_imagenet(self):
         trainer = Trainer("imagenet_pretrained_vit_base")
         trainer = Trainer("imagenet_pretrained_vit_base")
-        model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.VIT_BASE, arch_params=self.imagenet_pretrained_arch_params[Models.VIT_BASE], **self.imagenet_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0]
             trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0]
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_base"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.VIT_BASE], delta=0.001)
 
 
     def test_pretrained_vit_large_imagenet(self):
     def test_pretrained_vit_large_imagenet(self):
         trainer = Trainer("imagenet_pretrained_vit_large")
         trainer = Trainer("imagenet_pretrained_vit_large")
-        model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.VIT_LARGE, arch_params=self.imagenet_pretrained_arch_params[Models.VIT_BASE], **self.imagenet_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0]
             trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0]
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_large"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.VIT_LARGE], delta=0.001)
 
 
     def test_pretrained_beit_base_imagenet(self):
     def test_pretrained_beit_base_imagenet(self):
         trainer = Trainer("imagenet_pretrained_beit_base")
         trainer = Trainer("imagenet_pretrained_beit_base")
-        model = models.get("beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(
+            Models.BEIT_BASE_PATCH16_224, arch_params=self.imagenet_pretrained_arch_params[Models.VIT_BASE], **self.imagenet_pretrained_ckpt_params
+        )
         res = (
         res = (
             trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0]
             trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0]
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["beit_base_patch16_224"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.BEIT_BASE_PATCH16_224], delta=0.001)
 
 
     def test_transfer_learning_beit_base_imagenet(self):
     def test_transfer_learning_beit_base_imagenet(self):
         trainer = Trainer("test_transfer_learning_beit_base_imagenet")
         trainer = Trainer("test_transfer_learning_beit_base_imagenet")
 
 
         model = models.get(
         model = models.get(
-            "beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet_pretrained_ckpt_params, num_classes=5
+            Models.BEIT_BASE_PATCH16_224,
+            arch_params=self.imagenet_pretrained_arch_params[Models.VIT_BASE],
+            **self.imagenet_pretrained_ckpt_params,
+            num_classes=5,
         )
         )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -781,7 +815,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_pretrained_pplite_t_seg50_cityscapes(self):
     def test_pretrained_pplite_t_seg50_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_pplite_t_seg50")
         trainer = Trainer("cityscapes_pretrained_pplite_t_seg50")
-        model = models.get("pp_lite_t_seg50", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.PP_LITE_T_SEG50, arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
 
 
         res = (
         res = (
             trainer.test(
             trainer.test(
@@ -793,11 +827,11 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_t_seg50"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.PP_LITE_T_SEG50], delta=0.001)
 
 
     def test_pretrained_pplite_t_seg75_cityscapes(self):
     def test_pretrained_pplite_t_seg75_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_pplite_t_seg75")
         trainer = Trainer("cityscapes_pretrained_pplite_t_seg75")
-        model = models.get("pp_lite_t_seg75", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.PP_LITE_T_SEG75, arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
 
 
         res = (
         res = (
             trainer.test(
             trainer.test(
@@ -809,11 +843,11 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_t_seg75"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.PP_LITE_T_SEG75], delta=0.001)
 
 
     def test_pretrained_pplite_b_seg50_cityscapes(self):
     def test_pretrained_pplite_b_seg50_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_pplite_b_seg50")
         trainer = Trainer("cityscapes_pretrained_pplite_b_seg50")
-        model = models.get("pp_lite_b_seg50", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.PP_LITE_B_SEG50, arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
 
 
         res = (
         res = (
             trainer.test(
             trainer.test(
@@ -825,11 +859,11 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_b_seg50"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.PP_LITE_B_SEG50], delta=0.001)
 
 
     def test_pretrained_pplite_b_seg75_cityscapes(self):
     def test_pretrained_pplite_b_seg75_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_pplite_b_seg75")
         trainer = Trainer("cityscapes_pretrained_pplite_b_seg75")
-        model = models.get("pp_lite_b_seg75", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.PP_LITE_B_SEG75, arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
 
 
         res = (
         res = (
             trainer.test(
             trainer.test(
@@ -841,7 +875,7 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_b_seg75"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.PP_LITE_B_SEG75], delta=0.001)
 
 
     def tearDown(self) -> None:
     def tearDown(self) -> None:
         if os.path.exists("~/.cache/torch/hub/"):
         if os.path.exists("~/.cache/torch/hub/"):
Discard
@@ -5,6 +5,7 @@ import unittest
 import pkg_resources
 import pkg_resources
 from omegaconf import OmegaConf
 from omegaconf import OmegaConf
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training.models import SgModule, get_arch_params
 from super_gradients.training.models import SgModule, get_arch_params
 from super_gradients.training.models.model_factory import get_architecture
 from super_gradients.training.models.model_factory import get_architecture
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.utils import HpmStruct
@@ -215,7 +216,7 @@ class ConfigInspectTest(unittest.TestCase):
 
 
     def test_resnet18_cifar_arch_params(self):
     def test_resnet18_cifar_arch_params(self):
         arch_params = get_arch_params("resnet18_cifar_arch_params")
         arch_params = get_arch_params("resnet18_cifar_arch_params")
-        architecture_cls, arch_params, pretrained_weights_path, is_remote = get_architecture("resnet18", HpmStruct(**arch_params))
+        architecture_cls, arch_params, pretrained_weights_path, is_remote = get_architecture(Models.RESNET18, HpmStruct(**arch_params))
 
 
         with raise_if_unused_params(arch_params) as tracked_arch_params:
         with raise_if_unused_params(arch_params) as tracked_arch_params:
             _ = architecture_cls(arch_params=tracked_arch_params)
             _ = architecture_cls(arch_params=tracked_arch_params)
Discard
@@ -1,5 +1,6 @@
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training.dataloaders.dataloaders import coco2017_train, coco2017_val
 from super_gradients.training.dataloaders.dataloaders import coco2017_train, coco2017_val
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 
 
@@ -19,7 +20,7 @@ class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
 
 
         trainer = Trainer("dataset_statistics_visual_test")
         trainer = Trainer("dataset_statistics_visual_test")
 
 
-        model = models.get("yolox_s")
+        model = models.get(Models.YOLOX_S)
 
 
         training_params = {
         training_params = {
             "max_epochs": 1,  # we dont really need the actual training to run
             "max_epochs": 1,  # we dont really need the actual training to run
Discard
@@ -4,6 +4,7 @@ import unittest
 import numpy as np
 import numpy as np
 import torch.cuda
 import torch.cuda
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import Trainer, utils as core_utils, models
 from super_gradients.training import Trainer, utils as core_utils, models
 from super_gradients.training.dataloaders.dataloaders import coco2017_val
 from super_gradients.training.dataloaders.dataloaders import coco2017_val
 from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
 from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
@@ -16,7 +17,7 @@ from tests.core_test_utils import is_data_available
 class TestDetectionUtils(unittest.TestCase):
 class TestDetectionUtils(unittest.TestCase):
     def setUp(self):
     def setUp(self):
         self.device = "cuda" if torch.cuda.is_available() else "cpu"
         self.device = "cuda" if torch.cuda.is_available() else "cpu"
-        self.model = models.get("yolox_n", pretrained_weights="coco").to(self.device)
+        self.model = models.get(Models.YOLOX_N, pretrained_weights="coco").to(self.device)
         self.model.eval()
         self.model.eval()
 
 
     @unittest.skipIf(not is_data_available(), "run only when /data is available")
     @unittest.skipIf(not is_data_available(), "run only when /data is available")
Discard
@@ -1,6 +1,7 @@
 import unittest
 import unittest
 import torch
 import torch
 from super_gradients import Trainer
 from super_gradients import Trainer
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
@@ -18,7 +19,7 @@ class CallTrainTwiceTest(unittest.TestCase):
         trainer = Trainer("external_criterion_test")
         trainer = Trainer("external_criterion_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
Discard
@@ -1,6 +1,7 @@
 import tempfile
 import tempfile
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from torchvision.transforms import Compose, Normalize, Resize
 from torchvision.transforms import Compose, Normalize, Resize
 from super_gradients.training.transforms import Standardize
 from super_gradients.training.transforms import Standardize
@@ -9,7 +10,7 @@ import os
 
 
 class TestModelsONNXExport(unittest.TestCase):
 class TestModelsONNXExport(unittest.TestCase):
     def test_models_onnx_export(self):
     def test_models_onnx_export(self):
-        pretrained_model = models.get("resnet18", num_classes=1000, pretrained_weights="imagenet")
+        pretrained_model = models.get(Models.RESNET18, num_classes=1000, pretrained_weights="imagenet")
         preprocess = Compose([Resize(224), Standardize(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
         preprocess = Compose([Resize(224), Standardize(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
         with tempfile.TemporaryDirectory() as tmpdirname:
         with tempfile.TemporaryDirectory() as tmpdirname:
             out_path = os.path.join(tmpdirname, "resnet18.onnx")
             out_path = os.path.join(tmpdirname, "resnet18.onnx")
Discard
@@ -5,6 +5,7 @@ import torch
 from super_gradients import Trainer
 from super_gradients import Trainer
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
 from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.losses import LabelSmoothingCrossEntropyLoss
 from super_gradients.training.losses import LabelSmoothingCrossEntropyLoss
@@ -15,7 +16,7 @@ from torch import nn
 class FactoriesTest(unittest.TestCase):
 class FactoriesTest(unittest.TestCase):
     def test_training_with_factories(self):
     def test_training_with_factories(self):
         trainer = Trainer("test_train_with_factories")
         trainer = Trainer("test_train_with_factories")
-        net = models.get("resnet18", num_classes=5)
+        net = models.get(Models.RESNET18, num_classes=5)
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
Discard
@@ -1,4 +1,6 @@
 import unittest
 import unittest
+
+from super_gradients.common.object_names import Models
 from super_gradients.training import Trainer, models
 from super_gradients.training import Trainer, models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
@@ -28,7 +30,7 @@ class ForwardpassPrepFNTest(unittest.TestCase):
     def test_resizing_with_forward_pass_prep_fn(self):
     def test_resizing_with_forward_pass_prep_fn(self):
         # Define Model
         # Define Model
         trainer = Trainer("ForwardpassPrepFNTest")
         trainer = Trainer("ForwardpassPrepFNTest")
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
 
 
         sizes = []
         sizes = []
         phase_callbacks = [TestInputSizesCallback(sizes)]
         phase_callbacks = [TestInputSizesCallback(sizes)]
Discard
@@ -1,5 +1,6 @@
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 
 
 from super_gradients import Trainer
 from super_gradients import Trainer
@@ -27,7 +28,7 @@ class InitializeWithDataloadersTest(unittest.TestCase):
 
 
     def test_train_with_dataloaders(self):
     def test_train_with_dataloaders(self):
         trainer = Trainer(experiment_name="test_name")
         trainer = Trainer(experiment_name="test_name")
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params={
             training_params={
Discard
@@ -8,6 +8,7 @@ import torch
 from super_gradients.training.utils.utils import check_models_have_same_weights
 from super_gradients.training.utils.utils import check_models_have_same_weights
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.losses.kd_losses import KDLogitsLoss
 from super_gradients.training.losses.kd_losses import KDLogitsLoss
+from super_gradients.common.object_names import Models
 
 
 
 
 class KDEMATest(unittest.TestCase):
 class KDEMATest(unittest.TestCase):
@@ -39,8 +40,8 @@ class KDEMATest(unittest.TestCase):
         """Check that the teacher EMA is a reference to the teacher net (not a copy)."""
         """Check that the teacher EMA is a reference to the teacher net (not a copy)."""
 
 
         kd_model = KDTrainer("test_teacher_ema_not_duplicated")
         kd_model = KDTrainer("test_teacher_ema_not_duplicated")
-        student = models.get("resnet18", arch_params={"num_classes": 1000})
-        teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
         kd_model.train(
         kd_model.train(
             training_params=self.kd_train_params,
             training_params=self.kd_train_params,
@@ -59,8 +60,8 @@ class KDEMATest(unittest.TestCase):
         # Create a KD trainer and train it
         # Create a KD trainer and train it
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
         kd_model = KDTrainer("test_kd_ema_ckpt_reload")
         kd_model = KDTrainer("test_kd_ema_ckpt_reload")
-        student = models.get("resnet18", arch_params={"num_classes": 1000})
-        teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
         kd_model.train(
         kd_model.train(
             training_params=self.kd_train_params,
             training_params=self.kd_train_params,
@@ -74,8 +75,8 @@ class KDEMATest(unittest.TestCase):
 
 
         # Load the trained KD trainer
         # Load the trained KD trainer
         kd_model = KDTrainer("test_kd_ema_ckpt_reload")
         kd_model = KDTrainer("test_kd_ema_ckpt_reload")
-        student = models.get("resnet18", arch_params={"num_classes": 1000})
-        teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
         train_params["resume"] = True
         train_params["resume"] = True
         kd_model.train(
         kd_model.train(
Discard
@@ -2,6 +2,7 @@ import os
 import unittest
 import unittest
 from copy import deepcopy
 from copy import deepcopy
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.kd_trainer.kd_trainer import KDTrainer
 from super_gradients.training.kd_trainer.kd_trainer import KDTrainer
 import torch
 import torch
@@ -57,8 +58,8 @@ class KDTrainerTest(unittest.TestCase):
         }
         }
 
 
     def test_teacher_sg_module_methods(self):
     def test_teacher_sg_module_methods(self):
-        student = models.get("resnet18", arch_params={"num_classes": 1000})
-        teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
         kd_module = KDModule(arch_params={}, student=student, teacher=teacher)
         kd_module = KDModule(arch_params={}, student=student, teacher=teacher)
 
 
         initial_param_groups = kd_module.initialize_param_groups(lr=0.1, training_params={})
         initial_param_groups = kd_module.initialize_param_groups(lr=0.1, training_params={})
@@ -87,8 +88,8 @@ class KDTrainerTest(unittest.TestCase):
 
 
     def test_train_model_with_input_adapter(self):
     def test_train_model_with_input_adapter(self):
         kd_trainer = KDTrainer("train_kd_module_with_with_input_adapter")
         kd_trainer = KDTrainer("train_kd_module_with_with_input_adapter")
-        student = models.get("resnet18", arch_params={"num_classes": 5})
-        teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
 
 
         adapter = NormalizationAdapter(
         adapter = NormalizationAdapter(
             mean_original=[0.485, 0.456, 0.406], std_original=[0.229, 0.224, 0.225], mean_required=[0.5, 0.5, 0.5], std_required=[0.5, 0.5, 0.5]
             mean_original=[0.485, 0.456, 0.406], std_original=[0.229, 0.224, 0.225], mean_required=[0.5, 0.5, 0.5], std_required=[0.5, 0.5, 0.5]
@@ -108,8 +109,8 @@ class KDTrainerTest(unittest.TestCase):
 
 
     def test_load_ckpt_best_for_student(self):
     def test_load_ckpt_best_for_student(self):
         kd_trainer = KDTrainer("test_load_ckpt_best")
         kd_trainer = KDTrainer("test_load_ckpt_best")
-        student = models.get("resnet18", arch_params={"num_classes": 5})
-        teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
         train_params["max_epochs"] = 1
         train_params["max_epochs"] = 1
         kd_trainer.train(
         kd_trainer.train(
@@ -121,14 +122,14 @@ class KDTrainerTest(unittest.TestCase):
         )
         )
         best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
         best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
 
 
-        student_reloaded = models.get("resnet18", arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt)
+        student_reloaded = models.get(Models.RESNET18, arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt)
 
 
         self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.net.module.student))
         self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.net.module.student))
 
 
     def test_load_ckpt_best_for_student_with_ema(self):
     def test_load_ckpt_best_for_student_with_ema(self):
         kd_trainer = KDTrainer("test_load_ckpt_best")
         kd_trainer = KDTrainer("test_load_ckpt_best")
-        student = models.get("resnet18", arch_params={"num_classes": 5})
-        teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
         train_params["max_epochs"] = 1
         train_params["max_epochs"] = 1
         train_params["ema"] = True
         train_params["ema"] = True
@@ -141,14 +142,14 @@ class KDTrainerTest(unittest.TestCase):
         )
         )
         best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
         best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
 
 
-        student_reloaded = models.get("resnet18", arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt)
+        student_reloaded = models.get(Models.RESNET18, arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt)
 
 
         self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.ema_model.ema.module.student))
         self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.ema_model.ema.module.student))
 
 
     def test_resume_kd_training(self):
     def test_resume_kd_training(self):
         kd_trainer = KDTrainer("test_resume_training_start")
         kd_trainer = KDTrainer("test_resume_training_start")
-        student = models.get("resnet18", arch_params={"num_classes": 5})
-        teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
         train_params["max_epochs"] = 1
         train_params["max_epochs"] = 1
         kd_trainer.train(
         kd_trainer.train(
@@ -161,8 +162,8 @@ class KDTrainerTest(unittest.TestCase):
         latest_net = deepcopy(kd_trainer.net)
         latest_net = deepcopy(kd_trainer.net)
 
 
         kd_trainer = KDTrainer("test_resume_training_start")
         kd_trainer = KDTrainer("test_resume_training_start")
-        student = models.get("resnet18", arch_params={"num_classes": 5})
-        teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
 
 
         train_params["max_epochs"] = 2
         train_params["max_epochs"] = 2
         train_params["resume"] = True
         train_params["resume"] = True
Discard
@@ -1,5 +1,6 @@
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import Trainer, models
 from super_gradients.training import Trainer, models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
@@ -27,12 +28,12 @@ class LocalCkptHeadReplacementTest(unittest.TestCase):
         }
         }
 
 
         # Define Model
         # Define Model
-        net = models.get("resnet18", num_classes=5)
+        net = models.get(Models.RESNET18, num_classes=5)
         trainer = Trainer("test_resume_training")
         trainer = Trainer("test_resume_training")
         trainer.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
         trainer.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
         ckpt_path = os.path.join(trainer.checkpoints_dir_path, "ckpt_latest.pth")
         ckpt_path = os.path.join(trainer.checkpoints_dir_path, "ckpt_latest.pth")
 
 
-        net2 = models.get("resnet18", num_classes=10, checkpoint_num_classes=5, checkpoint_path=ckpt_path)
+        net2 = models.get(Models.RESNET18, num_classes=10, checkpoint_num_classes=5, checkpoint_path=ckpt_path)
         self.assertFalse(check_models_have_same_weights(net, net2))
         self.assertFalse(check_models_have_same_weights(net, net2))
 
 
         net.linear = None
         net.linear = None
Discard
@@ -4,6 +4,7 @@ from torch import Tensor
 from torchmetrics import Accuracy
 from torchmetrics import Accuracy
 import torch
 import torch
 from super_gradients import Trainer
 from super_gradients import Trainer
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 
 
@@ -29,7 +30,7 @@ class LossLoggingsTest(unittest.TestCase):
         trainer = Trainer("test_single_item_logging", model_checkpoints_location="local")
         trainer = Trainer("test_single_item_logging", model_checkpoints_location="local")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         train_params = {
         train_params = {
             "max_epochs": 1,
             "max_epochs": 1,
             "lr_updates": [1],
             "lr_updates": [1],
@@ -53,7 +54,7 @@ class LossLoggingsTest(unittest.TestCase):
         trainer = Trainer("test_multiple_unnamed_components_loss_logging", model_checkpoints_location="local")
         trainer = Trainer("test_multiple_unnamed_components_loss_logging", model_checkpoints_location="local")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         train_params = {
         train_params = {
             "max_epochs": 1,
             "max_epochs": 1,
             "lr_updates": [1],
             "lr_updates": [1],
@@ -77,7 +78,7 @@ class LossLoggingsTest(unittest.TestCase):
         trainer = Trainer("test_multiple_named_components_loss_logging", model_checkpoints_location="local")
         trainer = Trainer("test_multiple_named_components_loss_logging", model_checkpoints_location="local")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         train_params = {
         train_params = {
             "max_epochs": 1,
             "max_epochs": 1,
             "lr_updates": [1],
             "lr_updates": [1],
Discard
@@ -1,5 +1,6 @@
 import unittest
 import unittest
 import super_gradients
 import super_gradients
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
@@ -11,21 +12,21 @@ import shutil
 class PretrainedModelsUnitTest(unittest.TestCase):
 class PretrainedModelsUnitTest(unittest.TestCase):
     def setUp(self) -> None:
     def setUp(self) -> None:
         super_gradients.init_trainer()
         super_gradients.init_trainer()
-        self.imagenet_pretrained_models = ["resnet50", "repvgg_a0", "regnetY800"]
+        self.imagenet_pretrained_models = [Models.RESNET50, "repvgg_a0", "regnetY800"]
 
 
     def test_pretrained_resnet50_imagenet(self):
     def test_pretrained_resnet50_imagenet(self):
         trainer = Trainer("imagenet_pretrained_resnet50_unit_test")
         trainer = Trainer("imagenet_pretrained_resnet50_unit_test")
-        model = models.get("resnet50", pretrained_weights="imagenet")
+        model = models.get(Models.RESNET50, pretrained_weights="imagenet")
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
 
 
     def test_pretrained_regnetY800_imagenet(self):
     def test_pretrained_regnetY800_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY800_unit_test")
         trainer = Trainer("imagenet_pretrained_regnetY800_unit_test")
-        model = models.get("regnetY800", pretrained_weights="imagenet")
+        model = models.get(Models.REGNETY800, pretrained_weights="imagenet")
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
 
 
     def test_pretrained_repvgg_a0_imagenet(self):
     def test_pretrained_repvgg_a0_imagenet(self):
         trainer = Trainer("imagenet_pretrained_repvgg_a0_unit_test")
         trainer = Trainer("imagenet_pretrained_repvgg_a0_unit_test")
-        model = models.get("repvgg_a0", pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
+        model = models.get(Models.REPVGG_A0, pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
 
 
     def tearDown(self) -> None:
     def tearDown(self) -> None:
Discard
@@ -3,6 +3,8 @@ import torch
 import torchvision
 import torchvision
 from torch import nn
 from torch import nn
 
 
+from super_gradients.common.object_names import Models
+
 try:
 try:
     import super_gradients
     import super_gradients
     from pytorch_quantization import nn as quant_nn
     from pytorch_quantization import nn as quant_nn
@@ -787,7 +789,7 @@ class QuantizationUtilityTest(unittest.TestCase):
         if Bottleneck in sq.mapping_instructions:
         if Bottleneck in sq.mapping_instructions:
             sq.mapping_instructions.pop(Bottleneck)
             sq.mapping_instructions.pop(Bottleneck)
 
 
-        resnet_sg: nn.Module = super_gradients.training.models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
+        resnet_sg: nn.Module = super_gradients.training.models.get(Models.RESNET50, pretrained_weights="imagenet", num_classes=1000)
         sq.quantize_module(resnet_sg, preserve_state_dict=True)
         sq.quantize_module(resnet_sg, preserve_state_dict=True)
 
 
         # PYTORCH-QUANTIZATION
         # PYTORCH-QUANTIZATION
@@ -803,7 +805,7 @@ class QuantizationUtilityTest(unittest.TestCase):
         quant_nn.QuantAdaptiveAvgPool2d.set_default_quant_desc_input(QuantDescriptor(calib_method="histogram"))
         quant_nn.QuantAdaptiveAvgPool2d.set_default_quant_desc_input(QuantDescriptor(calib_method="histogram"))
 
 
         quant_modules.initialize()
         quant_modules.initialize()
-        resnet_pyquant: nn.Module = super_gradients.training.models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
+        resnet_pyquant: nn.Module = super_gradients.training.models.get(Models.RESNET50, pretrained_weights="imagenet", num_classes=1000)
 
 
         quant_modules.deactivate()
         quant_modules.deactivate()
 
 
Discard
@@ -3,6 +3,7 @@ import os
 from super_gradients.training import Trainer, models
 from super_gradients.training import Trainer, models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
+from super_gradients.common.object_names import Models
 
 
 
 
 class SaveCkptListUnitTest(unittest.TestCase):
 class SaveCkptListUnitTest(unittest.TestCase):
@@ -31,7 +32,7 @@ class SaveCkptListUnitTest(unittest.TestCase):
         trainer = Trainer("save_ckpt_test")
         trainer = Trainer("save_ckpt_test")
 
 
         # Build Model
         # Build Model
-        model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
+        model = models.get(Models.RESNET18_CIFAR, arch_params={"num_classes": 10})
 
 
         # Train Model (and save ckpt_epoch_list)
         # Train Model (and save ckpt_epoch_list)
         trainer.train(model=model, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
         trainer.train(model=model, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
Discard
@@ -3,6 +3,7 @@ import shutil
 import tempfile
 import tempfile
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.common.sg_loggers import BaseSGLogger
 from super_gradients.common.sg_loggers import BaseSGLogger
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
 import torch
 import torch
@@ -98,8 +99,8 @@ class StrictLoadEnumTest(unittest.TestCase):
 
 
     def test_strict_load_on(self):
     def test_strict_load_on(self):
         # Define Model
         # Define Model
-        model = models.get("resnet18", arch_params={"num_classes": 1000})
-        pretrained_model = models.get("resnet18", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        pretrained_model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
         # Make sure we initialized a model with different weights
         # Make sure we initialized a model with different weights
         assert not self.check_models_have_same_weights(model, pretrained_model)
         assert not self.check_models_have_same_weights(model, pretrained_model)
@@ -107,15 +108,15 @@ class StrictLoadEnumTest(unittest.TestCase):
         pretrained_sd_path = os.path.join(self.temp_working_file_dir, "pretrained_net_strict_load_on.pth")
         pretrained_sd_path = os.path.join(self.temp_working_file_dir, "pretrained_net_strict_load_on.pth")
         torch.save(pretrained_model.state_dict(), pretrained_sd_path)
         torch.save(pretrained_model.state_dict(), pretrained_sd_path)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
 
 
         # Assert the weights were loaded correctly
         # Assert the weights were loaded correctly
         assert self.check_models_have_same_weights(model, pretrained_model)
         assert self.check_models_have_same_weights(model, pretrained_model)
 
 
     def test_strict_load_off(self):
     def test_strict_load_off(self):
         # Define Model
         # Define Model
-        model = models.get("resnet18", arch_params={"num_classes": 1000})
-        pretrained_model = models.get("resnet18", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        pretrained_model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
         # Make sure we initialized a model with different weights
         # Make sure we initialized a model with different weights
         assert not self.check_models_have_same_weights(model, pretrained_model)
         assert not self.check_models_have_same_weights(model, pretrained_model)
@@ -125,17 +126,17 @@ class StrictLoadEnumTest(unittest.TestCase):
         torch.save(pretrained_model.state_dict(), pretrained_sd_path)
         torch.save(pretrained_model.state_dict(), pretrained_sd_path)
 
 
         with self.assertRaises(RuntimeError):
         with self.assertRaises(RuntimeError):
-            models.get("resnet18", arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
+            models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.OFF)
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.OFF)
         del model.linear
         del model.linear
         # Assert the weights were loaded correctly
         # Assert the weights were loaded correctly
         assert self.check_models_have_same_weights(model, pretrained_model)
         assert self.check_models_have_same_weights(model, pretrained_model)
 
 
     def test_strict_load_no_key_matching_sg_checkpoint(self):
     def test_strict_load_no_key_matching_sg_checkpoint(self):
         # Define Model
         # Define Model
-        model = models.get("resnet18", arch_params={"num_classes": 1000})
-        pretrained_model = models.get("resnet18", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        pretrained_model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
         # Make sure we initialized a model with different weights
         # Make sure we initialized a model with different weights
         assert not self.check_models_have_same_weights(model, pretrained_model)
         assert not self.check_models_have_same_weights(model, pretrained_model)
@@ -144,9 +145,9 @@ class StrictLoadEnumTest(unittest.TestCase):
         torch.save(self.change_state_dict_keys(pretrained_model.state_dict()), pretrained_sd_path)
         torch.save(self.change_state_dict_keys(pretrained_model.state_dict()), pretrained_sd_path)
 
 
         with self.assertRaises(RuntimeError):
         with self.assertRaises(RuntimeError):
-            models.get("resnet18", arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
+            models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.NO_KEY_MATCHING)
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.NO_KEY_MATCHING)
         # Assert the weights were loaded correctly
         # Assert the weights were loaded correctly
         assert self.check_models_have_same_weights(model, pretrained_model)
         assert self.check_models_have_same_weights(model, pretrained_model)
 
 
Discard
@@ -8,6 +8,7 @@ from super_gradients.training import models
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
+from super_gradients.common.object_names import Models
 
 
 
 
 class TestWithoutTrainTest(unittest.TestCase):
 class TestWithoutTrainTest(unittest.TestCase):
@@ -26,20 +27,20 @@ class TestWithoutTrainTest(unittest.TestCase):
     @staticmethod
     @staticmethod
     def get_classification_trainer(name=""):
     def get_classification_trainer(name=""):
         trainer = Trainer(name)
         trainer = Trainer(name)
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         return trainer, model
         return trainer, model
 
 
     @staticmethod
     @staticmethod
     def get_detection_trainer(name=""):
     def get_detection_trainer(name=""):
         trainer = Trainer(name)
         trainer = Trainer(name)
-        model = models.get("yolox_s", num_classes=5)
+        model = models.get(Models.YOLOX_S, num_classes=5)
         return trainer, model
         return trainer, model
 
 
     @staticmethod
     @staticmethod
     def get_segmentation_trainer(name=""):
     def get_segmentation_trainer(name=""):
         shelfnet_lw_arch_params = {"num_classes": 5}
         shelfnet_lw_arch_params = {"num_classes": 5}
         trainer = Trainer(name)
         trainer = Trainer(name)
-        model = models.get("shelfnet34_lw", arch_params=shelfnet_lw_arch_params)
+        model = models.get(Models.SHELFNET34_LW, arch_params=shelfnet_lw_arch_params)
         return trainer, model
         return trainer, model
 
 
     def test_test_without_train(self):
     def test_test_without_train(self):
Discard
@@ -1,6 +1,7 @@
 import unittest
 import unittest
 import torch
 import torch
 from super_gradients import Trainer
 from super_gradients import Trainer
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
@@ -18,7 +19,7 @@ class CallTrainAfterTestTest(unittest.TestCase):
         trainer = Trainer("test_call_train_after_test")
         trainer = Trainer("test_call_train_after_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
@@ -42,7 +43,7 @@ class CallTrainAfterTestTest(unittest.TestCase):
         trainer = Trainer("test_call_train_after_test_with_loss")
         trainer = Trainer("test_call_train_after_test_with_loss")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
Discard
@@ -7,6 +7,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR
 from torchmetrics import F1Score
 from torchmetrics import F1Score
 
 
 from super_gradients import Trainer
 from super_gradients import Trainer
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5, ToyTestClassificationMetric
 from super_gradients.training.metrics import Accuracy, Top5, ToyTestClassificationMetric
@@ -22,7 +23,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer = Trainer("external_criterion_test")
         trainer = Trainer("external_criterion_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
@@ -45,7 +46,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer = Trainer("external_optimizer_test")
         trainer = Trainer("external_optimizer_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         optimizer = SGD(params=model.parameters(), lr=0.1)
         optimizer = SGD(params=model.parameters(), lr=0.1)
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
@@ -70,7 +71,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
         lr = 0.3
         lr = 0.3
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         optimizer = SGD(params=model.parameters(), lr=lr)
         optimizer = SGD(params=model.parameters(), lr=lr)
         lr_scheduler = MultiStepLR(optimizer=optimizer, milestones=[1, 2], gamma=0.1)
         lr_scheduler = MultiStepLR(optimizer=optimizer, milestones=[1, 2], gamma=0.1)
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.TRAIN_EPOCH_END)]
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.TRAIN_EPOCH_END)]
@@ -95,7 +96,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer = Trainer("external_scheduler_test")
         trainer = Trainer("external_scheduler_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         optimizer = SGD  # a class - not an instance
         optimizer = SGD  # a class - not an instance
 
 
         train_params = {
         train_params = {
@@ -117,7 +118,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
         lr = 0.3
         lr = 0.3
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         optimizer = SGD(params=model.parameters(), lr=lr)
         optimizer = SGD(params=model.parameters(), lr=lr)
         lr_scheduler = ReduceLROnPlateau(optimizer=optimizer, patience=0)
         lr_scheduler = ReduceLROnPlateau(optimizer=optimizer, patience=0)
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.VALIDATION_EPOCH_END, "ToyTestClassificationMetric")]
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.VALIDATION_EPOCH_END, "ToyTestClassificationMetric")]
@@ -142,7 +143,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer = Trainer("external_metric_test")
         trainer = Trainer("external_metric_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
@@ -172,7 +173,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size)
         train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size)
         val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size)
         val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size)
 
 
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
Discard
@@ -1,5 +1,6 @@
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients import Trainer
 from super_gradients import Trainer
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
@@ -31,7 +32,7 @@ class TestViT(unittest.TestCase):
         Validate vit_base
         Validate vit_base
         """
         """
         trainer = Trainer("test_vit_base")
         trainer = Trainer("test_vit_base")
-        model = models.get("vit_base", arch_params=self.arch_params, num_classes=5)
+        model = models.get(Models.VIT_BASE, arch_params=self.arch_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model, training_params=self.train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
             model=model, training_params=self.train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
         )
         )
Discard