Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#614 Feature/sg 493 modelnames instead of strings

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-493_modelnames_instead_of_strings
@@ -5,6 +5,7 @@ import unittest
 import pkg_resources
 import pkg_resources
 from omegaconf import OmegaConf
 from omegaconf import OmegaConf
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training.models import SgModule, get_arch_params
 from super_gradients.training.models import SgModule, get_arch_params
 from super_gradients.training.models.model_factory import get_architecture
 from super_gradients.training.models.model_factory import get_architecture
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.utils import HpmStruct
@@ -215,7 +216,7 @@ class ConfigInspectTest(unittest.TestCase):
 
 
     def test_resnet18_cifar_arch_params(self):
     def test_resnet18_cifar_arch_params(self):
         arch_params = get_arch_params("resnet18_cifar_arch_params")
         arch_params = get_arch_params("resnet18_cifar_arch_params")
-        architecture_cls, arch_params, pretrained_weights_path, is_remote = get_architecture("resnet18", HpmStruct(**arch_params))
+        architecture_cls, arch_params, pretrained_weights_path, is_remote = get_architecture(Models.RESNET18, HpmStruct(**arch_params))
 
 
         with raise_if_unused_params(arch_params) as tracked_arch_params:
         with raise_if_unused_params(arch_params) as tracked_arch_params:
             _ = architecture_cls(arch_params=tracked_arch_params)
             _ = architecture_cls(arch_params=tracked_arch_params)
Discard
@@ -1,5 +1,6 @@
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training.dataloaders.dataloaders import coco2017_train, coco2017_val
 from super_gradients.training.dataloaders.dataloaders import coco2017_train, coco2017_val
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 
 
@@ -19,7 +20,7 @@ class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
 
 
         trainer = Trainer("dataset_statistics_visual_test")
         trainer = Trainer("dataset_statistics_visual_test")
 
 
-        model = models.get("yolox_s")
+        model = models.get(Models.YOLOX_S)
 
 
         training_params = {
         training_params = {
             "max_epochs": 1,  # we dont really need the actual training to run
             "max_epochs": 1,  # we dont really need the actual training to run
Discard
@@ -4,6 +4,7 @@ import unittest
 import numpy as np
 import numpy as np
 import torch.cuda
 import torch.cuda
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import Trainer, utils as core_utils, models
 from super_gradients.training import Trainer, utils as core_utils, models
 from super_gradients.training.dataloaders.dataloaders import coco2017_val
 from super_gradients.training.dataloaders.dataloaders import coco2017_val
 from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
 from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
@@ -16,7 +17,7 @@ from tests.core_test_utils import is_data_available
 class TestDetectionUtils(unittest.TestCase):
 class TestDetectionUtils(unittest.TestCase):
     def setUp(self):
     def setUp(self):
         self.device = "cuda" if torch.cuda.is_available() else "cpu"
         self.device = "cuda" if torch.cuda.is_available() else "cpu"
-        self.model = models.get("yolox_n", pretrained_weights="coco").to(self.device)
+        self.model = models.get(Models.YOLOX_N, pretrained_weights="coco").to(self.device)
         self.model.eval()
         self.model.eval()
 
 
     @unittest.skipIf(not is_data_available(), "run only when /data is available")
     @unittest.skipIf(not is_data_available(), "run only when /data is available")
Discard
@@ -1,6 +1,7 @@
 import unittest
 import unittest
 import torch
 import torch
 from super_gradients import Trainer
 from super_gradients import Trainer
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
@@ -18,7 +19,7 @@ class CallTrainTwiceTest(unittest.TestCase):
         trainer = Trainer("external_criterion_test")
         trainer = Trainer("external_criterion_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
Discard
@@ -1,6 +1,7 @@
 import tempfile
 import tempfile
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from torchvision.transforms import Compose, Normalize, Resize
 from torchvision.transforms import Compose, Normalize, Resize
 from super_gradients.training.transforms import Standardize
 from super_gradients.training.transforms import Standardize
@@ -9,7 +10,7 @@ import os
 
 
 class TestModelsONNXExport(unittest.TestCase):
 class TestModelsONNXExport(unittest.TestCase):
     def test_models_onnx_export(self):
     def test_models_onnx_export(self):
-        pretrained_model = models.get("resnet18", num_classes=1000, pretrained_weights="imagenet")
+        pretrained_model = models.get(Models.RESNET18, num_classes=1000, pretrained_weights="imagenet")
         preprocess = Compose([Resize(224), Standardize(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
         preprocess = Compose([Resize(224), Standardize(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
         with tempfile.TemporaryDirectory() as tmpdirname:
         with tempfile.TemporaryDirectory() as tmpdirname:
             out_path = os.path.join(tmpdirname, "resnet18.onnx")
             out_path = os.path.join(tmpdirname, "resnet18.onnx")
Discard
@@ -5,6 +5,7 @@ import torch
 from super_gradients import Trainer
 from super_gradients import Trainer
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
 from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.losses import LabelSmoothingCrossEntropyLoss
 from super_gradients.training.losses import LabelSmoothingCrossEntropyLoss
@@ -15,7 +16,7 @@ from torch import nn
 class FactoriesTest(unittest.TestCase):
 class FactoriesTest(unittest.TestCase):
     def test_training_with_factories(self):
     def test_training_with_factories(self):
         trainer = Trainer("test_train_with_factories")
         trainer = Trainer("test_train_with_factories")
-        net = models.get("resnet18", num_classes=5)
+        net = models.get(Models.RESNET18, num_classes=5)
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
Discard
@@ -1,4 +1,6 @@
 import unittest
 import unittest
+
+from super_gradients.common.object_names import Models
 from super_gradients.training import Trainer, models
 from super_gradients.training import Trainer, models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
@@ -28,7 +30,7 @@ class ForwardpassPrepFNTest(unittest.TestCase):
     def test_resizing_with_forward_pass_prep_fn(self):
     def test_resizing_with_forward_pass_prep_fn(self):
         # Define Model
         # Define Model
         trainer = Trainer("ForwardpassPrepFNTest")
         trainer = Trainer("ForwardpassPrepFNTest")
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
 
 
         sizes = []
         sizes = []
         phase_callbacks = [TestInputSizesCallback(sizes)]
         phase_callbacks = [TestInputSizesCallback(sizes)]
Discard
@@ -1,5 +1,6 @@
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 
 
 from super_gradients import Trainer
 from super_gradients import Trainer
@@ -27,7 +28,7 @@ class InitializeWithDataloadersTest(unittest.TestCase):
 
 
     def test_train_with_dataloaders(self):
     def test_train_with_dataloaders(self):
         trainer = Trainer(experiment_name="test_name")
         trainer = Trainer(experiment_name="test_name")
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params={
             training_params={
Discard
@@ -8,6 +8,7 @@ import torch
 from super_gradients.training.utils.utils import check_models_have_same_weights
 from super_gradients.training.utils.utils import check_models_have_same_weights
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.losses.kd_losses import KDLogitsLoss
 from super_gradients.training.losses.kd_losses import KDLogitsLoss
+from super_gradients.common.object_names import Models
 
 
 
 
 class KDEMATest(unittest.TestCase):
 class KDEMATest(unittest.TestCase):
@@ -39,8 +40,8 @@ class KDEMATest(unittest.TestCase):
         """Check that the teacher EMA is a reference to the teacher net (not a copy)."""
         """Check that the teacher EMA is a reference to the teacher net (not a copy)."""
 
 
         kd_model = KDTrainer("test_teacher_ema_not_duplicated")
         kd_model = KDTrainer("test_teacher_ema_not_duplicated")
-        student = models.get("resnet18", arch_params={"num_classes": 1000})
-        teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
         kd_model.train(
         kd_model.train(
             training_params=self.kd_train_params,
             training_params=self.kd_train_params,
@@ -59,8 +60,8 @@ class KDEMATest(unittest.TestCase):
         # Create a KD trainer and train it
         # Create a KD trainer and train it
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
         kd_model = KDTrainer("test_kd_ema_ckpt_reload")
         kd_model = KDTrainer("test_kd_ema_ckpt_reload")
-        student = models.get("resnet18", arch_params={"num_classes": 1000})
-        teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
         kd_model.train(
         kd_model.train(
             training_params=self.kd_train_params,
             training_params=self.kd_train_params,
@@ -74,8 +75,8 @@ class KDEMATest(unittest.TestCase):
 
 
         # Load the trained KD trainer
         # Load the trained KD trainer
         kd_model = KDTrainer("test_kd_ema_ckpt_reload")
         kd_model = KDTrainer("test_kd_ema_ckpt_reload")
-        student = models.get("resnet18", arch_params={"num_classes": 1000})
-        teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
         train_params["resume"] = True
         train_params["resume"] = True
         kd_model.train(
         kd_model.train(
Discard
@@ -2,6 +2,7 @@ import os
 import unittest
 import unittest
 from copy import deepcopy
 from copy import deepcopy
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.kd_trainer.kd_trainer import KDTrainer
 from super_gradients.training.kd_trainer.kd_trainer import KDTrainer
 import torch
 import torch
@@ -57,8 +58,8 @@ class KDTrainerTest(unittest.TestCase):
         }
         }
 
 
     def test_teacher_sg_module_methods(self):
     def test_teacher_sg_module_methods(self):
-        student = models.get("resnet18", arch_params={"num_classes": 1000})
-        teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
         kd_module = KDModule(arch_params={}, student=student, teacher=teacher)
         kd_module = KDModule(arch_params={}, student=student, teacher=teacher)
 
 
         initial_param_groups = kd_module.initialize_param_groups(lr=0.1, training_params={})
         initial_param_groups = kd_module.initialize_param_groups(lr=0.1, training_params={})
@@ -87,8 +88,8 @@ class KDTrainerTest(unittest.TestCase):
 
 
     def test_train_model_with_input_adapter(self):
     def test_train_model_with_input_adapter(self):
         kd_trainer = KDTrainer("train_kd_module_with_with_input_adapter")
         kd_trainer = KDTrainer("train_kd_module_with_with_input_adapter")
-        student = models.get("resnet18", arch_params={"num_classes": 5})
-        teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
 
 
         adapter = NormalizationAdapter(
         adapter = NormalizationAdapter(
             mean_original=[0.485, 0.456, 0.406], std_original=[0.229, 0.224, 0.225], mean_required=[0.5, 0.5, 0.5], std_required=[0.5, 0.5, 0.5]
             mean_original=[0.485, 0.456, 0.406], std_original=[0.229, 0.224, 0.225], mean_required=[0.5, 0.5, 0.5], std_required=[0.5, 0.5, 0.5]
@@ -108,8 +109,8 @@ class KDTrainerTest(unittest.TestCase):
 
 
     def test_load_ckpt_best_for_student(self):
     def test_load_ckpt_best_for_student(self):
         kd_trainer = KDTrainer("test_load_ckpt_best")
         kd_trainer = KDTrainer("test_load_ckpt_best")
-        student = models.get("resnet18", arch_params={"num_classes": 5})
-        teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
         train_params["max_epochs"] = 1
         train_params["max_epochs"] = 1
         kd_trainer.train(
         kd_trainer.train(
@@ -121,14 +122,14 @@ class KDTrainerTest(unittest.TestCase):
         )
         )
         best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
         best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
 
 
-        student_reloaded = models.get("resnet18", arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt)
+        student_reloaded = models.get(Models.RESNET18, arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt)
 
 
         self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.net.module.student))
         self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.net.module.student))
 
 
     def test_load_ckpt_best_for_student_with_ema(self):
     def test_load_ckpt_best_for_student_with_ema(self):
         kd_trainer = KDTrainer("test_load_ckpt_best")
         kd_trainer = KDTrainer("test_load_ckpt_best")
-        student = models.get("resnet18", arch_params={"num_classes": 5})
-        teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
         train_params["max_epochs"] = 1
         train_params["max_epochs"] = 1
         train_params["ema"] = True
         train_params["ema"] = True
@@ -141,14 +142,14 @@ class KDTrainerTest(unittest.TestCase):
         )
         )
         best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
         best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
 
 
-        student_reloaded = models.get("resnet18", arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt)
+        student_reloaded = models.get(Models.RESNET18, arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt)
 
 
         self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.ema_model.ema.module.student))
         self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.ema_model.ema.module.student))
 
 
     def test_resume_kd_training(self):
     def test_resume_kd_training(self):
         kd_trainer = KDTrainer("test_resume_training_start")
         kd_trainer = KDTrainer("test_resume_training_start")
-        student = models.get("resnet18", arch_params={"num_classes": 5})
-        teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
         train_params["max_epochs"] = 1
         train_params["max_epochs"] = 1
         kd_trainer.train(
         kd_trainer.train(
@@ -161,8 +162,8 @@ class KDTrainerTest(unittest.TestCase):
         latest_net = deepcopy(kd_trainer.net)
         latest_net = deepcopy(kd_trainer.net)
 
 
         kd_trainer = KDTrainer("test_resume_training_start")
         kd_trainer = KDTrainer("test_resume_training_start")
-        student = models.get("resnet18", arch_params={"num_classes": 5})
-        teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
 
 
         train_params["max_epochs"] = 2
         train_params["max_epochs"] = 2
         train_params["resume"] = True
         train_params["resume"] = True
Discard
@@ -1,5 +1,6 @@
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import Trainer, models
 from super_gradients.training import Trainer, models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
@@ -27,12 +28,12 @@ class LocalCkptHeadReplacementTest(unittest.TestCase):
         }
         }
 
 
         # Define Model
         # Define Model
-        net = models.get("resnet18", num_classes=5)
+        net = models.get(Models.RESNET18, num_classes=5)
         trainer = Trainer("test_resume_training")
         trainer = Trainer("test_resume_training")
         trainer.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
         trainer.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
         ckpt_path = os.path.join(trainer.checkpoints_dir_path, "ckpt_latest.pth")
         ckpt_path = os.path.join(trainer.checkpoints_dir_path, "ckpt_latest.pth")
 
 
-        net2 = models.get("resnet18", num_classes=10, checkpoint_num_classes=5, checkpoint_path=ckpt_path)
+        net2 = models.get(Models.RESNET18, num_classes=10, checkpoint_num_classes=5, checkpoint_path=ckpt_path)
         self.assertFalse(check_models_have_same_weights(net, net2))
         self.assertFalse(check_models_have_same_weights(net, net2))
 
 
         net.linear = None
         net.linear = None
Discard
@@ -4,6 +4,7 @@ from torch import Tensor
 from torchmetrics import Accuracy
 from torchmetrics import Accuracy
 import torch
 import torch
 from super_gradients import Trainer
 from super_gradients import Trainer
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 
 
@@ -29,7 +30,7 @@ class LossLoggingsTest(unittest.TestCase):
         trainer = Trainer("test_single_item_logging", model_checkpoints_location="local")
         trainer = Trainer("test_single_item_logging", model_checkpoints_location="local")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         train_params = {
         train_params = {
             "max_epochs": 1,
             "max_epochs": 1,
             "lr_updates": [1],
             "lr_updates": [1],
@@ -53,7 +54,7 @@ class LossLoggingsTest(unittest.TestCase):
         trainer = Trainer("test_multiple_unnamed_components_loss_logging", model_checkpoints_location="local")
         trainer = Trainer("test_multiple_unnamed_components_loss_logging", model_checkpoints_location="local")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         train_params = {
         train_params = {
             "max_epochs": 1,
             "max_epochs": 1,
             "lr_updates": [1],
             "lr_updates": [1],
@@ -77,7 +78,7 @@ class LossLoggingsTest(unittest.TestCase):
         trainer = Trainer("test_multiple_named_components_loss_logging", model_checkpoints_location="local")
         trainer = Trainer("test_multiple_named_components_loss_logging", model_checkpoints_location="local")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         train_params = {
         train_params = {
             "max_epochs": 1,
             "max_epochs": 1,
             "lr_updates": [1],
             "lr_updates": [1],
Discard
@@ -1,5 +1,6 @@
 import unittest
 import unittest
 import super_gradients
 import super_gradients
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
@@ -11,21 +12,21 @@ import shutil
 class PretrainedModelsUnitTest(unittest.TestCase):
 class PretrainedModelsUnitTest(unittest.TestCase):
     def setUp(self) -> None:
     def setUp(self) -> None:
         super_gradients.init_trainer()
         super_gradients.init_trainer()
-        self.imagenet_pretrained_models = ["resnet50", "repvgg_a0", "regnetY800"]
+        self.imagenet_pretrained_models = [Models.RESNET50, "repvgg_a0", "regnetY800"]
 
 
     def test_pretrained_resnet50_imagenet(self):
     def test_pretrained_resnet50_imagenet(self):
         trainer = Trainer("imagenet_pretrained_resnet50_unit_test")
         trainer = Trainer("imagenet_pretrained_resnet50_unit_test")
-        model = models.get("resnet50", pretrained_weights="imagenet")
+        model = models.get(Models.RESNET50, pretrained_weights="imagenet")
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
 
 
     def test_pretrained_regnetY800_imagenet(self):
     def test_pretrained_regnetY800_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY800_unit_test")
         trainer = Trainer("imagenet_pretrained_regnetY800_unit_test")
-        model = models.get("regnetY800", pretrained_weights="imagenet")
+        model = models.get(Models.REGNETY800, pretrained_weights="imagenet")
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
 
 
     def test_pretrained_repvgg_a0_imagenet(self):
     def test_pretrained_repvgg_a0_imagenet(self):
         trainer = Trainer("imagenet_pretrained_repvgg_a0_unit_test")
         trainer = Trainer("imagenet_pretrained_repvgg_a0_unit_test")
-        model = models.get("repvgg_a0", pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
+        model = models.get(Models.REPVGG_A0, pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
 
 
     def tearDown(self) -> None:
     def tearDown(self) -> None:
Discard
@@ -3,6 +3,8 @@ import torch
 import torchvision
 import torchvision
 from torch import nn
 from torch import nn
 
 
+from super_gradients.common.object_names import Models
+
 try:
 try:
     import super_gradients
     import super_gradients
     from pytorch_quantization import nn as quant_nn
     from pytorch_quantization import nn as quant_nn
@@ -787,7 +789,7 @@ class QuantizationUtilityTest(unittest.TestCase):
         if Bottleneck in sq.mapping_instructions:
         if Bottleneck in sq.mapping_instructions:
             sq.mapping_instructions.pop(Bottleneck)
             sq.mapping_instructions.pop(Bottleneck)
 
 
-        resnet_sg: nn.Module = super_gradients.training.models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
+        resnet_sg: nn.Module = super_gradients.training.models.get(Models.RESNET50, pretrained_weights="imagenet", num_classes=1000)
         sq.quantize_module(resnet_sg, preserve_state_dict=True)
         sq.quantize_module(resnet_sg, preserve_state_dict=True)
 
 
         # PYTORCH-QUANTIZATION
         # PYTORCH-QUANTIZATION
@@ -803,7 +805,7 @@ class QuantizationUtilityTest(unittest.TestCase):
         quant_nn.QuantAdaptiveAvgPool2d.set_default_quant_desc_input(QuantDescriptor(calib_method="histogram"))
         quant_nn.QuantAdaptiveAvgPool2d.set_default_quant_desc_input(QuantDescriptor(calib_method="histogram"))
 
 
         quant_modules.initialize()
         quant_modules.initialize()
-        resnet_pyquant: nn.Module = super_gradients.training.models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
+        resnet_pyquant: nn.Module = super_gradients.training.models.get(Models.RESNET50, pretrained_weights="imagenet", num_classes=1000)
 
 
         quant_modules.deactivate()
         quant_modules.deactivate()
 
 
Discard
@@ -3,6 +3,7 @@ import os
 from super_gradients.training import Trainer, models
 from super_gradients.training import Trainer, models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
+from super_gradients.common.object_names import Models
 
 
 
 
 class SaveCkptListUnitTest(unittest.TestCase):
 class SaveCkptListUnitTest(unittest.TestCase):
@@ -31,7 +32,7 @@ class SaveCkptListUnitTest(unittest.TestCase):
         trainer = Trainer("save_ckpt_test")
         trainer = Trainer("save_ckpt_test")
 
 
         # Build Model
         # Build Model
-        model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
+        model = models.get(Models.RESNET18_CIFAR, arch_params={"num_classes": 10})
 
 
         # Train Model (and save ckpt_epoch_list)
         # Train Model (and save ckpt_epoch_list)
         trainer.train(model=model, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
         trainer.train(model=model, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
Discard
@@ -3,6 +3,7 @@ import shutil
 import tempfile
 import tempfile
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.common.sg_loggers import BaseSGLogger
 from super_gradients.common.sg_loggers import BaseSGLogger
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
 import torch
 import torch
@@ -98,8 +99,8 @@ class StrictLoadEnumTest(unittest.TestCase):
 
 
     def test_strict_load_on(self):
     def test_strict_load_on(self):
         # Define Model
         # Define Model
-        model = models.get("resnet18", arch_params={"num_classes": 1000})
-        pretrained_model = models.get("resnet18", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        pretrained_model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
         # Make sure we initialized a model with different weights
         # Make sure we initialized a model with different weights
         assert not self.check_models_have_same_weights(model, pretrained_model)
         assert not self.check_models_have_same_weights(model, pretrained_model)
@@ -107,15 +108,15 @@ class StrictLoadEnumTest(unittest.TestCase):
         pretrained_sd_path = os.path.join(self.temp_working_file_dir, "pretrained_net_strict_load_on.pth")
         pretrained_sd_path = os.path.join(self.temp_working_file_dir, "pretrained_net_strict_load_on.pth")
         torch.save(pretrained_model.state_dict(), pretrained_sd_path)
         torch.save(pretrained_model.state_dict(), pretrained_sd_path)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
 
 
         # Assert the weights were loaded correctly
         # Assert the weights were loaded correctly
         assert self.check_models_have_same_weights(model, pretrained_model)
         assert self.check_models_have_same_weights(model, pretrained_model)
 
 
     def test_strict_load_off(self):
     def test_strict_load_off(self):
         # Define Model
         # Define Model
-        model = models.get("resnet18", arch_params={"num_classes": 1000})
-        pretrained_model = models.get("resnet18", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        pretrained_model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
         # Make sure we initialized a model with different weights
         # Make sure we initialized a model with different weights
         assert not self.check_models_have_same_weights(model, pretrained_model)
         assert not self.check_models_have_same_weights(model, pretrained_model)
@@ -125,17 +126,17 @@ class StrictLoadEnumTest(unittest.TestCase):
         torch.save(pretrained_model.state_dict(), pretrained_sd_path)
         torch.save(pretrained_model.state_dict(), pretrained_sd_path)
 
 
         with self.assertRaises(RuntimeError):
         with self.assertRaises(RuntimeError):
-            models.get("resnet18", arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
+            models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.OFF)
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.OFF)
         del model.linear
         del model.linear
         # Assert the weights were loaded correctly
         # Assert the weights were loaded correctly
         assert self.check_models_have_same_weights(model, pretrained_model)
         assert self.check_models_have_same_weights(model, pretrained_model)
 
 
     def test_strict_load_no_key_matching_sg_checkpoint(self):
     def test_strict_load_no_key_matching_sg_checkpoint(self):
         # Define Model
         # Define Model
-        model = models.get("resnet18", arch_params={"num_classes": 1000})
-        pretrained_model = models.get("resnet18", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        pretrained_model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
         # Make sure we initialized a model with different weights
         # Make sure we initialized a model with different weights
         assert not self.check_models_have_same_weights(model, pretrained_model)
         assert not self.check_models_have_same_weights(model, pretrained_model)
@@ -144,9 +145,9 @@ class StrictLoadEnumTest(unittest.TestCase):
         torch.save(self.change_state_dict_keys(pretrained_model.state_dict()), pretrained_sd_path)
         torch.save(self.change_state_dict_keys(pretrained_model.state_dict()), pretrained_sd_path)
 
 
         with self.assertRaises(RuntimeError):
         with self.assertRaises(RuntimeError):
-            models.get("resnet18", arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
+            models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.NO_KEY_MATCHING)
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.NO_KEY_MATCHING)
         # Assert the weights were loaded correctly
         # Assert the weights were loaded correctly
         assert self.check_models_have_same_weights(model, pretrained_model)
         assert self.check_models_have_same_weights(model, pretrained_model)
 
 
Discard
@@ -8,6 +8,7 @@ from super_gradients.training import models
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
+from super_gradients.common.object_names import Models
 
 
 
 
 class TestWithoutTrainTest(unittest.TestCase):
 class TestWithoutTrainTest(unittest.TestCase):
@@ -26,20 +27,20 @@ class TestWithoutTrainTest(unittest.TestCase):
     @staticmethod
     @staticmethod
     def get_classification_trainer(name=""):
     def get_classification_trainer(name=""):
         trainer = Trainer(name)
         trainer = Trainer(name)
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         return trainer, model
         return trainer, model
 
 
     @staticmethod
     @staticmethod
     def get_detection_trainer(name=""):
     def get_detection_trainer(name=""):
         trainer = Trainer(name)
         trainer = Trainer(name)
-        model = models.get("yolox_s", num_classes=5)
+        model = models.get(Models.YOLOX_S, num_classes=5)
         return trainer, model
         return trainer, model
 
 
     @staticmethod
     @staticmethod
     def get_segmentation_trainer(name=""):
     def get_segmentation_trainer(name=""):
         shelfnet_lw_arch_params = {"num_classes": 5}
         shelfnet_lw_arch_params = {"num_classes": 5}
         trainer = Trainer(name)
         trainer = Trainer(name)
-        model = models.get("shelfnet34_lw", arch_params=shelfnet_lw_arch_params)
+        model = models.get(Models.SHELFNET34_LW, arch_params=shelfnet_lw_arch_params)
         return trainer, model
         return trainer, model
 
 
     def test_test_without_train(self):
     def test_test_without_train(self):
Discard
@@ -1,6 +1,7 @@
 import unittest
 import unittest
 import torch
 import torch
 from super_gradients import Trainer
 from super_gradients import Trainer
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
@@ -18,7 +19,7 @@ class CallTrainAfterTestTest(unittest.TestCase):
         trainer = Trainer("test_call_train_after_test")
         trainer = Trainer("test_call_train_after_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
@@ -42,7 +43,7 @@ class CallTrainAfterTestTest(unittest.TestCase):
         trainer = Trainer("test_call_train_after_test_with_loss")
         trainer = Trainer("test_call_train_after_test_with_loss")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
Discard
@@ -7,6 +7,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR
 from torchmetrics import F1Score
 from torchmetrics import F1Score
 
 
 from super_gradients import Trainer
 from super_gradients import Trainer
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5, ToyTestClassificationMetric
 from super_gradients.training.metrics import Accuracy, Top5, ToyTestClassificationMetric
@@ -22,7 +23,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer = Trainer("external_criterion_test")
         trainer = Trainer("external_criterion_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
@@ -45,7 +46,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer = Trainer("external_optimizer_test")
         trainer = Trainer("external_optimizer_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         optimizer = SGD(params=model.parameters(), lr=0.1)
         optimizer = SGD(params=model.parameters(), lr=0.1)
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
@@ -70,7 +71,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
         lr = 0.3
         lr = 0.3
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         optimizer = SGD(params=model.parameters(), lr=lr)
         optimizer = SGD(params=model.parameters(), lr=lr)
         lr_scheduler = MultiStepLR(optimizer=optimizer, milestones=[1, 2], gamma=0.1)
         lr_scheduler = MultiStepLR(optimizer=optimizer, milestones=[1, 2], gamma=0.1)
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.TRAIN_EPOCH_END)]
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.TRAIN_EPOCH_END)]
@@ -95,7 +96,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer = Trainer("external_scheduler_test")
         trainer = Trainer("external_scheduler_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         optimizer = SGD  # a class - not an instance
         optimizer = SGD  # a class - not an instance
 
 
         train_params = {
         train_params = {
@@ -117,7 +118,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
         lr = 0.3
         lr = 0.3
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         optimizer = SGD(params=model.parameters(), lr=lr)
         optimizer = SGD(params=model.parameters(), lr=lr)
         lr_scheduler = ReduceLROnPlateau(optimizer=optimizer, patience=0)
         lr_scheduler = ReduceLROnPlateau(optimizer=optimizer, patience=0)
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.VALIDATION_EPOCH_END, "ToyTestClassificationMetric")]
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.VALIDATION_EPOCH_END, "ToyTestClassificationMetric")]
@@ -142,7 +143,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer = Trainer("external_metric_test")
         trainer = Trainer("external_metric_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
@@ -172,7 +173,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size)
         train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size)
         val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size)
         val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size)
 
 
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
Discard
@@ -1,5 +1,6 @@
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients import Trainer
 from super_gradients import Trainer
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
@@ -31,7 +32,7 @@ class TestViT(unittest.TestCase):
         Validate vit_base
         Validate vit_base
         """
         """
         trainer = Trainer("test_vit_base")
         trainer = Trainer("test_vit_base")
-        model = models.get("vit_base", arch_params=self.arch_params, num_classes=5)
+        model = models.get(Models.VIT_BASE, arch_params=self.arch_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model, training_params=self.train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
             model=model, training_params=self.train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
         )
         )
Discard