Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#614 Feature/sg 493 modelnames instead of strings

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-493_modelnames_instead_of_strings
35 changed files with 403 additions and 306 deletions
  1. 7
    2
      README.md
  2. 8
    2
      docs/_sources/welcome.md.txt
  3. 2
    1
      documentation/source/welcome.md
  4. 46
    41
      src/super_gradients/examples/ddrnet_imagenet/ddrnet_classification_example.py
  5. 20
    9
      src/super_gradients/examples/early_stop/early_stop_example.py
  6. 2
    1
      src/super_gradients/examples/loggers_examples/clearml_logger_example.py
  7. 3
    1
      src/super_gradients/examples/loggers_examples/deci_platform_logger_example.py
  8. 2
    1
      src/super_gradients/examples/quantization/resnet_qat_example.py
  9. 41
    38
      src/super_gradients/examples/regseg_transfer_learning_example/regseg_transfer_learning_example.py
  10. 3
    1
      src/super_gradients/training/Computer_Vision_Models_Pretrained_Checkpoints.md
  11. 3
    2
      tests/end_to_end_tests/cifar_trainer_test.py
  12. 2
    1
      tests/end_to_end_tests/trainer_test.py
  13. 2
    1
      tests/integration_tests/ema_train_integration_test.py
  14. 2
    1
      tests/integration_tests/lr_test.py
  15. 174
    140
      tests/integration_tests/pretrained_models_test.py
  16. 2
    1
      tests/unit_tests/config_inspector_test.py
  17. 2
    1
      tests/unit_tests/dataset_statistics_test.py
  18. 2
    1
      tests/unit_tests/detection_utils_test.py
  19. 2
    1
      tests/unit_tests/double_training_test.py
  20. 2
    1
      tests/unit_tests/export_onnx_test.py
  21. 2
    1
      tests/unit_tests/factories_test.py
  22. 3
    1
      tests/unit_tests/forward_pass_prep_fn_test.py
  23. 2
    1
      tests/unit_tests/initialize_with_dataloaders_test.py
  24. 7
    6
      tests/unit_tests/kd_ema_test.py
  25. 15
    14
      tests/unit_tests/kd_trainer_test.py
  26. 3
    2
      tests/unit_tests/local_ckpt_head_replacement_test.py
  27. 4
    3
      tests/unit_tests/loss_loggings_test.py
  28. 5
    4
      tests/unit_tests/pretrained_models_unit_test.py
  29. 4
    2
      tests/unit_tests/quantization_utility_tests.py
  30. 2
    1
      tests/unit_tests/save_ckpt_test.py
  31. 12
    11
      tests/unit_tests/strictload_enum_test.py
  32. 4
    3
      tests/unit_tests/test_without_train_test.py
  33. 3
    2
      tests/unit_tests/train_after_test_test.py
  34. 8
    7
      tests/unit_tests/train_with_intialized_param_args_test.py
  35. 2
    1
      tests/unit_tests/vit_unit_test.py
@@ -44,7 +44,9 @@ ________________________________________________________________________________
 ```python
 ```python
 # Load model with pretrained weights
 # Load model with pretrained weights
 from super_gradients.training import models
 from super_gradients.training import models
-model = models.get("yolox_s", pretrained_weights="coco")
+from super_gradients.common.object_names import Models
+
+model = models.get(Models.YOLOX_S, pretrained_weights="coco")
 ```
 ```
 #### All Computer Vision Models - Pretrained Checkpoints can be found in the [Model Zoo](http://bit.ly/3EGfKD4)
 #### All Computer Vision Models - Pretrained Checkpoints can be found in the [Model Zoo](http://bit.ly/3EGfKD4)
 
 
@@ -81,7 +83,10 @@ More example on how and why to use recipes can be found in [Recipes](#recipes)
 All SuperGradients models’ are production ready in the sense that they are compatible with deployment tools such as TensorRT (Nvidia) and OpenVINO (Intel) and can be easily taken into production. With a few lines of code you can easily integrate the models into your codebase.
 All SuperGradients models’ are production ready in the sense that they are compatible with deployment tools such as TensorRT (Nvidia) and OpenVINO (Intel) and can be easily taken into production. With a few lines of code you can easily integrate the models into your codebase.
 ```python
 ```python
 # Load model with pretrained weights
 # Load model with pretrained weights
-model = models.get("yolox_s", pretrained_weights="coco")
+from super_gradients.training import models
+from super_gradients.common.object_names import Models
+
+model = models.get(Models.YOLOX_S, pretrained_weights="coco")
 
 
 # Prepare model for conversion
 # Prepare model for conversion
 # Input size is in format of [Batch x Channels x Width x Height] where 640 is the standart COCO dataset dimensions
 # Input size is in format of [Batch x Channels x Width x Height] where 640 is the standart COCO dataset dimensions
Discard
@@ -46,7 +46,10 @@ ________________________________________________________________________________
 ### Ready to deploy pre-trained SOTA models
 ### Ready to deploy pre-trained SOTA models
 ```python
 ```python
 # Load model with pretrained weights
 # Load model with pretrained weights
-model = models.get("yolox_s", pretrained_weights="coco")
+from super_gradients.common.object_names import Models
+from super_gradients.training import models
+
+model = models.get(Models.YOLOX_S, pretrained_weights="coco")
 ```
 ```
 
 
 #### Classification
 #### Classification
@@ -86,7 +89,10 @@ More example on how and why to use recipes can be found in [Recipes](#recipes)
 All SuperGradients models’ are production ready in the sense that they are compatible with deployment tools such as TensorRT (Nvidia) and OpenVINO (Intel) and can be easily taken into production. With a few lines of code you can easily integrate the models into your codebase.
 All SuperGradients models’ are production ready in the sense that they are compatible with deployment tools such as TensorRT (Nvidia) and OpenVINO (Intel) and can be easily taken into production. With a few lines of code you can easily integrate the models into your codebase.
 ```python
 ```python
 # Load model with pretrained weights
 # Load model with pretrained weights
-model = models.get("yolox_s", pretrained_weights="coco")
+from super_gradients.training import models
+from super_gradients.common.object_names import Models
+
+model = models.get(Models.YOLOX_S, pretrained_weights="coco")
 
 
 # Prepare model for conversion
 # Prepare model for conversion
 # Input size is in format of [Batch x Channels x Width x Height] where 640 is the standart COCO dataset dimensions
 # Input size is in format of [Batch x Channels x Width x Height] where 640 is the standart COCO dataset dimensions
Discard
@@ -131,9 +131,10 @@ Want to try our pre-trained models on your machine? Import SuperGradients, initi
     
     
 import super_gradients
 import super_gradients
 from super_gradients.training import Trainer, models, dataloaders
 from super_gradients.training import Trainer, models, dataloaders
+from super_gradients.common.object_names import Models
 
 
 trainer = Trainer(experiment_name="yoloxn_coco_experiment",ckpt_root_dir=<CHECKPOINT_DIRECTORY>)
 trainer = Trainer(experiment_name="yoloxn_coco_experiment",ckpt_root_dir=<CHECKPOINT_DIRECTORY>)
-model = models.get("yolox_n", pretrained_weights="coco", num_classes= 80)
+model = models.get(Models.YOLOX_N, pretrained_weights="coco", num_classes= 80)
 train_loader = dataloaders.coco2017_train()
 train_loader = dataloaders.coco2017_train()
 valid_loader = dataloaders.coco2017_val()
 valid_loader = dataloaders.coco2017_val()
 train_params = {...}
 train_params = {...}
Discard
@@ -15,6 +15,7 @@ Paper:              https://arxiv.org/pdf/2101.06085.pdf
 import torch
 import torch
 
 
 from super_gradients.common import MultiGPUMode
 from super_gradients.common import MultiGPUMode
+from super_gradients.common.object_names import Models
 from super_gradients.training.datasets.datasets_utils import RandomResizedCropAndInterpolation
 from super_gradients.training.datasets.datasets_utils import RandomResizedCropAndInterpolation
 from torchvision.transforms import RandomHorizontalFlip, ColorJitter, ToTensor, Normalize
 from torchvision.transforms import RandomHorizontalFlip, ColorJitter, ToTensor, Normalize
 import super_gradients
 import super_gradients
@@ -22,6 +23,7 @@ from super_gradients.training import Trainer, models, dataloaders
 import argparse
 import argparse
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.datasets.data_augmentation import RandomErase
 from super_gradients.training.datasets.data_augmentation import RandomErase
+
 parser = argparse.ArgumentParser()
 parser = argparse.ArgumentParser()
 super_gradients.init_trainer()
 super_gradients.init_trainer()
 
 
@@ -29,53 +31,56 @@ parser.add_argument("--reload", action="store_true")
 parser.add_argument("--max_epochs", type=int, default=100)
 parser.add_argument("--max_epochs", type=int, default=100)
 parser.add_argument("--batch", type=int, default=3)
 parser.add_argument("--batch", type=int, default=3)
 parser.add_argument("--experiment_name", type=str, default="ddrnet_23")
 parser.add_argument("--experiment_name", type=str, default="ddrnet_23")
-parser.add_argument("-s", "--slim", action="store_true", help='train the slim version of DDRNet23')
+parser.add_argument("-s", "--slim", action="store_true", help="train the slim version of DDRNet23")
 
 
 args, _ = parser.parse_known_args()
 args, _ = parser.parse_known_args()
 distributed = super_gradients.is_distributed()
 distributed = super_gradients.is_distributed()
 devices = torch.cuda.device_count() if not distributed else 1
 devices = torch.cuda.device_count() if not distributed else 1
 
 
-train_params_ddr = {"max_epochs": args.max_epochs,
-                    "lr_mode": "step",
-                    "lr_updates": [30, 60, 90],
-                    "lr_decay_factor": 0.1,
-                    "initial_lr": 0.1 * devices,
-                    "optimizer": "SGD",
-                    "optimizer_params": {"weight_decay": 0.0001, "momentum": 0.9, "nesterov": True},
-                    "loss": "cross_entropy",
-                    "train_metrics_list": [Accuracy(), Top5()],
-                    "valid_metrics_list": [Accuracy(), Top5()],
-
-                    "metric_to_watch": "Accuracy",
-                    "greater_metric_to_watch_is_better": True
-                    }
-
-dataset_params = {"batch_size": args.batch,
-                  "color_jitter": 0.4,
-                  "random_erase_prob": 0.2,
-                  "random_erase_value": 'random',
-                  "train_interpolation": 'random',
-                  }
-
-
-train_transforms = [RandomResizedCropAndInterpolation(size=224, interpolation="random"),
-                    RandomHorizontalFlip(),
-                    ColorJitter(0.4, 0.4, 0.4),
-                    ToTensor(),
-                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
-                    RandomErase(0.2, "random")
-                    ]
-
-trainer = Trainer(experiment_name=args.experiment_name,
-                  multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL if distributed else MultiGPUMode.DATA_PARALLEL,
-                  device='cuda')
-
-train_loader = dataloaders.imagenet_train(dataset_params={"transforms": train_transforms},
-                                          dataloader_params={"batch_size": args.batch})
+train_params_ddr = {
+    "max_epochs": args.max_epochs,
+    "lr_mode": "step",
+    "lr_updates": [30, 60, 90],
+    "lr_decay_factor": 0.1,
+    "initial_lr": 0.1 * devices,
+    "optimizer": "SGD",
+    "optimizer_params": {"weight_decay": 0.0001, "momentum": 0.9, "nesterov": True},
+    "loss": "cross_entropy",
+    "train_metrics_list": [Accuracy(), Top5()],
+    "valid_metrics_list": [Accuracy(), Top5()],
+    "metric_to_watch": "Accuracy",
+    "greater_metric_to_watch_is_better": True,
+}
+
+dataset_params = {
+    "batch_size": args.batch,
+    "color_jitter": 0.4,
+    "random_erase_prob": 0.2,
+    "random_erase_value": "random",
+    "train_interpolation": "random",
+}
+
+
+train_transforms = [
+    RandomResizedCropAndInterpolation(size=224, interpolation="random"),
+    RandomHorizontalFlip(),
+    ColorJitter(0.4, 0.4, 0.4),
+    ToTensor(),
+    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+    RandomErase(0.2, "random"),
+]
+
+trainer = Trainer(
+    experiment_name=args.experiment_name, multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL if distributed else MultiGPUMode.DATA_PARALLEL, device="cuda"
+)
+
+train_loader = dataloaders.imagenet_train(dataset_params={"transforms": train_transforms}, dataloader_params={"batch_size": args.batch})
 valid_loader = dataloaders.imagenet_val()
 valid_loader = dataloaders.imagenet_val()
 
 
-model = models.get("ddrnet_23_slim" if args.slim else "ddrnet_23",
-                   arch_params={"aux_head": False, "classification_mode": True, 'dropout_prob': 0.3},
-                   num_classes=1000)
+model = models.get(
+    Models.DDRNET_23_SLIM if args.slim else Models.DDRNET_23,
+    arch_params={"aux_head": False, "classification_mode": True, "dropout_prob": 0.3},
+    num_classes=1000,
+)
 
 
 trainer.train(model=model, training_params=train_params_ddr, train_loader=train_loader, valid_loader=valid_loader)
 trainer.train(model=model, training_params=train_params_ddr, train_loader=train_loader, valid_loader=valid_loader)
Discard
@@ -2,6 +2,7 @@
 # Reaches ~94.9 Accuracy after 250 Epochs
 # Reaches ~94.9 Accuracy after 250 Epochs
 import super_gradients
 import super_gradients
 from super_gradients import Trainer
 from super_gradients import Trainer
+from super_gradients.common.object_names import Models
 from super_gradients.training import models, dataloaders
 from super_gradients.training import models, dataloaders
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
 from super_gradients.training.utils.early_stopping import EarlyStop
 from super_gradients.training.utils.early_stopping import EarlyStop
@@ -13,18 +14,28 @@ super_gradients.init_trainer()
 early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Accuracy", mode="max", patience=3, verbose=True)
 early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Accuracy", mode="max", patience=3, verbose=True)
 early_stop_val_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LabelSmoothingCrossEntropyLoss", mode="min", patience=3, verbose=True)
 early_stop_val_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LabelSmoothingCrossEntropyLoss", mode="min", patience=3, verbose=True)
 
 
-train_params = {"max_epochs": 250, "lr_updates": [100, 150, 200], "lr_decay_factor": 0.1, "lr_mode": "step",
-                "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
-                "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
-                "train_metrics_list": [Accuracy(), Top5()], "valid_metrics_list": [Accuracy(), Top5()],
-                "metric_to_watch": "Accuracy",
-                "greater_metric_to_watch_is_better": True, "phase_callbacks": [early_stop_acc, early_stop_val_loss]}
+train_params = {
+    "max_epochs": 250,
+    "lr_updates": [100, 150, 200],
+    "lr_decay_factor": 0.1,
+    "lr_mode": "step",
+    "lr_warmup_epochs": 0,
+    "initial_lr": 0.1,
+    "loss": "cross_entropy",
+    "optimizer": "SGD",
+    "criterion_params": {},
+    "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
+    "train_metrics_list": [Accuracy(), Top5()],
+    "valid_metrics_list": [Accuracy(), Top5()],
+    "metric_to_watch": "Accuracy",
+    "greater_metric_to_watch_is_better": True,
+    "phase_callbacks": [early_stop_acc, early_stop_val_loss],
+}
 
 
 # Define Model
 # Define Model
 trainer = Trainer("Callback_Example")
 trainer = Trainer("Callback_Example")
 
 
 # Build Model
 # Build Model
-model = models.get("resnet18_cifar", num_classes=10)
+model = models.get(Models.RESNET18_CIFAR, num_classes=10)
 
 
-trainer.train(model=model, training_params=train_params,
-              train_loader=dataloaders.cifar10_train(), valid_loader=dataloaders.cifar10_val())
+trainer.train(model=model, training_params=train_params, train_loader=dataloaders.cifar10_train(), valid_loader=dataloaders.cifar10_val())
Discard
@@ -1,10 +1,11 @@
+from super_gradients.common.object_names import Models
 from super_gradients.training import Trainer, models
 from super_gradients.training import Trainer, models
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
 from super_gradients.training.dataloaders.dataloaders import cifar10_train, cifar10_val
 from super_gradients.training.dataloaders.dataloaders import cifar10_train, cifar10_val
 
 
 
 
 trainer = Trainer(experiment_name="demo-clearml-logger")
 trainer = Trainer(experiment_name="demo-clearml-logger")
-model = models.get("resnet18", num_classes=10)
+model = models.get(Models.RESNET18, num_classes=10)
 
 
 training_params = {
 training_params = {
     "max_epochs": 20,
     "max_epochs": 20,
Discard
@@ -1,4 +1,6 @@
 import os
 import os
+
+from super_gradients.common.object_names import Models
 from super_gradients.training import Trainer, models
 from super_gradients.training import Trainer, models
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
 from super_gradients.training.dataloaders.dataloaders import cifar10_train, cifar10_val
 from super_gradients.training.dataloaders.dataloaders import cifar10_train, cifar10_val
@@ -7,7 +9,7 @@ os.environ["DECI_PLATFORM_TOKEN"] = "XXX"  # Replace XXX with your token
 
 
 
 
 trainer = Trainer(experiment_name="demo-deci-platform-logger")
 trainer = Trainer(experiment_name="demo-deci-platform-logger")
-model = models.get("resnet18", num_classes=10)
+model = models.get(Models.RESNET18, num_classes=10)
 training_params = {
 training_params = {
     "max_epochs": 20,
     "max_epochs": 20,
     "lr_updates": [5, 10, 15],
     "lr_updates": [5, 10, 15],
Discard
@@ -4,6 +4,7 @@ from torch import nn
 
 
 import super_gradients
 import super_gradients
 from super_gradients import Trainer
 from super_gradients import Trainer
+from super_gradients.common.object_names import Models
 from super_gradients.modules.quantization.resnet_bottleneck import QuantBottleneck as sg_QuantizedBottleneck
 from super_gradients.modules.quantization.resnet_bottleneck import QuantBottleneck as sg_QuantizedBottleneck
 from super_gradients.training import MultiGPUMode
 from super_gradients.training import MultiGPUMode
 from super_gradients.training import models as sg_models
 from super_gradients.training import models as sg_models
@@ -55,7 +56,7 @@ def selective_quantize(model: nn.Module):
 
 
 
 
 def sg_vanilla_resnet50():
 def sg_vanilla_resnet50():
-    return sg_models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
+    return sg_models.get(Models.RESNET50, pretrained_weights="imagenet", num_classes=1000)
 
 
 
 
 def sg_naive_qdq_resnet50():
 def sg_naive_qdq_resnet50():
Discard
@@ -1,20 +1,31 @@
+from super_gradients.common.object_names import Models
 from super_gradients.training import models, dataloaders
 from super_gradients.training import models, dataloaders
 
 
 from super_gradients.training.sg_trainer import Trainer
 from super_gradients.training.sg_trainer import Trainer
 from super_gradients.training.metrics import BinaryIOU
 from super_gradients.training.metrics import BinaryIOU
-from super_gradients.training.transforms.transforms import SegResize, SegRandomFlip, SegRandomRescale, SegCropImageAndMask, \
-    SegPadShortToCropSize, SegColorJitter
+from super_gradients.training.transforms.transforms import (
+    SegResize,
+    SegRandomFlip,
+    SegRandomRescale,
+    SegCropImageAndMask,
+    SegPadShortToCropSize,
+    SegColorJitter,
+)
 from super_gradients.training.utils.callbacks import BinarySegmentationVisualizationCallback, Phase
 from super_gradients.training.utils.callbacks import BinarySegmentationVisualizationCallback, Phase
 
 
 # DEFINE DATA TRANSFORMATIONS
 # DEFINE DATA TRANSFORMATIONS
 
 
 dl_train = dataloaders.supervisely_persons_train(
 dl_train = dataloaders.supervisely_persons_train(
-    dataset_params={"transforms": [SegColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
-                                   SegRandomFlip(),
-                                   SegRandomRescale(scales=[0.25, 1.]),
-                                   SegPadShortToCropSize([320, 480]),
-                                   SegCropImageAndMask(crop_size=[320, 480],
-                                                       mode="random")]})
+    dataset_params={
+        "transforms": [
+            SegColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
+            SegRandomFlip(),
+            SegRandomRescale(scales=[0.25, 1.0]),
+            SegPadShortToCropSize([320, 480]),
+            SegCropImageAndMask(crop_size=[320, 480], mode="random"),
+        ]
+    }
+)
 
 
 dl_val = dataloaders.supervisely_persons_val(dataset_params={"transforms": [SegResize(h=480, w=320)]})
 dl_val = dataloaders.supervisely_persons_val(dataset_params={"transforms": [SegResize(h=480, w=320)]})
 
 
@@ -23,35 +34,27 @@ trainer = Trainer("regseg48_transfer_learning_old_dice_diff_lrs_head_fixed_50_ep
 # THIS IS WHERE THE MAGIC HAPPENS- SINCE TRAINER'S CLASSES ATTRIBUTE WAS SET TO BE DIFFERENT FROM CITYSCAPES'S, AFTER
 # THIS IS WHERE THE MAGIC HAPPENS- SINCE TRAINER'S CLASSES ATTRIBUTE WAS SET TO BE DIFFERENT FROM CITYSCAPES'S, AFTER
 # LOADING THE PRETRAINED REGSET, IT WILL CALL IT'S REPLACE_HEAD METHOD AND CHANGE IT'S SEGMENTATION HEAD LAYER ACCORDING
 # LOADING THE PRETRAINED REGSET, IT WILL CALL IT'S REPLACE_HEAD METHOD AND CHANGE IT'S SEGMENTATION HEAD LAYER ACCORDING
 # TO OUR BINARY SEGMENTATION DATASET
 # TO OUR BINARY SEGMENTATION DATASET
-model = models.get("regseg48", pretrained_weights="cityscapes", num_classes=1)
+model = models.get(Models.REGSEG48, pretrained_weights="cityscapes", num_classes=1)
 
 
 # DEFINE TRAINING PARAMS. SEE DOCS FOR THE FULL LIST.
 # DEFINE TRAINING PARAMS. SEE DOCS FOR THE FULL LIST.
-train_params = {"max_epochs": 50,
-                "lr_mode": "cosine",
-                "initial_lr": 0.0064,  # for batch_size=16
-                "optimizer_params": {"momentum": 0.843,
-                                     "weight_decay": 0.00036,
-                                     "nesterov": True},
-
-                "cosine_final_lr_ratio": 0.1,
-                "multiply_head_lr": 10,
-                "optimizer": "SGD",
-                "loss": "bce_dice_loss",
-                "ema": True,
-                "zero_weight_decay_on_bias_and_bn": True,
-                "average_best_models": True,
-                "mixed_precision": False,
-                "metric_to_watch": "mean_IOU",
-                "greater_metric_to_watch_is_better": True,
-                "train_metrics_list": [BinaryIOU()],
-                "valid_metrics_list": [BinaryIOU()],
-
-                "phase_callbacks": [BinarySegmentationVisualizationCallback(phase=Phase.VALIDATION_BATCH_END,
-                                                                            freq=1,
-                                                                            last_img_idx_in_batch=4)],
-                }
-
-trainer.train(model=model,
-              training_params=train_params,
-              train_loader=dl_train,
-              valid_loader=dl_val)
+train_params = {
+    "max_epochs": 50,
+    "lr_mode": "cosine",
+    "initial_lr": 0.0064,  # for batch_size=16
+    "optimizer_params": {"momentum": 0.843, "weight_decay": 0.00036, "nesterov": True},
+    "cosine_final_lr_ratio": 0.1,
+    "multiply_head_lr": 10,
+    "optimizer": "SGD",
+    "loss": "bce_dice_loss",
+    "ema": True,
+    "zero_weight_decay_on_bias_and_bn": True,
+    "average_best_models": True,
+    "mixed_precision": False,
+    "metric_to_watch": "mean_IOU",
+    "greater_metric_to_watch_is_better": True,
+    "train_metrics_list": [BinaryIOU()],
+    "valid_metrics_list": [BinaryIOU()],
+    "phase_callbacks": [BinarySegmentationVisualizationCallback(phase=Phase.VALIDATION_BATCH_END, freq=1, last_img_idx_in_batch=4)],
+}
+
+trainer.train(model=model, training_params=train_params, train_loader=dl_train, valid_loader=dl_val)
Discard
@@ -4,7 +4,9 @@
 You can load any of our pretrained model in 2 lines of code:
 You can load any of our pretrained model in 2 lines of code:
 ```python
 ```python
 from super_gradients.training import models
 from super_gradients.training import models
-model = models.get("yolox_s", pretrained_weights="coco")
+from super_gradients.common.object_names import Models
+
+model = models.get(Models.YOLOX_S, pretrained_weights="coco")
 ```
 ```
 
 
 All the available models are listed in the column `Model name`.
 All the available models are listed in the column `Model name`.
Discard
@@ -1,5 +1,6 @@
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 
 
 import super_gradients
 import super_gradients
@@ -18,7 +19,7 @@ class TestCifarTrainer(unittest.TestCase):
         super_gradients.init_trainer()
         super_gradients.init_trainer()
         trainer = Trainer("test")
         trainer = Trainer("test")
         cifar10_train_dl, cifar10_val_dl = cifar10_train(), cifar10_val()
         cifar10_train_dl, cifar10_val_dl = cifar10_train(), cifar10_val()
-        model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
+        model = models.get(Models.RESNET18_CIFAR, arch_params={"num_classes": 10})
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params={
             training_params={
@@ -37,7 +38,7 @@ class TestCifarTrainer(unittest.TestCase):
         super_gradients.init_trainer()
         super_gradients.init_trainer()
         trainer = Trainer("test")
         trainer = Trainer("test")
         cifar100_train_dl, cifar100_val_dl = cifar100_train(), cifar100_val()
         cifar100_train_dl, cifar100_val_dl = cifar100_train(), cifar100_val()
-        model = models.get("resnet18_cifar", arch_params={"num_classes": 100})
+        model = models.get(Models.RESNET18_CIFAR, arch_params={"num_classes": 100})
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params={
             training_params={
Discard
@@ -1,6 +1,7 @@
 import shutil
 import shutil
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 
 
 import super_gradients
 import super_gradients
@@ -41,7 +42,7 @@ class TestTrainer(unittest.TestCase):
     @staticmethod
     @staticmethod
     def get_classification_trainer(name=""):
     def get_classification_trainer(name=""):
         trainer = Trainer(name)
         trainer = Trainer(name)
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         return trainer, model
         return trainer, model
 
 
     def test_train(self):
     def test_train(self):
Discard
@@ -1,3 +1,4 @@
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
@@ -22,7 +23,7 @@ class CallWrapper:
 class EMAIntegrationTest(unittest.TestCase):
 class EMAIntegrationTest(unittest.TestCase):
     def _init_model(self) -> None:
     def _init_model(self) -> None:
         self.trainer = Trainer("resnet18_cifar_ema_test")
         self.trainer = Trainer("resnet18_cifar_ema_test")
-        self.model = models.get("resnet18_cifar", arch_params={"num_classes": 5})
+        self.model = models.get(Models.RESNET18_CIFAR, arch_params={"num_classes": 5})
 
 
     @classmethod
     @classmethod
     def tearDownClass(cls) -> None:
     def tearDownClass(cls) -> None:
Discard
@@ -2,6 +2,7 @@ import shutil
 import unittest
 import unittest
 import os
 import os
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 
 
 from super_gradients import Trainer
 from super_gradients import Trainer
@@ -34,7 +35,7 @@ class LRTest(unittest.TestCase):
     @staticmethod
     @staticmethod
     def get_trainer(name=""):
     def get_trainer(name=""):
         trainer = Trainer(name)
         trainer = Trainer(name)
-        model = models.get("resnet18_cifar", num_classes=5)
+        model = models.get(Models.RESNET18_CIFAR, num_classes=5)
         return trainer, model
         return trainer, model
 
 
     def test_function_lr(self):
     def test_function_lr(self):
Discard
@@ -1,6 +1,7 @@
 #!/usr/bin/env python3
 #!/usr/bin/env python3
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
 from super_gradients.training.dataloaders import imagenet_val, imagenet_vit_base_val
 from super_gradients.training.dataloaders import imagenet_val, imagenet_vit_base_val
 from super_gradients.training.dataloaders.dataloaders import (
 from super_gradients.training.dataloaders.dataloaders import (
@@ -35,19 +36,19 @@ class PretrainedModelsTest(unittest.TestCase):
         self.imagenet_pretrained_arch_params = {
         self.imagenet_pretrained_arch_params = {
             "resnet": {},
             "resnet": {},
             "regnet": {},
             "regnet": {},
-            "repvgg_a0": {"build_residual_branches": True},
-            "efficientnet_b0": {},
+            Models.REPVGG_A0: {"build_residual_branches": True},
+            Models.EFFICIENTNET_B0: {},
             "mobilenet": {},
             "mobilenet": {},
-            "vit_base": {"image_size": (224, 224), "patch_size": (16, 16)},
+            Models.VIT_BASE: {"image_size": (224, 224), "patch_size": (16, 16)},
         }
         }
 
 
         self.imagenet_pretrained_trainsfer_learning_arch_params = {
         self.imagenet_pretrained_trainsfer_learning_arch_params = {
             "resnet": {},
             "resnet": {},
             "regnet": {},
             "regnet": {},
-            "repvgg_a0": {"build_residual_branches": True},
-            "efficientnet_b0": {},
+            Models.REPVGG_A0: {"build_residual_branches": True},
+            Models.EFFICIENTNET_B0: {},
             "mobilenet": {},
             "mobilenet": {},
-            "vit_base": {"image_size": (224, 224), "patch_size": (16, 16)},
+            Models.VIT_BASE: {"image_size": (224, 224), "patch_size": (16, 16)},
         }
         }
 
 
         self.imagenet_pretrained_ckpt_params = {"pretrained_weights": "imagenet"}
         self.imagenet_pretrained_ckpt_params = {"pretrained_weights": "imagenet"}
@@ -55,21 +56,21 @@ class PretrainedModelsTest(unittest.TestCase):
         self.imagenet21k_pretrained_ckpt_params = {"pretrained_weights": "imagenet21k"}
         self.imagenet21k_pretrained_ckpt_params = {"pretrained_weights": "imagenet21k"}
 
 
         self.imagenet_pretrained_accuracies = {
         self.imagenet_pretrained_accuracies = {
-            "resnet50": 0.8191,
-            "resnet34": 0.7413,
-            "resnet18": 0.706,
-            "repvgg_a0": 0.7205,
-            "regnetY800": 0.7707,
-            "regnetY600": 0.7618,
-            "regnetY400": 0.7474,
-            "regnetY200": 0.7088,
-            "efficientnet_b0": 0.7762,
-            "mobilenet_v3_large": 0.7452,
-            "mobilenet_v3_small": 0.6745,
-            "mobilenet_v2": 0.7308,
-            "vit_base": 0.8415,
-            "vit_large": 0.8564,
-            "beit_base_patch16_224": 0.85,
+            Models.RESNET50: 0.8191,
+            Models.RESNET34: 0.7413,
+            Models.RESNET18: 0.706,
+            Models.REPVGG_A0: 0.7205,
+            Models.REGNETY800: 0.7707,
+            Models.REGNETY600: 0.7618,
+            Models.REGNETY400: 0.7474,
+            Models.REGNETY200: 0.7088,
+            Models.EFFICIENTNET_B0: 0.7762,
+            Models.MOBILENET_V3_LARGE: 0.7452,
+            Models.MOBILENET_V3_SMALL: 0.6745,
+            Models.MOBILENET_V2: 0.7308,
+            Models.VIT_BASE: 0.8415,
+            Models.VIT_LARGE: 0.8564,
+            Models.BEIT_BASE_PATCH16_224: 0.85,
         }
         }
         self.imagenet_dataset = imagenet_val(dataloader_params={"batch_size": 128})
         self.imagenet_dataset = imagenet_val(dataloader_params={"batch_size": 128})
 
 
@@ -91,7 +92,7 @@ class PretrainedModelsTest(unittest.TestCase):
             "greater_metric_to_watch_is_better": True,
             "greater_metric_to_watch_is_better": True,
         }
         }
         self.coco_pretrained_ckpt_params = {"pretrained_weights": "coco"}
         self.coco_pretrained_ckpt_params = {"pretrained_weights": "coco"}
-        self.coco_pretrained_arch_params = {"ssd_lite_mobilenet_v2": {"num_classes": 80}, "coco_ssd_mobilenet_v1": {"num_classes": 80}}
+        self.coco_pretrained_arch_params = {Models.SSD_LITE_MOBILENET_V2: {"num_classes": 80}, Models.SSD_MOBILENET_V1: {"num_classes": 80}}
         self.coco_pretrained_ckpt_params = {"pretrained_weights": "coco"}
         self.coco_pretrained_ckpt_params = {"pretrained_weights": "coco"}
 
 
         self.coco_dataset = {
         self.coco_dataset = {
@@ -102,13 +103,13 @@ class PretrainedModelsTest(unittest.TestCase):
         }
         }
 
 
         self.coco_pretrained_maps = {
         self.coco_pretrained_maps = {
-            "ssd_lite_mobilenet_v2": 0.2041,
-            "coco_ssd_mobilenet_v1": 0.243,
-            "yolox_s": 0.4047,
-            "yolox_m": 0.4640,
-            "yolox_l": 0.4925,
-            "yolox_n": 0.2677,
-            "yolox_t": 0.3718,
+            Models.SSD_LITE_MOBILENET_V2: 0.2041,
+            Models.SSD_MOBILENET_V1: 0.243,
+            Models.YOLOX_S: 0.4047,
+            Models.YOLOX_M: 0.4640,
+            Models.YOLOX_L: 0.4925,
+            Models.YOLOX_N: 0.2677,
+            Models.YOLOX_T: 0.3718,
         }
         }
 
 
         self.transfer_detection_dataset = detection_test_dataloader()
         self.transfer_detection_dataset = detection_test_dataloader()
@@ -151,27 +152,27 @@ class PretrainedModelsTest(unittest.TestCase):
         self.coco_segmentation_subclass_pretrained_mious = {"shelfnet34_lw": 0.651}
         self.coco_segmentation_subclass_pretrained_mious = {"shelfnet34_lw": 0.651}
         self.coco_segmentation_dataset = coco_segmentation_val()
         self.coco_segmentation_dataset = coco_segmentation_val()
 
 
-        self.cityscapes_pretrained_models = ["ddrnet_23", "ddrnet_23_slim", "stdc1_seg50", "regseg48"]
+        self.cityscapes_pretrained_models = [Models.DDRNET_23, Models.DDRNET_23_SLIM, Models.STDC1_SEG50, Models.REGSEG48]
         self.cityscapes_pretrained_arch_params = {
         self.cityscapes_pretrained_arch_params = {
-            "ddrnet_23": {"aux_head": True},
-            "regseg48": {},
+            Models.DDRNET_23: {"aux_head": True},
+            Models.REGSEG48: {},
             "stdc": {"use_aux_heads": True, "aux_head": True},
             "stdc": {"use_aux_heads": True, "aux_head": True},
             "pplite_seg": {"use_aux_heads": True},
             "pplite_seg": {"use_aux_heads": True},
         }
         }
 
 
         self.cityscapes_pretrained_ckpt_params = {"pretrained_weights": "cityscapes"}
         self.cityscapes_pretrained_ckpt_params = {"pretrained_weights": "cityscapes"}
         self.cityscapes_pretrained_mious = {
         self.cityscapes_pretrained_mious = {
-            "ddrnet_23": 0.8026,
-            "ddrnet_23_slim": 0.7801,
-            "stdc1_seg50": 0.7511,
-            "stdc1_seg75": 0.7687,
-            "stdc2_seg50": 0.7644,
-            "stdc2_seg75": 0.7893,
-            "regseg48": 0.7815,
-            "pp_lite_t_seg50": 0.7492,
-            "pp_lite_t_seg75": 0.7756,
-            "pp_lite_b_seg50": 0.7648,
-            "pp_lite_b_seg75": 0.7852,
+            Models.DDRNET_23: 0.8026,
+            Models.DDRNET_23_SLIM: 0.7801,
+            Models.STDC1_SEG50: 0.7511,
+            Models.STDC1_SEG75: 0.7687,
+            Models.STDC2_SEG50: 0.7644,
+            Models.STDC2_SEG75: 0.7893,
+            Models.REGSEG48: 0.7815,
+            Models.PP_LITE_T_SEG50: 0.7492,
+            Models.PP_LITE_T_SEG75: 0.7756,
+            Models.PP_LITE_B_SEG50: 0.7648,
+            Models.PP_LITE_B_SEG75: 0.7852,
         }
         }
 
 
         self.cityscapes_dataset = cityscapes_val()
         self.cityscapes_dataset = cityscapes_val()
@@ -229,13 +230,13 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_pretrained_resnet50_imagenet(self):
     def test_pretrained_resnet50_imagenet(self):
         trainer = Trainer("imagenet_pretrained_resnet50")
         trainer = Trainer("imagenet_pretrained_resnet50")
-        model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.RESNET50, arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet50"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.RESNET50], delta=0.001)
 
 
     def test_transfer_learning_resnet50_imagenet(self):
     def test_transfer_learning_resnet50_imagenet(self):
         trainer = Trainer("imagenet_pretrained_resnet50_transfer_learning")
         trainer = Trainer("imagenet_pretrained_resnet50_transfer_learning")
-        model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
+        model = models.get(Models.RESNET50, arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -246,13 +247,13 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_resnet34_imagenet(self):
     def test_pretrained_resnet34_imagenet(self):
         trainer = Trainer("imagenet_pretrained_resnet34")
         trainer = Trainer("imagenet_pretrained_resnet34")
 
 
-        model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.RESNET34, arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet34"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.RESNET34], delta=0.001)
 
 
     def test_transfer_learning_resnet34_imagenet(self):
     def test_transfer_learning_resnet34_imagenet(self):
         trainer = Trainer("imagenet_pretrained_resnet34_transfer_learning")
         trainer = Trainer("imagenet_pretrained_resnet34_transfer_learning")
-        model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
+        model = models.get(Models.RESNET34, arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -263,13 +264,13 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_resnet18_imagenet(self):
     def test_pretrained_resnet18_imagenet(self):
         trainer = Trainer("imagenet_pretrained_resnet18")
         trainer = Trainer("imagenet_pretrained_resnet18")
 
 
-        model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.RESNET18, arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet18"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.RESNET18], delta=0.001)
 
 
     def test_transfer_learning_resnet18_imagenet(self):
     def test_transfer_learning_resnet18_imagenet(self):
         trainer = Trainer("imagenet_pretrained_resnet18_transfer_learning")
         trainer = Trainer("imagenet_pretrained_resnet18_transfer_learning")
-        model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
+        model = models.get(Models.RESNET18, arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -280,13 +281,13 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_regnetY800_imagenet(self):
     def test_pretrained_regnetY800_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY800")
         trainer = Trainer("imagenet_pretrained_regnetY800")
 
 
-        model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.REGNETY800, arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY800"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.REGNETY800], delta=0.001)
 
 
     def test_transfer_learning_regnetY800_imagenet(self):
     def test_transfer_learning_regnetY800_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY800_transfer_learning")
         trainer = Trainer("imagenet_pretrained_regnetY800_transfer_learning")
-        model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
+        model = models.get(Models.REGNETY800, arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -297,13 +298,13 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_regnetY600_imagenet(self):
     def test_pretrained_regnetY600_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY600")
         trainer = Trainer("imagenet_pretrained_regnetY600")
 
 
-        model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.REGNETY600, arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY600"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.REGNETY600], delta=0.001)
 
 
     def test_transfer_learning_regnetY600_imagenet(self):
     def test_transfer_learning_regnetY600_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY600_transfer_learning")
         trainer = Trainer("imagenet_pretrained_regnetY600_transfer_learning")
-        model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
+        model = models.get(Models.REGNETY600, arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -314,13 +315,13 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_regnetY400_imagenet(self):
     def test_pretrained_regnetY400_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY400")
         trainer = Trainer("imagenet_pretrained_regnetY400")
 
 
-        model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.REGNETY400, arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY400"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.REGNETY400], delta=0.001)
 
 
     def test_transfer_learning_regnetY400_imagenet(self):
     def test_transfer_learning_regnetY400_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY400_transfer_learning")
         trainer = Trainer("imagenet_pretrained_regnetY400_transfer_learning")
-        model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
+        model = models.get(Models.REGNETY400, arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -331,13 +332,13 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_regnetY200_imagenet(self):
     def test_pretrained_regnetY200_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY200")
         trainer = Trainer("imagenet_pretrained_regnetY200")
 
 
-        model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.REGNETY200, arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY200"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.REGNETY200], delta=0.001)
 
 
     def test_transfer_learning_regnetY200_imagenet(self):
     def test_transfer_learning_regnetY200_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY200_transfer_learning")
         trainer = Trainer("imagenet_pretrained_regnetY200_transfer_learning")
-        model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
+        model = models.get(Models.REGNETY200, arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -348,13 +349,15 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_repvgg_a0_imagenet(self):
     def test_pretrained_repvgg_a0_imagenet(self):
         trainer = Trainer("imagenet_pretrained_repvgg_a0")
         trainer = Trainer("imagenet_pretrained_repvgg_a0")
 
 
-        model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.REPVGG_A0, arch_params=self.imagenet_pretrained_arch_params[Models.REPVGG_A0], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["repvgg_a0"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.REPVGG_A0], delta=0.001)
 
 
     def test_transfer_learning_repvgg_a0_imagenet(self):
     def test_transfer_learning_repvgg_a0_imagenet(self):
         trainer = Trainer("imagenet_pretrained_repvgg_a0_transfer_learning")
         trainer = Trainer("imagenet_pretrained_repvgg_a0_transfer_learning")
-        model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
+        model = models.get(
+            Models.REPVGG_A0, arch_params=self.imagenet_pretrained_arch_params[Models.REPVGG_A0], **self.imagenet_pretrained_ckpt_params, num_classes=5
+        )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -364,7 +367,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_pretrained_regseg48_cityscapes(self):
     def test_pretrained_regseg48_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_regseg48")
         trainer = Trainer("cityscapes_pretrained_regseg48")
-        model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.REGSEG48, arch_params=self.cityscapes_pretrained_arch_params[Models.REGSEG48], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
                 model=model, test_loader=self.cityscapes_dataset, test_metrics_list=[IoU(num_classes=20, ignore_index=19)], metrics_progress_verbose=True
                 model=model, test_loader=self.cityscapes_dataset, test_metrics_list=[IoU(num_classes=20, ignore_index=19)], metrics_progress_verbose=True
@@ -372,11 +375,11 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["regseg48"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.REGSEG48], delta=0.001)
 
 
     def test_transfer_learning_regseg48_cityscapes(self):
     def test_transfer_learning_regseg48_cityscapes(self):
         trainer = Trainer("regseg48_cityscapes_transfer_learning")
         trainer = Trainer("regseg48_cityscapes_transfer_learning")
-        model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.REGSEG48, arch_params=self.cityscapes_pretrained_arch_params[Models.REGSEG48], **self.cityscapes_pretrained_ckpt_params)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             train_loader=self.transfer_segmentation_dataset,
             train_loader=self.transfer_segmentation_dataset,
@@ -386,7 +389,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_pretrained_ddrnet23_cityscapes(self):
     def test_pretrained_ddrnet23_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_ddrnet23")
         trainer = Trainer("cityscapes_pretrained_ddrnet23")
-        model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.DDRNET_23, arch_params=self.cityscapes_pretrained_arch_params[Models.DDRNET_23], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
                 model=model, test_loader=self.cityscapes_dataset, test_metrics_list=[IoU(num_classes=20, ignore_index=19)], metrics_progress_verbose=True
                 model=model, test_loader=self.cityscapes_dataset, test_metrics_list=[IoU(num_classes=20, ignore_index=19)], metrics_progress_verbose=True
@@ -394,11 +397,13 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.DDRNET_23], delta=0.001)
 
 
     def test_pretrained_ddrnet23_slim_cityscapes(self):
     def test_pretrained_ddrnet23_slim_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_ddrnet23_slim")
         trainer = Trainer("cityscapes_pretrained_ddrnet23_slim")
-        model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(
+            Models.DDRNET_23_SLIM, arch_params=self.cityscapes_pretrained_arch_params[Models.DDRNET_23], **self.cityscapes_pretrained_ckpt_params
+        )
         res = (
         res = (
             trainer.test(
             trainer.test(
                 model=model, test_loader=self.cityscapes_dataset, test_metrics_list=[IoU(num_classes=20, ignore_index=19)], metrics_progress_verbose=True
                 model=model, test_loader=self.cityscapes_dataset, test_metrics_list=[IoU(num_classes=20, ignore_index=19)], metrics_progress_verbose=True
@@ -406,11 +411,11 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23_slim"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.DDRNET_23_SLIM], delta=0.001)
 
 
     def test_transfer_learning_ddrnet23_cityscapes(self):
     def test_transfer_learning_ddrnet23_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_ddrnet23_transfer_learning")
         trainer = Trainer("cityscapes_pretrained_ddrnet23_transfer_learning")
-        model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.DDRNET_23, arch_params=self.cityscapes_pretrained_arch_params[Models.DDRNET_23], **self.cityscapes_pretrained_ckpt_params)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.ddrnet_transfer_segmentation_train_params,
             training_params=self.ddrnet_transfer_segmentation_train_params,
@@ -420,7 +425,9 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_transfer_learning_ddrnet23_slim_cityscapes(self):
     def test_transfer_learning_ddrnet23_slim_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_ddrnet23_slim_transfer_learning")
         trainer = Trainer("cityscapes_pretrained_ddrnet23_slim_transfer_learning")
-        model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(
+            Models.DDRNET_23_SLIM, arch_params=self.cityscapes_pretrained_arch_params[Models.DDRNET_23], **self.cityscapes_pretrained_ckpt_params
+        )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.ddrnet_transfer_segmentation_train_params,
             training_params=self.ddrnet_transfer_segmentation_train_params,
@@ -441,15 +448,20 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_efficientnet_b0_imagenet(self):
     def test_pretrained_efficientnet_b0_imagenet(self):
         trainer = Trainer("imagenet_pretrained_efficientnet_b0")
         trainer = Trainer("imagenet_pretrained_efficientnet_b0")
 
 
-        model = models.get("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(
+            Models.EFFICIENTNET_B0, arch_params=self.imagenet_pretrained_arch_params[Models.EFFICIENTNET_B0], **self.imagenet_pretrained_ckpt_params
+        )
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["efficientnet_b0"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.EFFICIENTNET_B0], delta=0.001)
 
 
     def test_transfer_learning_efficientnet_b0_imagenet(self):
     def test_transfer_learning_efficientnet_b0_imagenet(self):
         trainer = Trainer("imagenet_pretrained_efficientnet_b0_transfer_learning")
         trainer = Trainer("imagenet_pretrained_efficientnet_b0_transfer_learning")
 
 
         model = models.get(
         model = models.get(
-            "efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"], **self.imagenet_pretrained_ckpt_params, num_classes=5
+            Models.EFFICIENTNET_B0,
+            arch_params=self.imagenet_pretrained_arch_params[Models.EFFICIENTNET_B0],
+            **self.imagenet_pretrained_ckpt_params,
+            num_classes=5,
         )
         )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -460,7 +472,9 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_pretrained_ssd_lite_mobilenet_v2_coco(self):
     def test_pretrained_ssd_lite_mobilenet_v2_coco(self):
         trainer = Trainer("coco_ssd_lite_mobilenet_v2")
         trainer = Trainer("coco_ssd_lite_mobilenet_v2")
-        model = models.get("ssd_lite_mobilenet_v2", arch_params=self.coco_pretrained_arch_params["ssd_lite_mobilenet_v2"], **self.coco_pretrained_ckpt_params)
+        model = models.get(
+            Models.SSD_LITE_MOBILENET_V2, arch_params=self.coco_pretrained_arch_params[Models.SSD_LITE_MOBILENET_V2], **self.coco_pretrained_ckpt_params
+        )
         ssd_post_prediction_callback = SSDPostPredictCallback()
         ssd_post_prediction_callback = SSDPostPredictCallback()
         res = trainer.test(
         res = trainer.test(
             model=model,
             model=model,
@@ -468,13 +482,13 @@ class PretrainedModelsTest(unittest.TestCase):
             test_metrics_list=[DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback, num_cls=80)],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback, num_cls=80)],
             metrics_progress_verbose=True,
             metrics_progress_verbose=True,
         )[2]
         )[2]
-        self.assertAlmostEqual(res, self.coco_pretrained_maps["ssd_lite_mobilenet_v2"], delta=0.001)
+        self.assertAlmostEqual(res, self.coco_pretrained_maps[Models.SSD_LITE_MOBILENET_V2], delta=0.001)
 
 
     def test_transfer_learning_ssd_lite_mobilenet_v2_coco(self):
     def test_transfer_learning_ssd_lite_mobilenet_v2_coco(self):
         trainer = Trainer("coco_ssd_lite_mobilenet_v2_transfer_learning")
         trainer = Trainer("coco_ssd_lite_mobilenet_v2_transfer_learning")
-        transfer_arch_params = self.coco_pretrained_arch_params["ssd_lite_mobilenet_v2"].copy()
+        transfer_arch_params = self.coco_pretrained_arch_params[Models.SSD_LITE_MOBILENET_V2].copy()
         transfer_arch_params["num_classes"] = 5
         transfer_arch_params["num_classes"] = 5
-        model = models.get("ssd_lite_mobilenet_v2", arch_params=transfer_arch_params, **self.coco_pretrained_ckpt_params)
+        model = models.get(Models.SSD_LITE_MOBILENET_V2, arch_params=transfer_arch_params, **self.coco_pretrained_ckpt_params)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_detection_train_params_ssd,
             training_params=self.transfer_detection_train_params_ssd,
@@ -483,8 +497,8 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_ssd_mobilenet_v1_coco(self):
     def test_pretrained_ssd_mobilenet_v1_coco(self):
-        trainer = Trainer("coco_ssd_mobilenet_v1")
-        model = models.get("ssd_mobilenet_v1", arch_params=self.coco_pretrained_arch_params["coco_ssd_mobilenet_v1"], **self.coco_pretrained_ckpt_params)
+        trainer = Trainer(Models.SSD_MOBILENET_V1)
+        model = models.get(Models.SSD_MOBILENET_V1, arch_params=self.coco_pretrained_arch_params[Models.SSD_MOBILENET_V1], **self.coco_pretrained_ckpt_params)
         ssd_post_prediction_callback = SSDPostPredictCallback()
         ssd_post_prediction_callback = SSDPostPredictCallback()
         res = trainer.test(
         res = trainer.test(
             model=model,
             model=model,
@@ -492,63 +506,63 @@ class PretrainedModelsTest(unittest.TestCase):
             test_metrics_list=[DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback, num_cls=80)],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback, num_cls=80)],
             metrics_progress_verbose=True,
             metrics_progress_verbose=True,
         )[2]
         )[2]
-        self.assertAlmostEqual(res, self.coco_pretrained_maps["coco_ssd_mobilenet_v1"], delta=0.001)
+        self.assertAlmostEqual(res, self.coco_pretrained_maps[Models.SSD_MOBILENET_V1], delta=0.001)
 
 
     def test_pretrained_yolox_s_coco(self):
     def test_pretrained_yolox_s_coco(self):
-        trainer = Trainer("yolox_s")
+        trainer = Trainer(Models.YOLOX_S)
 
 
-        model = models.get("yolox_s", **self.coco_pretrained_ckpt_params)
+        model = models.get(Models.YOLOX_S, **self.coco_pretrained_ckpt_params)
         res = trainer.test(
         res = trainer.test(
             model=model,
             model=model,
             test_loader=self.coco_dataset["yolox"],
             test_loader=self.coco_dataset["yolox"],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
         )[2]
         )[2]
-        self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_s"], delta=0.001)
+        self.assertAlmostEqual(res, self.coco_pretrained_maps[Models.YOLOX_S], delta=0.001)
 
 
     def test_pretrained_yolox_m_coco(self):
     def test_pretrained_yolox_m_coco(self):
-        trainer = Trainer("yolox_m")
-        model = models.get("yolox_m", **self.coco_pretrained_ckpt_params)
+        trainer = Trainer(Models.YOLOX_M)
+        model = models.get(Models.YOLOX_M, **self.coco_pretrained_ckpt_params)
         res = trainer.test(
         res = trainer.test(
             model=model,
             model=model,
             test_loader=self.coco_dataset["yolox"],
             test_loader=self.coco_dataset["yolox"],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
         )[2]
         )[2]
-        self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_m"], delta=0.001)
+        self.assertAlmostEqual(res, self.coco_pretrained_maps[Models.YOLOX_M], delta=0.001)
 
 
     def test_pretrained_yolox_l_coco(self):
     def test_pretrained_yolox_l_coco(self):
-        trainer = Trainer("yolox_l")
-        model = models.get("yolox_l", **self.coco_pretrained_ckpt_params)
+        trainer = Trainer(Models.YOLOX_L)
+        model = models.get(Models.YOLOX_L, **self.coco_pretrained_ckpt_params)
         res = trainer.test(
         res = trainer.test(
             model=model,
             model=model,
             test_loader=self.coco_dataset["yolox"],
             test_loader=self.coco_dataset["yolox"],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
         )[2]
         )[2]
-        self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_l"], delta=0.001)
+        self.assertAlmostEqual(res, self.coco_pretrained_maps[Models.YOLOX_L], delta=0.001)
 
 
     def test_pretrained_yolox_n_coco(self):
     def test_pretrained_yolox_n_coco(self):
-        trainer = Trainer("yolox_n")
+        trainer = Trainer(Models.YOLOX_N)
 
 
-        model = models.get("yolox_n", **self.coco_pretrained_ckpt_params)
+        model = models.get(Models.YOLOX_N, **self.coco_pretrained_ckpt_params)
         res = trainer.test(
         res = trainer.test(
             model=model,
             model=model,
             test_loader=self.coco_dataset["yolox"],
             test_loader=self.coco_dataset["yolox"],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
         )[2]
         )[2]
-        self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_n"], delta=0.001)
+        self.assertAlmostEqual(res, self.coco_pretrained_maps[Models.YOLOX_N], delta=0.001)
 
 
     def test_pretrained_yolox_t_coco(self):
     def test_pretrained_yolox_t_coco(self):
-        trainer = Trainer("yolox_t")
-        model = models.get("yolox_t", **self.coco_pretrained_ckpt_params)
+        trainer = Trainer(Models.YOLOX_T)
+        model = models.get(Models.YOLOX_T, **self.coco_pretrained_ckpt_params)
         res = trainer.test(
         res = trainer.test(
             model=model,
             model=model,
             test_loader=self.coco_dataset["yolox"],
             test_loader=self.coco_dataset["yolox"],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
             test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=80, normalize_targets=True)],
         )[2]
         )[2]
-        self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_t"], delta=0.001)
+        self.assertAlmostEqual(res, self.coco_pretrained_maps[Models.YOLOX_T], delta=0.001)
 
 
     def test_transfer_learning_yolox_n_coco(self):
     def test_transfer_learning_yolox_n_coco(self):
         trainer = Trainer("test_transfer_learning_yolox_n_coco")
         trainer = Trainer("test_transfer_learning_yolox_n_coco")
-        model = models.get("yolox_n", **self.coco_pretrained_ckpt_params, num_classes=5)
+        model = models.get(Models.YOLOX_N, **self.coco_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_detection_train_params_yolox,
             training_params=self.transfer_detection_train_params_yolox,
@@ -560,7 +574,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer("imagenet_pretrained_mobilenet_v3_large_transfer_learning")
         trainer = Trainer("imagenet_pretrained_mobilenet_v3_large_transfer_learning")
 
 
         model = models.get(
         model = models.get(
-            "mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params, num_classes=5
+            Models.MOBILENET_V3_LARGE, arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params, num_classes=5
         )
         )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -572,15 +586,15 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_mobilenet_v3_large_imagenet(self):
     def test_pretrained_mobilenet_v3_large_imagenet(self):
         trainer = Trainer("imagenet_mobilenet_v3_large")
         trainer = Trainer("imagenet_mobilenet_v3_large")
 
 
-        model = models.get("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.MOBILENET_V3_LARGE, arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_large"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.MOBILENET_V3_LARGE], delta=0.001)
 
 
     def test_transfer_learning_mobilenet_v3_small_imagenet(self):
     def test_transfer_learning_mobilenet_v3_small_imagenet(self):
         trainer = Trainer("imagenet_pretrained_mobilenet_v3_small_transfer_learning")
         trainer = Trainer("imagenet_pretrained_mobilenet_v3_small_transfer_learning")
 
 
         model = models.get(
         model = models.get(
-            "mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params, num_classes=5
+            Models.MOBILENET_V3_SMALL, arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params, num_classes=5
         )
         )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -592,14 +606,16 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_mobilenet_v3_small_imagenet(self):
     def test_pretrained_mobilenet_v3_small_imagenet(self):
         trainer = Trainer("imagenet_mobilenet_v3_small")
         trainer = Trainer("imagenet_mobilenet_v3_small")
 
 
-        model = models.get("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.MOBILENET_V3_SMALL, arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_small"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.MOBILENET_V3_SMALL], delta=0.001)
 
 
     def test_transfer_learning_mobilenet_v2_imagenet(self):
     def test_transfer_learning_mobilenet_v2_imagenet(self):
         trainer = Trainer("imagenet_pretrained_mobilenet_v2_transfer_learning")
         trainer = Trainer("imagenet_pretrained_mobilenet_v2_transfer_learning")
 
 
-        model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
+        model = models.get(
+            Models.MOBILENET_V2, arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params, num_classes=5
+        )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -610,13 +626,14 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_pretrained_mobilenet_v2_imagenet(self):
     def test_pretrained_mobilenet_v2_imagenet(self):
         trainer = Trainer("imagenet_mobilenet_v2")
         trainer = Trainer("imagenet_mobilenet_v2")
 
 
-        model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.MOBILENET_V2, arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v2"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.MOBILENET_V2], delta=0.001)
 
 
     def test_pretrained_stdc1_seg50_cityscapes(self):
     def test_pretrained_stdc1_seg50_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_stdc1_seg50")
         trainer = Trainer("cityscapes_pretrained_stdc1_seg50")
-        model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
+
+        model = models.get(Models.STDC1_SEG50, arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
                 model=model,
                 model=model,
@@ -627,11 +644,13 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg50"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.STDC1_SEG50], delta=0.001)
 
 
     def test_transfer_learning_stdc1_seg50_cityscapes(self):
     def test_transfer_learning_stdc1_seg50_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_stdc1_seg50_transfer_learning")
         trainer = Trainer("cityscapes_pretrained_stdc1_seg50_transfer_learning")
-        model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5)
+        model = models.get(
+            Models.STDC1_SEG50, arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5
+        )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.stdc_transfer_segmentation_train_params,
             training_params=self.stdc_transfer_segmentation_train_params,
@@ -641,7 +660,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_pretrained_stdc1_seg75_cityscapes(self):
     def test_pretrained_stdc1_seg75_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_stdc1_seg75")
         trainer = Trainer("cityscapes_pretrained_stdc1_seg75")
-        model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.STDC1_SEG75, arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
                 model=model,
                 model=model,
@@ -652,11 +671,13 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg75"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.STDC1_SEG75], delta=0.001)
 
 
     def test_transfer_learning_stdc1_seg75_cityscapes(self):
     def test_transfer_learning_stdc1_seg75_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_stdc1_seg75_transfer_learning")
         trainer = Trainer("cityscapes_pretrained_stdc1_seg75_transfer_learning")
-        model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5)
+        model = models.get(
+            Models.STDC1_SEG75, arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5
+        )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.stdc_transfer_segmentation_train_params,
             training_params=self.stdc_transfer_segmentation_train_params,
@@ -666,7 +687,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_pretrained_stdc2_seg50_cityscapes(self):
     def test_pretrained_stdc2_seg50_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_stdc2_seg50")
         trainer = Trainer("cityscapes_pretrained_stdc2_seg50")
-        model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.STDC2_SEG50, arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
                 model=model,
                 model=model,
@@ -677,11 +698,13 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg50"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.STDC2_SEG50], delta=0.001)
 
 
     def test_transfer_learning_stdc2_seg50_cityscapes(self):
     def test_transfer_learning_stdc2_seg50_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_stdc2_seg50_transfer_learning")
         trainer = Trainer("cityscapes_pretrained_stdc2_seg50_transfer_learning")
-        model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5)
+        model = models.get(
+            Models.STDC2_SEG50, arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5
+        )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.stdc_transfer_segmentation_train_params,
             training_params=self.stdc_transfer_segmentation_train_params,
@@ -691,7 +714,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_pretrained_stdc2_seg75_cityscapes(self):
     def test_pretrained_stdc2_seg75_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_stdc2_seg75")
         trainer = Trainer("cityscapes_pretrained_stdc2_seg75")
-        model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.STDC2_SEG75, arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
                 model=model,
                 model=model,
@@ -702,11 +725,13 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg75"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.STDC2_SEG75], delta=0.001)
 
 
     def test_transfer_learning_stdc2_seg75_cityscapes(self):
     def test_transfer_learning_stdc2_seg75_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_stdc2_seg75_transfer_learning")
         trainer = Trainer("cityscapes_pretrained_stdc2_seg75_transfer_learning")
-        model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5)
+        model = models.get(
+            Models.STDC2_SEG75, arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5
+        )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.stdc_transfer_segmentation_train_params,
             training_params=self.stdc_transfer_segmentation_train_params,
@@ -717,7 +742,9 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_transfer_learning_vit_base_imagenet21k(self):
     def test_transfer_learning_vit_base_imagenet21k(self):
         trainer = Trainer("imagenet21k_pretrained_vit_base")
         trainer = Trainer("imagenet21k_pretrained_vit_base")
 
 
-        model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet21k_pretrained_ckpt_params, num_classes=5)
+        model = models.get(
+            Models.VIT_BASE, arch_params=self.imagenet_pretrained_arch_params[Models.VIT_BASE], **self.imagenet21k_pretrained_ckpt_params, num_classes=5
+        )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -728,7 +755,9 @@ class PretrainedModelsTest(unittest.TestCase):
     def test_transfer_learning_vit_large_imagenet21k(self):
     def test_transfer_learning_vit_large_imagenet21k(self):
         trainer = Trainer("imagenet21k_pretrained_vit_large")
         trainer = Trainer("imagenet21k_pretrained_vit_large")
 
 
-        model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet21k_pretrained_ckpt_params, num_classes=5)
+        model = models.get(
+            Models.VIT_LARGE, arch_params=self.imagenet_pretrained_arch_params[Models.VIT_BASE], **self.imagenet21k_pretrained_ckpt_params, num_classes=5
+        )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params=self.transfer_classification_train_params,
             training_params=self.transfer_classification_train_params,
@@ -738,39 +767,44 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_pretrained_vit_base_imagenet(self):
     def test_pretrained_vit_base_imagenet(self):
         trainer = Trainer("imagenet_pretrained_vit_base")
         trainer = Trainer("imagenet_pretrained_vit_base")
-        model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.VIT_BASE, arch_params=self.imagenet_pretrained_arch_params[Models.VIT_BASE], **self.imagenet_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0]
             trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0]
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_base"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.VIT_BASE], delta=0.001)
 
 
     def test_pretrained_vit_large_imagenet(self):
     def test_pretrained_vit_large_imagenet(self):
         trainer = Trainer("imagenet_pretrained_vit_large")
         trainer = Trainer("imagenet_pretrained_vit_large")
-        model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(Models.VIT_LARGE, arch_params=self.imagenet_pretrained_arch_params[Models.VIT_BASE], **self.imagenet_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0]
             trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0]
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_large"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.VIT_LARGE], delta=0.001)
 
 
     def test_pretrained_beit_base_imagenet(self):
     def test_pretrained_beit_base_imagenet(self):
         trainer = Trainer("imagenet_pretrained_beit_base")
         trainer = Trainer("imagenet_pretrained_beit_base")
-        model = models.get("beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet_pretrained_ckpt_params)
+        model = models.get(
+            Models.BEIT_BASE_PATCH16_224, arch_params=self.imagenet_pretrained_arch_params[Models.VIT_BASE], **self.imagenet_pretrained_ckpt_params
+        )
         res = (
         res = (
             trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0]
             trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0]
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["beit_base_patch16_224"], delta=0.001)
+        self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies[Models.BEIT_BASE_PATCH16_224], delta=0.001)
 
 
     def test_transfer_learning_beit_base_imagenet(self):
     def test_transfer_learning_beit_base_imagenet(self):
         trainer = Trainer("test_transfer_learning_beit_base_imagenet")
         trainer = Trainer("test_transfer_learning_beit_base_imagenet")
 
 
         model = models.get(
         model = models.get(
-            "beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet_pretrained_ckpt_params, num_classes=5
+            Models.BEIT_BASE_PATCH16_224,
+            arch_params=self.imagenet_pretrained_arch_params[Models.VIT_BASE],
+            **self.imagenet_pretrained_ckpt_params,
+            num_classes=5,
         )
         )
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -781,7 +815,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_pretrained_pplite_t_seg50_cityscapes(self):
     def test_pretrained_pplite_t_seg50_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_pplite_t_seg50")
         trainer = Trainer("cityscapes_pretrained_pplite_t_seg50")
-        model = models.get("pp_lite_t_seg50", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.PP_LITE_T_SEG50, arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
 
 
         res = (
         res = (
             trainer.test(
             trainer.test(
@@ -793,11 +827,11 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_t_seg50"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.PP_LITE_T_SEG50], delta=0.001)
 
 
     def test_pretrained_pplite_t_seg75_cityscapes(self):
     def test_pretrained_pplite_t_seg75_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_pplite_t_seg75")
         trainer = Trainer("cityscapes_pretrained_pplite_t_seg75")
-        model = models.get("pp_lite_t_seg75", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.PP_LITE_T_SEG75, arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
 
 
         res = (
         res = (
             trainer.test(
             trainer.test(
@@ -809,11 +843,11 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_t_seg75"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.PP_LITE_T_SEG75], delta=0.001)
 
 
     def test_pretrained_pplite_b_seg50_cityscapes(self):
     def test_pretrained_pplite_b_seg50_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_pplite_b_seg50")
         trainer = Trainer("cityscapes_pretrained_pplite_b_seg50")
-        model = models.get("pp_lite_b_seg50", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.PP_LITE_B_SEG50, arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
 
 
         res = (
         res = (
             trainer.test(
             trainer.test(
@@ -825,11 +859,11 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_b_seg50"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.PP_LITE_B_SEG50], delta=0.001)
 
 
     def test_pretrained_pplite_b_seg75_cityscapes(self):
     def test_pretrained_pplite_b_seg75_cityscapes(self):
         trainer = Trainer("cityscapes_pretrained_pplite_b_seg75")
         trainer = Trainer("cityscapes_pretrained_pplite_b_seg75")
-        model = models.get("pp_lite_b_seg75", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
+        model = models.get(Models.PP_LITE_B_SEG75, arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
 
 
         res = (
         res = (
             trainer.test(
             trainer.test(
@@ -841,7 +875,7 @@ class PretrainedModelsTest(unittest.TestCase):
             .cpu()
             .cpu()
             .item()
             .item()
         )
         )
-        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_b_seg75"], delta=0.001)
+        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious[Models.PP_LITE_B_SEG75], delta=0.001)
 
 
     def tearDown(self) -> None:
     def tearDown(self) -> None:
         if os.path.exists("~/.cache/torch/hub/"):
         if os.path.exists("~/.cache/torch/hub/"):
Discard
@@ -5,6 +5,7 @@ import unittest
 import pkg_resources
 import pkg_resources
 from omegaconf import OmegaConf
 from omegaconf import OmegaConf
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training.models import SgModule, get_arch_params
 from super_gradients.training.models import SgModule, get_arch_params
 from super_gradients.training.models.model_factory import get_architecture
 from super_gradients.training.models.model_factory import get_architecture
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.utils import HpmStruct
@@ -215,7 +216,7 @@ class ConfigInspectTest(unittest.TestCase):
 
 
     def test_resnet18_cifar_arch_params(self):
     def test_resnet18_cifar_arch_params(self):
         arch_params = get_arch_params("resnet18_cifar_arch_params")
         arch_params = get_arch_params("resnet18_cifar_arch_params")
-        architecture_cls, arch_params, pretrained_weights_path, is_remote = get_architecture("resnet18", HpmStruct(**arch_params))
+        architecture_cls, arch_params, pretrained_weights_path, is_remote = get_architecture(Models.RESNET18, HpmStruct(**arch_params))
 
 
         with raise_if_unused_params(arch_params) as tracked_arch_params:
         with raise_if_unused_params(arch_params) as tracked_arch_params:
             _ = architecture_cls(arch_params=tracked_arch_params)
             _ = architecture_cls(arch_params=tracked_arch_params)
Discard
@@ -1,5 +1,6 @@
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training.dataloaders.dataloaders import coco2017_train, coco2017_val
 from super_gradients.training.dataloaders.dataloaders import coco2017_train, coco2017_val
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 
 
@@ -19,7 +20,7 @@ class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
 
 
         trainer = Trainer("dataset_statistics_visual_test")
         trainer = Trainer("dataset_statistics_visual_test")
 
 
-        model = models.get("yolox_s")
+        model = models.get(Models.YOLOX_S)
 
 
         training_params = {
         training_params = {
             "max_epochs": 1,  # we dont really need the actual training to run
             "max_epochs": 1,  # we dont really need the actual training to run
Discard
@@ -4,6 +4,7 @@ import unittest
 import numpy as np
 import numpy as np
 import torch.cuda
 import torch.cuda
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import Trainer, utils as core_utils, models
 from super_gradients.training import Trainer, utils as core_utils, models
 from super_gradients.training.dataloaders.dataloaders import coco2017_val
 from super_gradients.training.dataloaders.dataloaders import coco2017_val
 from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
 from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
@@ -16,7 +17,7 @@ from tests.core_test_utils import is_data_available
 class TestDetectionUtils(unittest.TestCase):
 class TestDetectionUtils(unittest.TestCase):
     def setUp(self):
     def setUp(self):
         self.device = "cuda" if torch.cuda.is_available() else "cpu"
         self.device = "cuda" if torch.cuda.is_available() else "cpu"
-        self.model = models.get("yolox_n", pretrained_weights="coco").to(self.device)
+        self.model = models.get(Models.YOLOX_N, pretrained_weights="coco").to(self.device)
         self.model.eval()
         self.model.eval()
 
 
     @unittest.skipIf(not is_data_available(), "run only when /data is available")
     @unittest.skipIf(not is_data_available(), "run only when /data is available")
Discard
@@ -1,6 +1,7 @@
 import unittest
 import unittest
 import torch
 import torch
 from super_gradients import Trainer
 from super_gradients import Trainer
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
@@ -18,7 +19,7 @@ class CallTrainTwiceTest(unittest.TestCase):
         trainer = Trainer("external_criterion_test")
         trainer = Trainer("external_criterion_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
Discard
@@ -1,6 +1,7 @@
 import tempfile
 import tempfile
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from torchvision.transforms import Compose, Normalize, Resize
 from torchvision.transforms import Compose, Normalize, Resize
 from super_gradients.training.transforms import Standardize
 from super_gradients.training.transforms import Standardize
@@ -9,7 +10,7 @@ import os
 
 
 class TestModelsONNXExport(unittest.TestCase):
 class TestModelsONNXExport(unittest.TestCase):
     def test_models_onnx_export(self):
     def test_models_onnx_export(self):
-        pretrained_model = models.get("resnet18", num_classes=1000, pretrained_weights="imagenet")
+        pretrained_model = models.get(Models.RESNET18, num_classes=1000, pretrained_weights="imagenet")
         preprocess = Compose([Resize(224), Standardize(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
         preprocess = Compose([Resize(224), Standardize(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
         with tempfile.TemporaryDirectory() as tmpdirname:
         with tempfile.TemporaryDirectory() as tmpdirname:
             out_path = os.path.join(tmpdirname, "resnet18.onnx")
             out_path = os.path.join(tmpdirname, "resnet18.onnx")
Discard
@@ -5,6 +5,7 @@ import torch
 from super_gradients import Trainer
 from super_gradients import Trainer
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
 from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.losses import LabelSmoothingCrossEntropyLoss
 from super_gradients.training.losses import LabelSmoothingCrossEntropyLoss
@@ -15,7 +16,7 @@ from torch import nn
 class FactoriesTest(unittest.TestCase):
 class FactoriesTest(unittest.TestCase):
     def test_training_with_factories(self):
     def test_training_with_factories(self):
         trainer = Trainer("test_train_with_factories")
         trainer = Trainer("test_train_with_factories")
-        net = models.get("resnet18", num_classes=5)
+        net = models.get(Models.RESNET18, num_classes=5)
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
Discard
@@ -1,4 +1,6 @@
 import unittest
 import unittest
+
+from super_gradients.common.object_names import Models
 from super_gradients.training import Trainer, models
 from super_gradients.training import Trainer, models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
@@ -28,7 +30,7 @@ class ForwardpassPrepFNTest(unittest.TestCase):
     def test_resizing_with_forward_pass_prep_fn(self):
     def test_resizing_with_forward_pass_prep_fn(self):
         # Define Model
         # Define Model
         trainer = Trainer("ForwardpassPrepFNTest")
         trainer = Trainer("ForwardpassPrepFNTest")
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
 
 
         sizes = []
         sizes = []
         phase_callbacks = [TestInputSizesCallback(sizes)]
         phase_callbacks = [TestInputSizesCallback(sizes)]
Discard
@@ -1,5 +1,6 @@
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 
 
 from super_gradients import Trainer
 from super_gradients import Trainer
@@ -27,7 +28,7 @@ class InitializeWithDataloadersTest(unittest.TestCase):
 
 
     def test_train_with_dataloaders(self):
     def test_train_with_dataloaders(self):
         trainer = Trainer(experiment_name="test_name")
         trainer = Trainer(experiment_name="test_name")
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
             training_params={
             training_params={
Discard
@@ -8,6 +8,7 @@ import torch
 from super_gradients.training.utils.utils import check_models_have_same_weights
 from super_gradients.training.utils.utils import check_models_have_same_weights
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.losses.kd_losses import KDLogitsLoss
 from super_gradients.training.losses.kd_losses import KDLogitsLoss
+from super_gradients.common.object_names import Models
 
 
 
 
 class KDEMATest(unittest.TestCase):
 class KDEMATest(unittest.TestCase):
@@ -39,8 +40,8 @@ class KDEMATest(unittest.TestCase):
         """Check that the teacher EMA is a reference to the teacher net (not a copy)."""
         """Check that the teacher EMA is a reference to the teacher net (not a copy)."""
 
 
         kd_model = KDTrainer("test_teacher_ema_not_duplicated")
         kd_model = KDTrainer("test_teacher_ema_not_duplicated")
-        student = models.get("resnet18", arch_params={"num_classes": 1000})
-        teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
         kd_model.train(
         kd_model.train(
             training_params=self.kd_train_params,
             training_params=self.kd_train_params,
@@ -59,8 +60,8 @@ class KDEMATest(unittest.TestCase):
         # Create a KD trainer and train it
         # Create a KD trainer and train it
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
         kd_model = KDTrainer("test_kd_ema_ckpt_reload")
         kd_model = KDTrainer("test_kd_ema_ckpt_reload")
-        student = models.get("resnet18", arch_params={"num_classes": 1000})
-        teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
         kd_model.train(
         kd_model.train(
             training_params=self.kd_train_params,
             training_params=self.kd_train_params,
@@ -74,8 +75,8 @@ class KDEMATest(unittest.TestCase):
 
 
         # Load the trained KD trainer
         # Load the trained KD trainer
         kd_model = KDTrainer("test_kd_ema_ckpt_reload")
         kd_model = KDTrainer("test_kd_ema_ckpt_reload")
-        student = models.get("resnet18", arch_params={"num_classes": 1000})
-        teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
         train_params["resume"] = True
         train_params["resume"] = True
         kd_model.train(
         kd_model.train(
Discard
@@ -2,6 +2,7 @@ import os
 import unittest
 import unittest
 from copy import deepcopy
 from copy import deepcopy
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.kd_trainer.kd_trainer import KDTrainer
 from super_gradients.training.kd_trainer.kd_trainer import KDTrainer
 import torch
 import torch
@@ -57,8 +58,8 @@ class KDTrainerTest(unittest.TestCase):
         }
         }
 
 
     def test_teacher_sg_module_methods(self):
     def test_teacher_sg_module_methods(self):
-        student = models.get("resnet18", arch_params={"num_classes": 1000})
-        teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
         kd_module = KDModule(arch_params={}, student=student, teacher=teacher)
         kd_module = KDModule(arch_params={}, student=student, teacher=teacher)
 
 
         initial_param_groups = kd_module.initialize_param_groups(lr=0.1, training_params={})
         initial_param_groups = kd_module.initialize_param_groups(lr=0.1, training_params={})
@@ -87,8 +88,8 @@ class KDTrainerTest(unittest.TestCase):
 
 
     def test_train_model_with_input_adapter(self):
     def test_train_model_with_input_adapter(self):
         kd_trainer = KDTrainer("train_kd_module_with_with_input_adapter")
         kd_trainer = KDTrainer("train_kd_module_with_with_input_adapter")
-        student = models.get("resnet18", arch_params={"num_classes": 5})
-        teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
 
 
         adapter = NormalizationAdapter(
         adapter = NormalizationAdapter(
             mean_original=[0.485, 0.456, 0.406], std_original=[0.229, 0.224, 0.225], mean_required=[0.5, 0.5, 0.5], std_required=[0.5, 0.5, 0.5]
             mean_original=[0.485, 0.456, 0.406], std_original=[0.229, 0.224, 0.225], mean_required=[0.5, 0.5, 0.5], std_required=[0.5, 0.5, 0.5]
@@ -108,8 +109,8 @@ class KDTrainerTest(unittest.TestCase):
 
 
     def test_load_ckpt_best_for_student(self):
     def test_load_ckpt_best_for_student(self):
         kd_trainer = KDTrainer("test_load_ckpt_best")
         kd_trainer = KDTrainer("test_load_ckpt_best")
-        student = models.get("resnet18", arch_params={"num_classes": 5})
-        teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
         train_params["max_epochs"] = 1
         train_params["max_epochs"] = 1
         kd_trainer.train(
         kd_trainer.train(
@@ -121,14 +122,14 @@ class KDTrainerTest(unittest.TestCase):
         )
         )
         best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
         best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
 
 
-        student_reloaded = models.get("resnet18", arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt)
+        student_reloaded = models.get(Models.RESNET18, arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt)
 
 
         self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.net.module.student))
         self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.net.module.student))
 
 
     def test_load_ckpt_best_for_student_with_ema(self):
     def test_load_ckpt_best_for_student_with_ema(self):
         kd_trainer = KDTrainer("test_load_ckpt_best")
         kd_trainer = KDTrainer("test_load_ckpt_best")
-        student = models.get("resnet18", arch_params={"num_classes": 5})
-        teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
         train_params["max_epochs"] = 1
         train_params["max_epochs"] = 1
         train_params["ema"] = True
         train_params["ema"] = True
@@ -141,14 +142,14 @@ class KDTrainerTest(unittest.TestCase):
         )
         )
         best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
         best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
 
 
-        student_reloaded = models.get("resnet18", arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt)
+        student_reloaded = models.get(Models.RESNET18, arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt)
 
 
         self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.ema_model.ema.module.student))
         self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.ema_model.ema.module.student))
 
 
     def test_resume_kd_training(self):
     def test_resume_kd_training(self):
         kd_trainer = KDTrainer("test_resume_training_start")
         kd_trainer = KDTrainer("test_resume_training_start")
-        student = models.get("resnet18", arch_params={"num_classes": 5})
-        teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
         train_params["max_epochs"] = 1
         train_params["max_epochs"] = 1
         kd_trainer.train(
         kd_trainer.train(
@@ -161,8 +162,8 @@ class KDTrainerTest(unittest.TestCase):
         latest_net = deepcopy(kd_trainer.net)
         latest_net = deepcopy(kd_trainer.net)
 
 
         kd_trainer = KDTrainer("test_resume_training_start")
         kd_trainer = KDTrainer("test_resume_training_start")
-        student = models.get("resnet18", arch_params={"num_classes": 5})
-        teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
+        student = models.get(Models.RESNET18, arch_params={"num_classes": 5})
+        teacher = models.get(Models.RESNET50, arch_params={"num_classes": 5}, pretrained_weights="imagenet")
 
 
         train_params["max_epochs"] = 2
         train_params["max_epochs"] = 2
         train_params["resume"] = True
         train_params["resume"] = True
Discard
@@ -1,5 +1,6 @@
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training import Trainer, models
 from super_gradients.training import Trainer, models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
@@ -27,12 +28,12 @@ class LocalCkptHeadReplacementTest(unittest.TestCase):
         }
         }
 
 
         # Define Model
         # Define Model
-        net = models.get("resnet18", num_classes=5)
+        net = models.get(Models.RESNET18, num_classes=5)
         trainer = Trainer("test_resume_training")
         trainer = Trainer("test_resume_training")
         trainer.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
         trainer.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
         ckpt_path = os.path.join(trainer.checkpoints_dir_path, "ckpt_latest.pth")
         ckpt_path = os.path.join(trainer.checkpoints_dir_path, "ckpt_latest.pth")
 
 
-        net2 = models.get("resnet18", num_classes=10, checkpoint_num_classes=5, checkpoint_path=ckpt_path)
+        net2 = models.get(Models.RESNET18, num_classes=10, checkpoint_num_classes=5, checkpoint_path=ckpt_path)
         self.assertFalse(check_models_have_same_weights(net, net2))
         self.assertFalse(check_models_have_same_weights(net, net2))
 
 
         net.linear = None
         net.linear = None
Discard
@@ -4,6 +4,7 @@ from torch import Tensor
 from torchmetrics import Accuracy
 from torchmetrics import Accuracy
 import torch
 import torch
 from super_gradients import Trainer
 from super_gradients import Trainer
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 
 
@@ -29,7 +30,7 @@ class LossLoggingsTest(unittest.TestCase):
         trainer = Trainer("test_single_item_logging", model_checkpoints_location="local")
         trainer = Trainer("test_single_item_logging", model_checkpoints_location="local")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         train_params = {
         train_params = {
             "max_epochs": 1,
             "max_epochs": 1,
             "lr_updates": [1],
             "lr_updates": [1],
@@ -53,7 +54,7 @@ class LossLoggingsTest(unittest.TestCase):
         trainer = Trainer("test_multiple_unnamed_components_loss_logging", model_checkpoints_location="local")
         trainer = Trainer("test_multiple_unnamed_components_loss_logging", model_checkpoints_location="local")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         train_params = {
         train_params = {
             "max_epochs": 1,
             "max_epochs": 1,
             "lr_updates": [1],
             "lr_updates": [1],
@@ -77,7 +78,7 @@ class LossLoggingsTest(unittest.TestCase):
         trainer = Trainer("test_multiple_named_components_loss_logging", model_checkpoints_location="local")
         trainer = Trainer("test_multiple_named_components_loss_logging", model_checkpoints_location="local")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         train_params = {
         train_params = {
             "max_epochs": 1,
             "max_epochs": 1,
             "lr_updates": [1],
             "lr_updates": [1],
Discard
@@ -1,5 +1,6 @@
 import unittest
 import unittest
 import super_gradients
 import super_gradients
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
@@ -11,21 +12,21 @@ import shutil
 class PretrainedModelsUnitTest(unittest.TestCase):
 class PretrainedModelsUnitTest(unittest.TestCase):
     def setUp(self) -> None:
     def setUp(self) -> None:
         super_gradients.init_trainer()
         super_gradients.init_trainer()
-        self.imagenet_pretrained_models = ["resnet50", "repvgg_a0", "regnetY800"]
+        self.imagenet_pretrained_models = [Models.RESNET50, "repvgg_a0", "regnetY800"]
 
 
     def test_pretrained_resnet50_imagenet(self):
     def test_pretrained_resnet50_imagenet(self):
         trainer = Trainer("imagenet_pretrained_resnet50_unit_test")
         trainer = Trainer("imagenet_pretrained_resnet50_unit_test")
-        model = models.get("resnet50", pretrained_weights="imagenet")
+        model = models.get(Models.RESNET50, pretrained_weights="imagenet")
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
 
 
     def test_pretrained_regnetY800_imagenet(self):
     def test_pretrained_regnetY800_imagenet(self):
         trainer = Trainer("imagenet_pretrained_regnetY800_unit_test")
         trainer = Trainer("imagenet_pretrained_regnetY800_unit_test")
-        model = models.get("regnetY800", pretrained_weights="imagenet")
+        model = models.get(Models.REGNETY800, pretrained_weights="imagenet")
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
 
 
     def test_pretrained_repvgg_a0_imagenet(self):
     def test_pretrained_repvgg_a0_imagenet(self):
         trainer = Trainer("imagenet_pretrained_repvgg_a0_unit_test")
         trainer = Trainer("imagenet_pretrained_repvgg_a0_unit_test")
-        model = models.get("repvgg_a0", pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
+        model = models.get(Models.REPVGG_A0, pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
 
 
     def tearDown(self) -> None:
     def tearDown(self) -> None:
Discard
@@ -3,6 +3,8 @@ import torch
 import torchvision
 import torchvision
 from torch import nn
 from torch import nn
 
 
+from super_gradients.common.object_names import Models
+
 try:
 try:
     import super_gradients
     import super_gradients
     from pytorch_quantization import nn as quant_nn
     from pytorch_quantization import nn as quant_nn
@@ -787,7 +789,7 @@ class QuantizationUtilityTest(unittest.TestCase):
         if Bottleneck in sq.mapping_instructions:
         if Bottleneck in sq.mapping_instructions:
             sq.mapping_instructions.pop(Bottleneck)
             sq.mapping_instructions.pop(Bottleneck)
 
 
-        resnet_sg: nn.Module = super_gradients.training.models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
+        resnet_sg: nn.Module = super_gradients.training.models.get(Models.RESNET50, pretrained_weights="imagenet", num_classes=1000)
         sq.quantize_module(resnet_sg, preserve_state_dict=True)
         sq.quantize_module(resnet_sg, preserve_state_dict=True)
 
 
         # PYTORCH-QUANTIZATION
         # PYTORCH-QUANTIZATION
@@ -803,7 +805,7 @@ class QuantizationUtilityTest(unittest.TestCase):
         quant_nn.QuantAdaptiveAvgPool2d.set_default_quant_desc_input(QuantDescriptor(calib_method="histogram"))
         quant_nn.QuantAdaptiveAvgPool2d.set_default_quant_desc_input(QuantDescriptor(calib_method="histogram"))
 
 
         quant_modules.initialize()
         quant_modules.initialize()
-        resnet_pyquant: nn.Module = super_gradients.training.models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
+        resnet_pyquant: nn.Module = super_gradients.training.models.get(Models.RESNET50, pretrained_weights="imagenet", num_classes=1000)
 
 
         quant_modules.deactivate()
         quant_modules.deactivate()
 
 
Discard
@@ -3,6 +3,7 @@ import os
 from super_gradients.training import Trainer, models
 from super_gradients.training import Trainer, models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
+from super_gradients.common.object_names import Models
 
 
 
 
 class SaveCkptListUnitTest(unittest.TestCase):
 class SaveCkptListUnitTest(unittest.TestCase):
@@ -31,7 +32,7 @@ class SaveCkptListUnitTest(unittest.TestCase):
         trainer = Trainer("save_ckpt_test")
         trainer = Trainer("save_ckpt_test")
 
 
         # Build Model
         # Build Model
-        model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
+        model = models.get(Models.RESNET18_CIFAR, arch_params={"num_classes": 10})
 
 
         # Train Model (and save ckpt_epoch_list)
         # Train Model (and save ckpt_epoch_list)
         trainer.train(model=model, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
         trainer.train(model=model, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
Discard
@@ -3,6 +3,7 @@ import shutil
 import tempfile
 import tempfile
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.common.sg_loggers import BaseSGLogger
 from super_gradients.common.sg_loggers import BaseSGLogger
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
 import torch
 import torch
@@ -98,8 +99,8 @@ class StrictLoadEnumTest(unittest.TestCase):
 
 
     def test_strict_load_on(self):
     def test_strict_load_on(self):
         # Define Model
         # Define Model
-        model = models.get("resnet18", arch_params={"num_classes": 1000})
-        pretrained_model = models.get("resnet18", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        pretrained_model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
         # Make sure we initialized a model with different weights
         # Make sure we initialized a model with different weights
         assert not self.check_models_have_same_weights(model, pretrained_model)
         assert not self.check_models_have_same_weights(model, pretrained_model)
@@ -107,15 +108,15 @@ class StrictLoadEnumTest(unittest.TestCase):
         pretrained_sd_path = os.path.join(self.temp_working_file_dir, "pretrained_net_strict_load_on.pth")
         pretrained_sd_path = os.path.join(self.temp_working_file_dir, "pretrained_net_strict_load_on.pth")
         torch.save(pretrained_model.state_dict(), pretrained_sd_path)
         torch.save(pretrained_model.state_dict(), pretrained_sd_path)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
 
 
         # Assert the weights were loaded correctly
         # Assert the weights were loaded correctly
         assert self.check_models_have_same_weights(model, pretrained_model)
         assert self.check_models_have_same_weights(model, pretrained_model)
 
 
     def test_strict_load_off(self):
     def test_strict_load_off(self):
         # Define Model
         # Define Model
-        model = models.get("resnet18", arch_params={"num_classes": 1000})
-        pretrained_model = models.get("resnet18", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        pretrained_model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
         # Make sure we initialized a model with different weights
         # Make sure we initialized a model with different weights
         assert not self.check_models_have_same_weights(model, pretrained_model)
         assert not self.check_models_have_same_weights(model, pretrained_model)
@@ -125,17 +126,17 @@ class StrictLoadEnumTest(unittest.TestCase):
         torch.save(pretrained_model.state_dict(), pretrained_sd_path)
         torch.save(pretrained_model.state_dict(), pretrained_sd_path)
 
 
         with self.assertRaises(RuntimeError):
         with self.assertRaises(RuntimeError):
-            models.get("resnet18", arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
+            models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.OFF)
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.OFF)
         del model.linear
         del model.linear
         # Assert the weights were loaded correctly
         # Assert the weights were loaded correctly
         assert self.check_models_have_same_weights(model, pretrained_model)
         assert self.check_models_have_same_weights(model, pretrained_model)
 
 
     def test_strict_load_no_key_matching_sg_checkpoint(self):
     def test_strict_load_no_key_matching_sg_checkpoint(self):
         # Define Model
         # Define Model
-        model = models.get("resnet18", arch_params={"num_classes": 1000})
-        pretrained_model = models.get("resnet18", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
+        pretrained_model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
         # Make sure we initialized a model with different weights
         # Make sure we initialized a model with different weights
         assert not self.check_models_have_same_weights(model, pretrained_model)
         assert not self.check_models_have_same_weights(model, pretrained_model)
@@ -144,9 +145,9 @@ class StrictLoadEnumTest(unittest.TestCase):
         torch.save(self.change_state_dict_keys(pretrained_model.state_dict()), pretrained_sd_path)
         torch.save(self.change_state_dict_keys(pretrained_model.state_dict()), pretrained_sd_path)
 
 
         with self.assertRaises(RuntimeError):
         with self.assertRaises(RuntimeError):
-            models.get("resnet18", arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
+            models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.NO_KEY_MATCHING)
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.NO_KEY_MATCHING)
         # Assert the weights were loaded correctly
         # Assert the weights were loaded correctly
         assert self.check_models_have_same_weights(model, pretrained_model)
         assert self.check_models_have_same_weights(model, pretrained_model)
 
 
Discard
@@ -8,6 +8,7 @@ from super_gradients.training import models
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
+from super_gradients.common.object_names import Models
 
 
 
 
 class TestWithoutTrainTest(unittest.TestCase):
 class TestWithoutTrainTest(unittest.TestCase):
@@ -26,20 +27,20 @@ class TestWithoutTrainTest(unittest.TestCase):
     @staticmethod
     @staticmethod
     def get_classification_trainer(name=""):
     def get_classification_trainer(name=""):
         trainer = Trainer(name)
         trainer = Trainer(name)
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         return trainer, model
         return trainer, model
 
 
     @staticmethod
     @staticmethod
     def get_detection_trainer(name=""):
     def get_detection_trainer(name=""):
         trainer = Trainer(name)
         trainer = Trainer(name)
-        model = models.get("yolox_s", num_classes=5)
+        model = models.get(Models.YOLOX_S, num_classes=5)
         return trainer, model
         return trainer, model
 
 
     @staticmethod
     @staticmethod
     def get_segmentation_trainer(name=""):
     def get_segmentation_trainer(name=""):
         shelfnet_lw_arch_params = {"num_classes": 5}
         shelfnet_lw_arch_params = {"num_classes": 5}
         trainer = Trainer(name)
         trainer = Trainer(name)
-        model = models.get("shelfnet34_lw", arch_params=shelfnet_lw_arch_params)
+        model = models.get(Models.SHELFNET34_LW, arch_params=shelfnet_lw_arch_params)
         return trainer, model
         return trainer, model
 
 
     def test_test_without_train(self):
     def test_test_without_train(self):
Discard
@@ -1,6 +1,7 @@
 import unittest
 import unittest
 import torch
 import torch
 from super_gradients import Trainer
 from super_gradients import Trainer
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
@@ -18,7 +19,7 @@ class CallTrainAfterTestTest(unittest.TestCase):
         trainer = Trainer("test_call_train_after_test")
         trainer = Trainer("test_call_train_after_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
@@ -42,7 +43,7 @@ class CallTrainAfterTestTest(unittest.TestCase):
         trainer = Trainer("test_call_train_after_test_with_loss")
         trainer = Trainer("test_call_train_after_test_with_loss")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
Discard
@@ -7,6 +7,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR
 from torchmetrics import F1Score
 from torchmetrics import F1Score
 
 
 from super_gradients import Trainer
 from super_gradients import Trainer
+from super_gradients.common.object_names import Models
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5, ToyTestClassificationMetric
 from super_gradients.training.metrics import Accuracy, Top5, ToyTestClassificationMetric
@@ -22,7 +23,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer = Trainer("external_criterion_test")
         trainer = Trainer("external_criterion_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
@@ -45,7 +46,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer = Trainer("external_optimizer_test")
         trainer = Trainer("external_optimizer_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         optimizer = SGD(params=model.parameters(), lr=0.1)
         optimizer = SGD(params=model.parameters(), lr=0.1)
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
@@ -70,7 +71,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
         lr = 0.3
         lr = 0.3
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         optimizer = SGD(params=model.parameters(), lr=lr)
         optimizer = SGD(params=model.parameters(), lr=lr)
         lr_scheduler = MultiStepLR(optimizer=optimizer, milestones=[1, 2], gamma=0.1)
         lr_scheduler = MultiStepLR(optimizer=optimizer, milestones=[1, 2], gamma=0.1)
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.TRAIN_EPOCH_END)]
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.TRAIN_EPOCH_END)]
@@ -95,7 +96,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer = Trainer("external_scheduler_test")
         trainer = Trainer("external_scheduler_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         optimizer = SGD  # a class - not an instance
         optimizer = SGD  # a class - not an instance
 
 
         train_params = {
         train_params = {
@@ -117,7 +118,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
         lr = 0.3
         lr = 0.3
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         optimizer = SGD(params=model.parameters(), lr=lr)
         optimizer = SGD(params=model.parameters(), lr=lr)
         lr_scheduler = ReduceLROnPlateau(optimizer=optimizer, patience=0)
         lr_scheduler = ReduceLROnPlateau(optimizer=optimizer, patience=0)
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.VALIDATION_EPOCH_END, "ToyTestClassificationMetric")]
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.VALIDATION_EPOCH_END, "ToyTestClassificationMetric")]
@@ -142,7 +143,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer = Trainer("external_metric_test")
         trainer = Trainer("external_metric_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
-        model = models.get("resnet18", arch_params={"num_classes": 5})
+        model = models.get(Models.RESNET18, arch_params={"num_classes": 5})
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
@@ -172,7 +173,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size)
         train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size)
         val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size)
         val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size)
 
 
-        model = models.get("resnet18", num_classes=5)
+        model = models.get(Models.RESNET18, num_classes=5)
         train_params = {
         train_params = {
             "max_epochs": 2,
             "max_epochs": 2,
             "lr_updates": [1],
             "lr_updates": [1],
Discard
@@ -1,5 +1,6 @@
 import unittest
 import unittest
 
 
+from super_gradients.common.object_names import Models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients import Trainer
 from super_gradients import Trainer
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
@@ -31,7 +32,7 @@ class TestViT(unittest.TestCase):
         Validate vit_base
         Validate vit_base
         """
         """
         trainer = Trainer("test_vit_base")
         trainer = Trainer("test_vit_base")
-        model = models.get("vit_base", arch_params=self.arch_params, num_classes=5)
+        model = models.get(Models.VIT_BASE, arch_params=self.arch_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model, training_params=self.train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
             model=model, training_params=self.train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
         )
         )
Discard