Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#744 Feature/sg 691 register dynamically metrics and datasets

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-691-register_dynamically_metrics_and_datasets
100 changed files with 647 additions and 437 deletions
  1. 2
    1
      src/super_gradients/__init__.py
  2. 1
    1
      src/super_gradients/common/factories/callbacks_factory.py
  3. 1
    1
      src/super_gradients/common/factories/collate_functions_factory.py
  4. 1
    1
      src/super_gradients/common/factories/datasets_factory.py
  5. 1
    1
      src/super_gradients/common/factories/detection_modules_factory.py
  6. 1
    1
      src/super_gradients/common/factories/metrics_factory.py
  7. 1
    1
      src/super_gradients/common/factories/optimizers_type_factory.py
  8. 2
    2
      src/super_gradients/common/factories/pre_launch_callbacks_factory.py
  9. 1
    1
      src/super_gradients/common/factories/samplers_factory.py
  10. 1
    1
      src/super_gradients/common/factories/target_generator_factory.py
  11. 3
    1
      src/super_gradients/common/factories/transforms_factory.py
  12. 78
    0
      src/super_gradients/common/object_names.py
  13. 31
    0
      src/super_gradients/common/registry/albumentation.py
  14. 91
    20
      src/super_gradients/common/registry/registry.py
  15. 2
    7
      src/super_gradients/common/sg_loggers/__init__.py
  16. 3
    0
      src/super_gradients/common/sg_loggers/base_sg_logger.py
  17. 4
    1
      src/super_gradients/common/sg_loggers/clearml_sg_logger.py
  18. 2
    0
      src/super_gradients/common/sg_loggers/deci_platform_sg_logger.py
  19. 2
    0
      src/super_gradients/common/sg_loggers/wandb_sg_logger.py
  20. 16
    8
      src/super_gradients/modules/__init__.py
  21. 0
    29
      src/super_gradients/modules/all_detection_modules.py
  22. 9
    0
      src/super_gradients/modules/detection_modules.py
  23. 2
    0
      src/super_gradients/modules/pose_estimation_modules.py
  24. 0
    2
      src/super_gradients/training/__init__.py
  25. 62
    62
      src/super_gradients/training/dataloaders/dataloaders.py
  26. 0
    13
      src/super_gradients/training/datasets/all_collate_functions.py
  27. 0
    31
      src/super_gradients/training/datasets/all_datasets.py
  28. 0
    3
      src/super_gradients/training/datasets/all_target_generators.py
  29. 3
    0
      src/super_gradients/training/datasets/auto_augment.py
  30. 5
    1
      src/super_gradients/training/datasets/classification_datasets/cifar.py
  31. 3
    0
      src/super_gradients/training/datasets/classification_datasets/imagenet_dataset.py
  32. 5
    0
      src/super_gradients/training/datasets/data_augmentation.py
  33. 6
    0
      src/super_gradients/training/datasets/datasets_utils.py
  34. 3
    0
      src/super_gradients/training/datasets/detection_datasets/coco_detection.py
  35. 3
    0
      src/super_gradients/training/datasets/detection_datasets/detection_dataset.py
  36. 3
    0
      src/super_gradients/training/datasets/detection_datasets/pascal_voc_detection.py
  37. 2
    0
      src/super_gradients/training/datasets/mixup.py
  38. 2
    0
      src/super_gradients/training/datasets/pose_estimation_datasets/base_keypoints.py
  39. 3
    0
      src/super_gradients/training/datasets/pose_estimation_datasets/coco_keypoints.py
  40. 3
    0
      src/super_gradients/training/datasets/pose_estimation_datasets/target_generators.py
  41. 2
    1
      src/super_gradients/training/datasets/samplers/__init__.py
  42. 0
    16
      src/super_gradients/training/datasets/samplers/all_samplers.py
  43. 4
    0
      src/super_gradients/training/datasets/samplers/infinite_sampler.py
  44. 3
    0
      src/super_gradients/training/datasets/samplers/repeated_augmentation_sampler.py
  45. 4
    0
      src/super_gradients/training/datasets/segmentation_datasets/cityscape_segmentation.py
  46. 3
    0
      src/super_gradients/training/datasets/segmentation_datasets/coco_segmentation.py
  47. 3
    0
      src/super_gradients/training/datasets/segmentation_datasets/mapillary_dataset.py
  48. 5
    0
      src/super_gradients/training/datasets/segmentation_datasets/pascal_voc_segmentation.py
  49. 3
    0
      src/super_gradients/training/datasets/segmentation_datasets/segmentation_dataset.py
  50. 3
    0
      src/super_gradients/training/datasets/segmentation_datasets/supervisely_persons_segmentation.py
  51. 1
    1
      src/super_gradients/training/kd_trainer/kd_trainer.py
  52. 5
    3
      src/super_gradients/training/losses/__init__.py
  53. 0
    0
      src/super_gradients/training/losses/all_losses.py
  54. 3
    2
      src/super_gradients/training/losses/dice_ce_edge_loss.py
  55. 1
    3
      src/super_gradients/training/losses/r_squared_loss.py
  56. 0
    1
      src/super_gradients/training/losses/shelfnet_semantic_encoding_loss.py
  57. 2
    1
      src/super_gradients/training/metrics/__init__.py
  58. 0
    31
      src/super_gradients/training/metrics/all_metrics.py
  59. 6
    1
      src/super_gradients/training/metrics/classification_metrics.py
  60. 7
    0
      src/super_gradients/training/metrics/detection_metrics.py
  61. 3
    0
      src/super_gradients/training/metrics/pose_estimation_metrics.py
  62. 9
    0
      src/super_gradients/training/metrics/segmentation_metrics.py
  63. 12
    1
      src/super_gradients/training/models/__init__.py
  64. 0
    154
      src/super_gradients/training/models/all_architectures.py
  65. 5
    0
      src/super_gradients/training/models/classification_models/beit.py
  66. 8
    0
      src/super_gradients/training/models/classification_models/densenet.py
  67. 14
    0
      src/super_gradients/training/models/classification_models/efficientnet.py
  68. 4
    0
      src/super_gradients/training/models/classification_models/googlenet.py
  69. 6
    0
      src/super_gradients/training/models/classification_models/mobilenetv2.py
  70. 6
    0
      src/super_gradients/training/models/classification_models/mobilenetv3.py
  71. 9
    0
      src/super_gradients/training/models/classification_models/regnet.py
  72. 11
    0
      src/super_gradients/training/models/classification_models/repvgg.py
  73. 13
    0
      src/super_gradients/training/models/classification_models/resnet.py
  74. 5
    0
      src/super_gradients/training/models/classification_models/resnext.py
  75. 7
    0
      src/super_gradients/training/models/classification_models/shufflenetv2.py
  76. 7
    1
      src/super_gradients/training/models/classification_models/vit.py
  77. 3
    0
      src/super_gradients/training/models/detection_models/csp_darknet53.py
  78. 2
    0
      src/super_gradients/training/models/detection_models/csp_resnet.py
  79. 4
    0
      src/super_gradients/training/models/detection_models/darknet53.py
  80. 4
    1
      src/super_gradients/training/models/detection_models/pp_yolo_e/pan.py
  81. 6
    0
      src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_e.py
  82. 4
    0
      src/super_gradients/training/models/detection_models/ssd.py
  83. 9
    0
      src/super_gradients/training/models/detection_models/yolox.py
  84. 5
    0
      src/super_gradients/training/models/kd_modules/kd_module.py
  85. 1
    1
      src/super_gradients/training/models/model_factory.py
  86. 4
    0
      src/super_gradients/training/models/pose_estimation_models/dekr_hrnet.py
  87. 3
    0
      src/super_gradients/training/models/pose_estimation_models/pose_ddrnet39.py
  88. 3
    0
      src/super_gradients/training/models/pose_estimation_models/pose_ppyolo.py
  89. 7
    2
      src/super_gradients/training/models/segmentation_models/ddrnet.py
  90. 2
    0
      src/super_gradients/training/models/segmentation_models/ddrnet_backbones.py
  91. 9
    0
      src/super_gradients/training/models/segmentation_models/ppliteseg.py
  92. 4
    0
      src/super_gradients/training/models/segmentation_models/regseg.py
  93. 7
    0
      src/super_gradients/training/models/segmentation_models/shelfnet.py
  94. 12
    0
      src/super_gradients/training/models/segmentation_models/stdc.py
  95. 4
    0
      src/super_gradients/training/models/segmentation_models/unet/unet.py
  96. 4
    7
      src/super_gradients/training/models/segmentation_models/unet/unet_decoder.py
  97. 8
    9
      src/super_gradients/training/models/segmentation_models/unet/unet_encoder.py
  98. 1
    5
      src/super_gradients/training/pre_launch_callbacks/__init__.py
  99. 3
    0
      src/super_gradients/training/pre_launch_callbacks/pre_launch_callbacks.py
  100. 6
    7
      src/super_gradients/training/sg_trainer/sg_trainer.py
@@ -1,5 +1,6 @@
 from super_gradients.common import init_trainer, is_distributed, object_names
 from super_gradients.common import init_trainer, is_distributed, object_names
-from super_gradients.training import ARCHITECTURES, losses, utils, datasets_utils, DataAugmentation, Trainer, KDTrainer, QATTrainer
+from super_gradients.training import losses, utils, datasets_utils, DataAugmentation, Trainer, KDTrainer, QATTrainer
+from super_gradients.common.registry.registry import ARCHITECTURES
 from super_gradients.examples.train_from_recipe_example import train_from_recipe
 from super_gradients.examples.train_from_recipe_example import train_from_recipe
 from super_gradients.examples.train_from_kd_recipe_example import train_from_kd_recipe
 from super_gradients.examples.train_from_kd_recipe_example import train_from_kd_recipe
 from super_gradients.sanity_check import env_sanity_check
 from super_gradients.sanity_check import env_sanity_check
Discard
@@ -1,5 +1,5 @@
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.base_factory import BaseFactory
-from super_gradients.training.utils.callbacks import CALLBACKS
+from super_gradients.common.registry.registry import CALLBACKS
 
 
 
 
 class CallbacksFactory(BaseFactory):
 class CallbacksFactory(BaseFactory):
Discard
@@ -1,5 +1,5 @@
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.base_factory import BaseFactory
-from super_gradients.training.datasets.all_collate_functions import ALL_COLLATE_FUNCTIONS
+from super_gradients.common.registry.registry import ALL_COLLATE_FUNCTIONS
 
 
 
 
 class CollateFunctionsFactory(BaseFactory):
 class CollateFunctionsFactory(BaseFactory):
Discard
@@ -1,5 +1,5 @@
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.base_factory import BaseFactory
-from super_gradients.training.datasets.all_datasets import ALL_DATASETS
+from super_gradients.common.registry.registry import ALL_DATASETS
 
 
 
 
 class DatasetsFactory(BaseFactory):
 class DatasetsFactory(BaseFactory):
Discard
@@ -4,7 +4,7 @@ from omegaconf import DictConfig
 
 
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.utils import HpmStruct
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.base_factory import BaseFactory
-from super_gradients.modules.all_detection_modules import ALL_DETECTION_MODULES
+from super_gradients.common.registry.registry import ALL_DETECTION_MODULES
 
 
 
 
 class DetectionModulesFactory(BaseFactory):
 class DetectionModulesFactory(BaseFactory):
Discard
@@ -1,5 +1,5 @@
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.base_factory import BaseFactory
-from super_gradients.training.metrics import METRICS
+from super_gradients.common.registry.registry import METRICS
 
 
 
 
 class MetricsFactory(BaseFactory):
 class MetricsFactory(BaseFactory):
Discard
@@ -1,5 +1,5 @@
 from super_gradients.common.factories.type_factory import TypeFactory
 from super_gradients.common.factories.type_factory import TypeFactory
-from super_gradients.training.utils.optimizers import OPTIMIZERS
+from super_gradients.common.registry.registry import OPTIMIZERS
 
 
 
 
 class OptimizersTypeFactory(TypeFactory):
 class OptimizersTypeFactory(TypeFactory):
Discard
@@ -1,7 +1,7 @@
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.base_factory import BaseFactory
-from super_gradients.training import pre_launch_callbacks
+from super_gradients.common.registry.registry import ALL_PRE_LAUNCH_CALLBACKS
 
 
 
 
 class PreLaunchCallbacksFactory(BaseFactory):
 class PreLaunchCallbacksFactory(BaseFactory):
     def __init__(self):
     def __init__(self):
-        super().__init__(pre_launch_callbacks.ALL_PRE_LAUNCH_CALLBACKS)
+        super().__init__(ALL_PRE_LAUNCH_CALLBACKS)
Discard
@@ -1,5 +1,5 @@
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.base_factory import BaseFactory
-from super_gradients.training.datasets.samplers import SAMPLERS
+from super_gradients.common.registry.registry import SAMPLERS
 
 
 
 
 class SamplersFactory(BaseFactory):
 class SamplersFactory(BaseFactory):
Discard
@@ -1,5 +1,5 @@
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.base_factory import BaseFactory
-from super_gradients.training.datasets.all_target_generators import ALL_TARGET_GENERATORS
+from super_gradients.common.registry.registry import ALL_TARGET_GENERATORS
 
 
 
 
 class TargetGeneratorsFactory(BaseFactory):
 class TargetGeneratorsFactory(BaseFactory):
Discard
@@ -4,7 +4,9 @@ from omegaconf import ListConfig
 
 
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.list_factory import ListFactory
 from super_gradients.common.factories.list_factory import ListFactory
-from super_gradients.training.transforms import TRANSFORMS, ALBUMENTATIONS_TRANSFORMS, ALBUMENTATIONS_COMP_TRANSFORMS, imported_albumentations_failure
+from super_gradients.common.registry.registry import TRANSFORMS
+from super_gradients.common.registry.albumentation import ALBUMENTATIONS_TRANSFORMS, ALBUMENTATIONS_COMP_TRANSFORMS
+from super_gradients.common.registry.albumentation import imported_albumentations_failure
 from super_gradients.training.transforms.pipeline_adaptors import AlbumentationsAdaptor
 from super_gradients.training.transforms.pipeline_adaptors import AlbumentationsAdaptor
 
 
 
 
Discard
@@ -126,6 +126,7 @@ class Optimizers:
     RMS_PROP = "RMSprop"
     RMS_PROP = "RMSprop"
     RMS_PROP_TF = "RMSpropTF"
     RMS_PROP_TF = "RMSpropTF"
     LAMB = "Lamb"
     LAMB = "Lamb"
+    LION = "Lion"
 
 
 
 
 class Callbacks:
 class Callbacks:
@@ -311,3 +312,80 @@ class ConcatenatedTensorFormats:
     LABEL_NORMALIZED_XYXY = "LABEL_NORMALIZED_XYXY"
     LABEL_NORMALIZED_XYXY = "LABEL_NORMALIZED_XYXY"
     LABEL_NORMALIZED_XYWH = "LABEL_NORMALIZED_XYWH"
     LABEL_NORMALIZED_XYWH = "LABEL_NORMALIZED_XYWH"
     LABEL_NORMALIZED_CXCYWH = "LABEL_NORMALIZED_CXCYWH"
     LABEL_NORMALIZED_CXCYWH = "LABEL_NORMALIZED_CXCYWH"
+
+
+class Dataloaders:
+    COCO2017_TRAIN = "coco2017_train"
+    COCO2017_VAL = "coco2017_val"
+    COCO2017_TRAIN_YOLOX = "coco2017_train_yolox"
+    COCO2017_VAL_YOLOX = "coco2017_val_yolox"
+    COCO2017_TRAIN_DECIYOLO = "coco2017_train_deci_yolo"
+    COCO2017_VAL_DECIYOLO = "coco2017_val_deci_yolo"
+    COCO2017_TRAIN_PPYOLOE = "coco2017_train_ppyoloe"
+    COCO2017_VAL_PPYOLOE = "coco2017_val_ppyoloe"
+    COCO2017_TRAIN_SSD_LITE_MOBILENET_V2 = "coco2017_train_ssd_lite_mobilenet_v2"
+    COCO2017_VAL_SSD_LITE_MOBILENET_V2 = "coco2017_val_ssd_lite_mobilenet_v2"
+    COCO2017_POSE_TRAIN = "coco2017_pose_train"
+    COCO2017_POSE_VAL = "coco2017_pose_val"
+    IMAGENET_TRAIN = "imagenet_train"
+    IMAGENET_VAL = "imagenet_val"
+    IMAGENET_EFFICIENTNET_TRAIN = "imagenet_efficientnet_train"
+    IMAGENET_EFFICIENTNET_VAL = "imagenet_efficientnet_val"
+    IMAGENET_MOBILENETV2_TRAIN = "imagenet_mobilenetv2_train"
+    IMAGENET_MOBILENETV2_VAL = "imagenet_mobilenetv2_val"
+    IMAGENET_MOBILENETV3_TRAIN = "imagenet_mobilenetv3_train"
+    IMAGENET_MOBILENETV3_VAL = "imagenet_mobilenetv3_val"
+    IMAGENET_REGNETY_TRAIN = "imagenet_regnetY_train"
+    IMAGENET_REGNETY_VAL = "imagenet_regnetY_val"
+    IMAGENET_RESNET50_TRAIN = "imagenet_resnet50_train"
+    IMAGENET_RESNET50_VAL = "imagenet_resnet50_val"
+    IMAGENET_RESNET50_KD_TRAIN = "imagenet_resnet50_kd_train"
+    IMAGENET_RESNET50_KD_VAL = "imagenet_resnet50_kd_val"
+    IMAGENET_VIT_BASE_TRAIN = "imagenet_vit_base_train"
+    IMAGENET_VIT_BASE_VAL = "imagenet_vit_base_val"
+    TINY_IMAGENET_TRAIN = "tiny_imagenet_train"
+    TINY_IMAGENET_VAL = "tiny_imagenet_val"
+    CIFAR10_TRAIN = "cifar10_train"
+    CIFAR10_VAL = "cifar10_val"
+    CIFAR100_TRAIN = "cifar100_train"
+    CIFAR100_VAL = "cifar100_val"
+    CITYSCAPES_TRAIN = "cityscapes_train"
+    CITYSCAPES_VAL = "cityscapes_val"
+    CITYSCAPES_STDC_SEG50_TRAIN = "cityscapes_stdc_seg50_train"
+    CITYSCAPES_STDC_SEG50_VAL = "cityscapes_stdc_seg50_val"
+    CITYSCAPES_STDC_SEG75_TRAIN = "cityscapes_stdc_seg75_train"
+    CITYSCAPES_STDC_SEG75_VAL = "cityscapes_stdc_seg75_val"
+    CITYSCAPES_REGSEG48_TRAIN = "cityscapes_regseg48_train"
+    CITYSCAPES_REGSEG48_VAL = "cityscapes_regseg48_val"
+    CITYSCAPES_DDRNET_TRAIN = "cityscapes_ddrnet_train"
+    CITYSCAPES_DDRNET_VAL = "cityscapes_ddrnet_val"
+    COCO_SEGMENTATION_TRAIN = "coco_segmentation_train"
+    COCO_SEGMENTATION_VAL = "coco_segmentation_val"
+    MAPILLARY_TRAIN = "mapillary_train"
+    MAPILLARY_VAL = "mapillary_val"
+    PASCAL_AUG_SEGMENTATION_TRAIN = "pascal_aug_segmentation_train"
+    PASCAL_AUG_SEGMENTATION_VAL = "pascal_aug_segmentation_val"
+    PASCAL_VOC_SEGMENTATION_TRAIN = "pascal_voc_segmentation_train"
+    PASCAL_VOC_SEGMENTATION_VAL = "pascal_voc_segmentation_val"
+    SUPERVISELY_PERSONS_TRAIN = "supervisely_persons_train"
+    SUPERVISELY_PERSONS_VAL = "supervisely_persons_val"
+    PASCAL_VOC_DETECTION_TRAIN = "pascal_voc_detection_train"
+    PASCAL_VOC_DETECTION_VAL = "pascal_voc_detection_val"
+
+
+class Datasets:
+    CIFAR_10 = "Cifar10"
+    CIFAR_100 = "Cifar100"
+    IMAGENET_DATASET = "ImageNetDataset"
+    COCO_DETECTION_DATASET = "COCODetectionDataset"
+    DETECTION_DATASET = "DetectionDataset"
+    PASCAL_VOC_DETECTION_DATASET = "PascalVOCDetectionDataset"
+    SEGMENTATION_DATASET = "SegmentationDataSet"
+    COCO_SEGMENTATION_DATASET = "CoCoSegmentationDataSet"
+    PASCAL_AUG_2012_SEGMENTATION_DATASET = "PascalAUG2012SegmentationDataSet"
+    PASCAL_VOC_2012_SEGMENTATION_DATASET = "PascalVOC2012SegmentationDataSet"
+    CITYSCAPES_DATASET = "CityscapesDataset"
+    MAPILLARY_DATASET = "MapillaryDataset"
+    SUPERVISELY_PERSONS_DATASET = "SuperviselyPersonsDataset"
+    PASCAL_VOC_AND_AUG_UNIFIED_DATASET = "PascalVOCAndAUGUnifiedDataset"
+    COCO_KEY_POINTS_DATASET = "COCOKeypointsDataset"
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
  1. import importlib
  2. import inspect
  3. from super_gradients.common.abstractions.abstract_logger import get_logger
  4. logger = get_logger(__name__)
  5. try:
  6. from albumentations import BasicTransform, BaseCompose
  7. except (ImportError, NameError, ModuleNotFoundError) as import_err:
  8. logger.debug("Failed to import pytorch_quantization")
  9. imported_albumentations_failure = import_err
  10. if imported_albumentations_failure is None:
  11. ALBUMENTATIONS_TRANSFORMS = {
  12. name: cls for name, cls in inspect.getmembers(importlib.import_module("albumentations"), inspect.isclass) if issubclass(cls, BasicTransform)
  13. }
  14. ALBUMENTATIONS_TRANSFORMS.update(
  15. {name: cls for name, cls in inspect.getmembers(importlib.import_module("albumentations.pytorch"), inspect.isclass) if issubclass(cls, BasicTransform)}
  16. )
  17. ALBUMENTATIONS_COMP_TRANSFORMS = {
  18. name: cls
  19. for name, cls in inspect.getmembers(importlib.import_module("albumentations.core.composition"), inspect.isclass)
  20. if issubclass(cls, BaseCompose)
  21. }
  22. ALBUMENTATIONS_TRANSFORMS.update(ALBUMENTATIONS_COMP_TRANSFORMS)
  23. else:
  24. ALBUMENTATIONS_TRANSFORMS = None
  25. ALBUMENTATIONS_COMP_TRANSFORMS = None
Discard
@@ -1,25 +1,11 @@
 import inspect
 import inspect
 from typing import Callable, Dict, Optional
 from typing import Callable, Dict, Optional
 
 
-from torch import nn
-
-from super_gradients.common import object_names
-from super_gradients.training.utils.callbacks import LR_SCHEDULERS_CLS_DICT
-from super_gradients.common.sg_loggers import SG_LOGGERS
-from super_gradients.training.dataloaders.dataloaders import ALL_DATALOADERS
-from super_gradients.training.models.all_architectures import ARCHITECTURES
-from super_gradients.training.metrics.all_metrics import METRICS
-from super_gradients.modules.all_detection_modules import ALL_DETECTION_MODULES
-from super_gradients.training.utils.callbacks.all_callbacks import CALLBACKS
-from super_gradients.training.transforms.all_transforms import TRANSFORMS
-from super_gradients.training.datasets.all_datasets import ALL_DATASETS
-from super_gradients.training.pre_launch_callbacks import ALL_PRE_LAUNCH_CALLBACKS
-from super_gradients.training.models.segmentation_models.unet.unet_encoder import BACKBONE_STAGES
-from super_gradients.training.models.segmentation_models.unet.unet_decoder import UP_FUSE_BLOCKS
-from super_gradients.training.datasets.all_target_generators import ALL_TARGET_GENERATORS
-from super_gradients.training.datasets.all_collate_functions import ALL_COLLATE_FUNCTIONS
-from super_gradients.training.datasets.samplers.all_samplers import SAMPLERS
-from super_gradients.training.utils.optimizers import OPTIMIZERS
+import torch
+from torch import nn, optim
+import torchvision
+
+from super_gradients.common.object_names import Losses, Transforms, Samplers, Optimizers
 
 
 
 
 def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
 def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
@@ -54,23 +40,108 @@ def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
     return register
     return register
 
 
 
 
+ARCHITECTURES = {}
 register_model = create_register_decorator(registry=ARCHITECTURES)
 register_model = create_register_decorator(registry=ARCHITECTURES)
+
+KD_ARCHITECTURES = {}
+register_kd_model = create_register_decorator(registry=KD_ARCHITECTURES)
+
+ALL_DETECTION_MODULES = {}
 register_detection_module = create_register_decorator(registry=ALL_DETECTION_MODULES)
 register_detection_module = create_register_decorator(registry=ALL_DETECTION_MODULES)
+
+METRICS = {}
 register_metric = create_register_decorator(registry=METRICS)
 register_metric = create_register_decorator(registry=METRICS)
 
 
-LOSSES = {object_names.Losses.MSE: nn.MSELoss}
+LOSSES = {Losses.MSE: nn.MSELoss}
 register_loss = create_register_decorator(registry=LOSSES)
 register_loss = create_register_decorator(registry=LOSSES)
 
 
+
+ALL_DATALOADERS = {}
 register_dataloader = create_register_decorator(registry=ALL_DATALOADERS)
 register_dataloader = create_register_decorator(registry=ALL_DATALOADERS)
+
+CALLBACKS = {}
 register_callback = create_register_decorator(registry=CALLBACKS)
 register_callback = create_register_decorator(registry=CALLBACKS)
+
+TRANSFORMS = {
+    Transforms.Compose: torchvision.transforms.Compose,
+    Transforms.ToTensor: torchvision.transforms.ToTensor,
+    Transforms.PILToTensor: torchvision.transforms.PILToTensor,
+    Transforms.ConvertImageDtype: torchvision.transforms.ConvertImageDtype,
+    Transforms.ToPILImage: torchvision.transforms.ToPILImage,
+    Transforms.Normalize: torchvision.transforms.Normalize,
+    Transforms.Resize: torchvision.transforms.Resize,
+    Transforms.CenterCrop: torchvision.transforms.CenterCrop,
+    Transforms.Pad: torchvision.transforms.Pad,
+    Transforms.Lambda: torchvision.transforms.Lambda,
+    Transforms.RandomApply: torchvision.transforms.RandomApply,
+    Transforms.RandomChoice: torchvision.transforms.RandomChoice,
+    Transforms.RandomOrder: torchvision.transforms.RandomOrder,
+    Transforms.RandomCrop: torchvision.transforms.RandomCrop,
+    Transforms.RandomHorizontalFlip: torchvision.transforms.RandomHorizontalFlip,
+    Transforms.RandomVerticalFlip: torchvision.transforms.RandomVerticalFlip,
+    Transforms.RandomResizedCrop: torchvision.transforms.RandomResizedCrop,
+    Transforms.FiveCrop: torchvision.transforms.FiveCrop,
+    Transforms.TenCrop: torchvision.transforms.TenCrop,
+    Transforms.LinearTransformation: torchvision.transforms.LinearTransformation,
+    Transforms.ColorJitter: torchvision.transforms.ColorJitter,
+    Transforms.RandomRotation: torchvision.transforms.RandomRotation,
+    Transforms.RandomAffine: torchvision.transforms.RandomAffine,
+    Transforms.Grayscale: torchvision.transforms.Grayscale,
+    Transforms.RandomGrayscale: torchvision.transforms.RandomGrayscale,
+    Transforms.RandomPerspective: torchvision.transforms.RandomPerspective,
+    Transforms.RandomErasing: torchvision.transforms.RandomErasing,
+    Transforms.GaussianBlur: torchvision.transforms.GaussianBlur,
+    Transforms.InterpolationMode: torchvision.transforms.InterpolationMode,
+    Transforms.RandomInvert: torchvision.transforms.RandomInvert,
+    Transforms.RandomPosterize: torchvision.transforms.RandomPosterize,
+    Transforms.RandomSolarize: torchvision.transforms.RandomSolarize,
+    Transforms.RandomAdjustSharpness: torchvision.transforms.RandomAdjustSharpness,
+    Transforms.RandomAutocontrast: torchvision.transforms.RandomAutocontrast,
+    Transforms.RandomEqualize: torchvision.transforms.RandomEqualize,
+}
 register_transform = create_register_decorator(registry=TRANSFORMS)
 register_transform = create_register_decorator(registry=TRANSFORMS)
+
+ALL_DATASETS = {}
 register_dataset = create_register_decorator(registry=ALL_DATASETS)
 register_dataset = create_register_decorator(registry=ALL_DATASETS)
+
+ALL_PRE_LAUNCH_CALLBACKS = {}
 register_pre_launch_callback = create_register_decorator(registry=ALL_PRE_LAUNCH_CALLBACKS)
 register_pre_launch_callback = create_register_decorator(registry=ALL_PRE_LAUNCH_CALLBACKS)
+
+BACKBONE_STAGES = {}
 register_unet_backbone_stage = create_register_decorator(registry=BACKBONE_STAGES)
 register_unet_backbone_stage = create_register_decorator(registry=BACKBONE_STAGES)
+
+UP_FUSE_BLOCKS = {}
 register_unet_up_block = create_register_decorator(registry=UP_FUSE_BLOCKS)
 register_unet_up_block = create_register_decorator(registry=UP_FUSE_BLOCKS)
+
+ALL_TARGET_GENERATORS = {}
 register_target_generator = create_register_decorator(registry=ALL_TARGET_GENERATORS)
 register_target_generator = create_register_decorator(registry=ALL_TARGET_GENERATORS)
+
+LR_SCHEDULERS_CLS_DICT = {}
 register_lr_scheduler = create_register_decorator(registry=LR_SCHEDULERS_CLS_DICT)
 register_lr_scheduler = create_register_decorator(registry=LR_SCHEDULERS_CLS_DICT)
+
+LR_WARMUP_CLS_DICT = {}
+register_lr_warmup = create_register_decorator(registry=LR_WARMUP_CLS_DICT)
+
+SG_LOGGERS = {}
 register_sg_logger = create_register_decorator(registry=SG_LOGGERS)
 register_sg_logger = create_register_decorator(registry=SG_LOGGERS)
+
+ALL_COLLATE_FUNCTIONS = {}
 register_collate_function = create_register_decorator(registry=ALL_COLLATE_FUNCTIONS)
 register_collate_function = create_register_decorator(registry=ALL_COLLATE_FUNCTIONS)
+
+SAMPLERS = {
+    Samplers.DISTRIBUTED: torch.utils.data.DistributedSampler,
+    Samplers.SEQUENTIAL: torch.utils.data.SequentialSampler,
+    Samplers.SUBSET_RANDOM: torch.utils.data.SubsetRandomSampler,
+    Samplers.RANDOM: torch.utils.data.RandomSampler,
+    Samplers.WEIGHTED_RANDOM: torch.utils.data.WeightedRandomSampler,
+}
 register_sampler = create_register_decorator(registry=SAMPLERS)
 register_sampler = create_register_decorator(registry=SAMPLERS)
+
+
+OPTIMIZERS = {
+    Optimizers.SGD: optim.SGD,
+    Optimizers.ADAM: optim.Adam,
+    Optimizers.ADAMW: optim.AdamW,
+    Optimizers.RMS_PROP: optim.RMSprop,
+}
 register_optimizer = create_register_decorator(registry=OPTIMIZERS)
 register_optimizer = create_register_decorator(registry=OPTIMIZERS)
Discard
@@ -1,11 +1,6 @@
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
+from super_gradients.common.sg_loggers.clearml_sg_logger import ClearMLSGLogger
 from super_gradients.common.sg_loggers.deci_platform_sg_logger import DeciPlatformSGLogger
 from super_gradients.common.sg_loggers.deci_platform_sg_logger import DeciPlatformSGLogger
 from super_gradients.common.sg_loggers.wandb_sg_logger import WandBSGLogger
 from super_gradients.common.sg_loggers.wandb_sg_logger import WandBSGLogger
-from super_gradients.common.sg_loggers.clearml_sg_logger import ClearMLSGLogger
 
 
-SG_LOGGERS = {
-    "base_sg_logger": BaseSGLogger,
-    "deci_platform_sg_logger": DeciPlatformSGLogger,
-    "wandb_sg_logger": WandBSGLogger,
-    "clearml_sg_logger": ClearMLSGLogger,
-}
+__all__ = ["BaseSGLogger", "ClearMLSGLogger", "DeciPlatformSGLogger", "WandBSGLogger"]
Discard
@@ -9,6 +9,8 @@ import numpy as np
 import psutil
 import psutil
 import torch
 import torch
 from PIL import Image
 from PIL import Image
+
+from super_gradients.common.registry.registry import register_sg_logger
 from super_gradients.common.data_interface.adnn_model_repository_data_interface import ADNNModelRepositoryDataInterfaces
 from super_gradients.common.data_interface.adnn_model_repository_data_interface import ADNNModelRepositoryDataInterfaces
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.decorators.code_save_decorator import saved_codes
 from super_gradients.common.decorators.code_save_decorator import saved_codes
@@ -27,6 +29,7 @@ LOGGER_LOGS_PREFIX = "logs"
 CONSOLE_LOGS_PREFIX = "console"
 CONSOLE_LOGS_PREFIX = "console"
 
 
 
 
+@register_sg_logger("base_sg_logger")
 class BaseSGLogger(AbstractSGLogger):
 class BaseSGLogger(AbstractSGLogger):
     def __init__(
     def __init__(
         self,
         self,
Discard
@@ -6,8 +6,10 @@ import numpy as np
 from PIL import Image
 from PIL import Image
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt
 import torch
 import torch
-from super_gradients.common.abstractions.abstract_logger import get_logger
 
 
+
+from super_gradients.common.abstractions.abstract_logger import get_logger
+from super_gradients.common.registry.registry import register_sg_logger
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
 from super_gradients.common.environment.ddp_utils import multi_process_safe
 from super_gradients.common.environment.ddp_utils import multi_process_safe
 
 
@@ -22,6 +24,7 @@ except (ImportError, NameError, ModuleNotFoundError) as import_err:
     _imported_clear_ml_failure = import_err
     _imported_clear_ml_failure = import_err
 
 
 
 
+@register_sg_logger("clearml_sg_logger")
 class ClearMLSGLogger(BaseSGLogger):
 class ClearMLSGLogger(BaseSGLogger):
     def __init__(
     def __init__(
         self,
         self,
Discard
@@ -3,6 +3,7 @@ import io
 from contextlib import contextmanager
 from contextlib import contextmanager
 from typing import Optional
 from typing import Optional
 
 
+from super_gradients.common.registry.registry import register_sg_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger, EXPERIMENT_LOGS_PREFIX, LOGGER_LOGS_PREFIX, CONSOLE_LOGS_PREFIX
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger, EXPERIMENT_LOGS_PREFIX, LOGGER_LOGS_PREFIX, CONSOLE_LOGS_PREFIX
 from super_gradients.common.environment.ddp_utils import multi_process_safe
 from super_gradients.common.environment.ddp_utils import multi_process_safe
@@ -15,6 +16,7 @@ logger = get_logger(__name__)
 TENSORBOARD_EVENTS_PREFIX = "events.out.tfevents"
 TENSORBOARD_EVENTS_PREFIX = "events.out.tfevents"
 
 
 
 
+@register_sg_logger("deci_platform_sg_logger")
 class DeciPlatformSGLogger(BaseSGLogger):
 class DeciPlatformSGLogger(BaseSGLogger):
     """Logger responsible to push logs and tensorboard artifacts to Deci platform."""
     """Logger responsible to push logs and tensorboard artifacts to Deci platform."""
 
 
Discard
@@ -7,6 +7,7 @@ from PIL import Image
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt
 import torch
 import torch
 
 
+from super_gradients.common.registry.registry import register_sg_logger
 from super_gradients.common.environment.env_variables import env_variables
 from super_gradients.common.environment.env_variables import env_variables
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 
 
@@ -25,6 +26,7 @@ WANDB_ID_PREFIX = "wandb_id."
 WANDB_INCLUDE_FILE_NAME = ".wandbinclude"
 WANDB_INCLUDE_FILE_NAME = ".wandbinclude"
 
 
 
 
+@register_sg_logger("wandb_sg_logger")
 class WandBSGLogger(BaseSGLogger):
 class WandBSGLogger(BaseSGLogger):
     def __init__(
     def __init__(
         self,
         self,
Discard
@@ -1,15 +1,23 @@
-from .conv_bn_act_block import ConvBNAct
-from .conv_bn_relu_block import ConvBNReLU
-from .repvgg_block import RepVGGBlock
-from .qarepvgg_block import QARepVGGBlock
-from .se_blocks import SEBlock, EffectiveSEBlock
-from .skip_connections import Residual, SkipConnection, CrossModelSkipConnection, BackboneInternalSkipConnection, HeadInternalSkipConnection
-from super_gradients.common.abstractions.abstract_logger import get_logger
+from super_gradients.modules.anti_alias import AntiAliasDownsample
 from super_gradients.modules.pose_estimation_modules import LightweightDEKRHead
 from super_gradients.modules.pose_estimation_modules import LightweightDEKRHead
-from .all_detection_modules import ALL_DETECTION_MODULES
+from super_gradients.modules.conv_bn_act_block import ConvBNAct
+from super_gradients.modules.conv_bn_relu_block import ConvBNReLU
+from super_gradients.modules.repvgg_block import RepVGGBlock
+from super_gradients.modules.qarepvgg_block import QARepVGGBlock
+from super_gradients.modules.se_blocks import SEBlock, EffectiveSEBlock
+from super_gradients.modules.skip_connections import (
+    Residual,
+    SkipConnection,
+    CrossModelSkipConnection,
+    BackboneInternalSkipConnection,
+    HeadInternalSkipConnection,
+)
+from super_gradients.common.abstractions.abstract_logger import get_logger
+from super_gradients.common.registry.registry import ALL_DETECTION_MODULES
 
 
 __all__ = [
 __all__ = [
     "ALL_DETECTION_MODULES",
     "ALL_DETECTION_MODULES",
+    "AntiAliasDownsample",
     "ConvBNAct",
     "ConvBNAct",
     "ConvBNReLU",
     "ConvBNReLU",
     "RepVGGBlock",
     "RepVGGBlock",
Discard
@@ -1,29 +0,0 @@
-from super_gradients.modules.pose_estimation_modules import LightweightDEKRHead
-from super_gradients.modules.detection_modules import (
-    MobileNetV1Backbone,
-    MobileNetV2Backbone,
-    SSDInvertedResidualNeck,
-    SSDBottleneckNeck,
-    NStageBackbone,
-    PANNeck,
-    NHeads,
-    SSDHead,
-)
-from super_gradients.training.models.detection_models.csp_resnet import CSPResNetBackbone
-from super_gradients.training.models.detection_models.pp_yolo_e.pan import CustomCSPPAN
-from super_gradients.training.models.segmentation_models.ddrnet_backbones import DDRNet39Backbone
-
-ALL_DETECTION_MODULES = {
-    "MobileNetV1Backbone": MobileNetV1Backbone,
-    "MobileNetV2Backbone": MobileNetV2Backbone,
-    "SSDInvertedResidualNeck": SSDInvertedResidualNeck,
-    "SSDBottleneckNeck": SSDBottleneckNeck,
-    "SSDHead": SSDHead,
-    "NStageBackbone": NStageBackbone,
-    "PANNeck": PANNeck,
-    "NHeads": NHeads,
-    "LightweightDEKRHead": LightweightDEKRHead,
-    "CustomCSPPAN": CustomCSPPAN,
-    "CSPResNetBackbone": CSPResNetBackbone,
-    "DDRNet39Backbone": DDRNet39Backbone,
-}
Discard
@@ -6,6 +6,7 @@ from torch import nn
 from omegaconf.listconfig import ListConfig
 from omegaconf.listconfig import ListConfig
 from omegaconf import DictConfig
 from omegaconf import DictConfig
 
 
+from super_gradients.common.registry.registry import register_detection_module
 from super_gradients.training.models.classification_models.mobilenetv2 import InvertedResidual
 from super_gradients.training.models.classification_models.mobilenetv2 import InvertedResidual
 from super_gradients.training.utils.utils import HpmStruct
 from super_gradients.training.utils.utils import HpmStruct
 from super_gradients.training.models import MobileNet, MobileNetV2
 from super_gradients.training.models import MobileNet, MobileNetV2
@@ -33,6 +34,7 @@ class BaseDetectionModule(nn.Module, ABC):
         raise NotImplementedError()
         raise NotImplementedError()
 
 
 
 
+@register_detection_module()
 class NStageBackbone(BaseDetectionModule):
 class NStageBackbone(BaseDetectionModule):
     """
     """
     A backbone with a stem -> N stages -> context module
     A backbone with a stem -> N stages -> context module
@@ -89,6 +91,7 @@ class NStageBackbone(BaseDetectionModule):
         return outputs
         return outputs
 
 
 
 
+@register_detection_module()
 class PANNeck(BaseDetectionModule):
 class PANNeck(BaseDetectionModule):
     """
     """
     A PAN (path aggregation network) neck with 4 stages (2 up-sampling and 2 down-sampling stages)
     A PAN (path aggregation network) neck with 4 stages (2 up-sampling and 2 down-sampling stages)
@@ -135,6 +138,7 @@ class PANNeck(BaseDetectionModule):
         return p3, p4, p5
         return p3, p4, p5
 
 
 
 
+@register_detection_module()
 class NHeads(BaseDetectionModule):
 class NHeads(BaseDetectionModule):
     """
     """
     Apply N heads in parallel and combine predictions into the shape expected by SG detection losses
     Apply N heads in parallel and combine predictions into the shape expected by SG detection losses
@@ -197,6 +201,7 @@ class MultiOutputBackbone(BaseDetectionModule):
         return self.multi_output_backbone(x)
         return self.multi_output_backbone(x)
 
 
 
 
+@register_detection_module()
 class MobileNetV1Backbone(MultiOutputBackbone):
 class MobileNetV1Backbone(MultiOutputBackbone):
     """MobileNetV1 backbone with an option to return output of any layer"""
     """MobileNetV1 backbone with an option to return output of any layer"""
 
 
@@ -205,6 +210,7 @@ class MobileNetV1Backbone(MultiOutputBackbone):
         super().__init__(in_channels, backbone, out_layers)
         super().__init__(in_channels, backbone, out_layers)
 
 
 
 
+@register_detection_module()
 class MobileNetV2Backbone(MultiOutputBackbone):
 class MobileNetV2Backbone(MultiOutputBackbone):
     """MobileNetV2 backbone with an option to return output of any layer"""
     """MobileNetV2 backbone with an option to return output of any layer"""
 
 
@@ -254,6 +260,7 @@ class SSDNeck(BaseDetectionModule, ABC):
         return outputs
         return outputs
 
 
 
 
+@register_detection_module()
 class SSDInvertedResidualNeck(SSDNeck):
 class SSDInvertedResidualNeck(SSDNeck):
     """
     """
     Consecutive InvertedResidual blocks each starting with stride 2
     Consecutive InvertedResidual blocks each starting with stride 2
@@ -268,6 +275,7 @@ class SSDInvertedResidualNeck(SSDNeck):
         return neck_blocks
         return neck_blocks
 
 
 
 
+@register_detection_module()
 class SSDBottleneckNeck(SSDNeck):
 class SSDBottleneckNeck(SSDNeck):
     """
     """
     Consecutive bottleneck blocks
     Consecutive bottleneck blocks
@@ -305,6 +313,7 @@ def SeperableConv2d(in_channels: int, out_channels: int, kernel_size: int = 1, s
     )
     )
 
 
 
 
+@register_detection_module()
 class SSDHead(BaseDetectionModule):
 class SSDHead(BaseDetectionModule):
     """
     """
     A one-layer conv head attached to each input feature map.
     A one-layer conv head attached to each input feature map.
Discard
@@ -6,8 +6,10 @@ from super_gradients.common.factories.activations_type_factory import Activation
 from torch import nn, Tensor
 from torch import nn, Tensor
 
 
 from super_gradients.modules.detection_modules import BaseDetectionModule
 from super_gradients.modules.detection_modules import BaseDetectionModule
+from super_gradients.common.registry.registry import register_detection_module
 
 
 
 
+@register_detection_module()
 class LightweightDEKRHead(BaseDetectionModule):
 class LightweightDEKRHead(BaseDetectionModule):
     """
     """
     Prediction head for pose estimation task that mimics approach from DEKR (https://arxiv.org/abs/2104.02300) paper,
     Prediction head for pose estimation task that mimics approach from DEKR (https://arxiv.org/abs/2104.02300) paper,
Discard
@@ -1,7 +1,6 @@
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 import super_gradients.training.utils.distributed_training_utils as distributed_training_utils
 import super_gradients.training.utils.distributed_training_utils as distributed_training_utils
 from super_gradients.training.datasets import datasets_utils, DataAugmentation
 from super_gradients.training.datasets import datasets_utils, DataAugmentation
-from super_gradients.training.models import ARCHITECTURES
 from super_gradients.training.sg_trainer import Trainer
 from super_gradients.training.sg_trainer import Trainer
 from super_gradients.training.kd_trainer import KDTrainer
 from super_gradients.training.kd_trainer import KDTrainer
 from super_gradients.training.qat_trainer import QATTrainer
 from super_gradients.training.qat_trainer import QATTrainer
@@ -11,7 +10,6 @@ __all__ = [
     "distributed_training_utils",
     "distributed_training_utils",
     "datasets_utils",
     "datasets_utils",
     "DataAugmentation",
     "DataAugmentation",
-    "ARCHITECTURES",
     "Trainer",
     "Trainer",
     "KDTrainer",
     "KDTrainer",
     "QATTrainer",
     "QATTrainer",
Discard
@@ -2,12 +2,16 @@ from typing import Dict
 
 
 import hydra
 import hydra
 import numpy as np
 import numpy as np
-import super_gradients
 import torch
 import torch
+from torch.utils.data import BatchSampler, DataLoader, TensorDataset
+
+import super_gradients
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
+from super_gradients.common.registry.registry import register_dataloader, ALL_DATALOADERS
 from super_gradients.common.factories.collate_functions_factory import CollateFunctionsFactory
 from super_gradients.common.factories.collate_functions_factory import CollateFunctionsFactory
 from super_gradients.common.factories.datasets_factory import DatasetsFactory
 from super_gradients.common.factories.datasets_factory import DatasetsFactory
 from super_gradients.common.factories.samplers_factory import SamplersFactory
 from super_gradients.common.factories.samplers_factory import SamplersFactory
+from super_gradients.common.object_names import Dataloaders
 from super_gradients.training.datasets import ImageNetDataset
 from super_gradients.training.datasets import ImageNetDataset
 from super_gradients.training.datasets.classification_datasets.cifar import (
 from super_gradients.training.datasets.classification_datasets.cifar import (
     Cifar10,
     Cifar10,
@@ -34,7 +38,7 @@ from super_gradients.training.utils.distributed_training_utils import (
 )
 )
 from super_gradients.training.utils.utils import override_default_params_without_nones
 from super_gradients.training.utils.utils import override_default_params_without_nones
 from super_gradients.common.environment.cfg_utils import load_dataset_params
 from super_gradients.common.environment.cfg_utils import load_dataset_params
-from torch.utils.data import BatchSampler, DataLoader, TensorDataset
+
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -131,6 +135,7 @@ def _instantiate_sampler(dataset, dataloader_params):
     return dataloader_params
     return dataloader_params
 
 
 
 
+@register_dataloader(Dataloaders.COCO2017_TRAIN)
 def coco2017_train(dataset_params: Dict = None, dataloader_params: Dict = None):
 def coco2017_train(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="coco_detection_dataset_params",
         config_name="coco_detection_dataset_params",
@@ -141,6 +146,7 @@ def coco2017_train(dataset_params: Dict = None, dataloader_params: Dict = None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.COCO2017_VAL)
 def coco2017_val(dataset_params: Dict = None, dataloader_params: Dict = None):
 def coco2017_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="coco_detection_dataset_params",
         config_name="coco_detection_dataset_params",
@@ -151,6 +157,7 @@ def coco2017_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.COCO2017_TRAIN_DECIYOLO)
 def coco2017_train_deci_yolo(dataset_params: Dict = None, dataloader_params: Dict = None):
 def coco2017_train_deci_yolo(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="coco_detection_deci_yolo_dataset_params",
         config_name="coco_detection_deci_yolo_dataset_params",
@@ -161,6 +168,7 @@ def coco2017_train_deci_yolo(dataset_params: Dict = None, dataloader_params: Dic
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.COCO2017_VAL_DECIYOLO)
 def coco2017_val_deci_yolo(dataset_params: Dict = None, dataloader_params: Dict = None):
 def coco2017_val_deci_yolo(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="coco_detection_deci_yolo_dataset_params",
         config_name="coco_detection_deci_yolo_dataset_params",
@@ -171,6 +179,7 @@ def coco2017_val_deci_yolo(dataset_params: Dict = None, dataloader_params: Dict
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.COCO2017_TRAIN_PPYOLOE)
 def coco2017_train_ppyoloe(dataset_params: Dict = None, dataloader_params: Dict = None):
 def coco2017_train_ppyoloe(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="coco_detection_ppyoloe_dataset_params",
         config_name="coco_detection_ppyoloe_dataset_params",
@@ -181,6 +190,7 @@ def coco2017_train_ppyoloe(dataset_params: Dict = None, dataloader_params: Dict
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.COCO2017_VAL_PPYOLOE)
 def coco2017_val_ppyoloe(dataset_params: Dict = None, dataloader_params: Dict = None):
 def coco2017_val_ppyoloe(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="coco_detection_ppyoloe_dataset_params",
         config_name="coco_detection_ppyoloe_dataset_params",
@@ -191,14 +201,17 @@ def coco2017_val_ppyoloe(dataset_params: Dict = None, dataloader_params: Dict =
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.COCO2017_TRAIN_YOLOX)
 def coco2017_train_yolox(dataset_params: Dict = None, dataloader_params: Dict = None):
 def coco2017_train_yolox(dataset_params: Dict = None, dataloader_params: Dict = None):
     return coco2017_train(dataset_params, dataloader_params)
     return coco2017_train(dataset_params, dataloader_params)
 
 
 
 
+@register_dataloader(Dataloaders.COCO2017_VAL_YOLOX)
 def coco2017_val_yolox(dataset_params: Dict = None, dataloader_params: Dict = None):
 def coco2017_val_yolox(dataset_params: Dict = None, dataloader_params: Dict = None):
     return coco2017_val(dataset_params, dataloader_params)
     return coco2017_val(dataset_params, dataloader_params)
 
 
 
 
+@register_dataloader(Dataloaders.COCO2017_TRAIN_SSD_LITE_MOBILENET_V2)
 def coco2017_train_ssd_lite_mobilenet_v2(dataset_params: Dict = None, dataloader_params: Dict = None):
 def coco2017_train_ssd_lite_mobilenet_v2(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="coco_detection_ssd_lite_mobilenet_v2_dataset_params",
         config_name="coco_detection_ssd_lite_mobilenet_v2_dataset_params",
@@ -209,6 +222,7 @@ def coco2017_train_ssd_lite_mobilenet_v2(dataset_params: Dict = None, dataloader
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.COCO2017_VAL_SSD_LITE_MOBILENET_V2)
 def coco2017_val_ssd_lite_mobilenet_v2(dataset_params: Dict = None, dataloader_params: Dict = None):
 def coco2017_val_ssd_lite_mobilenet_v2(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="coco_detection_ssd_lite_mobilenet_v2_dataset_params",
         config_name="coco_detection_ssd_lite_mobilenet_v2_dataset_params",
@@ -219,6 +233,7 @@ def coco2017_val_ssd_lite_mobilenet_v2(dataset_params: Dict = None, dataloader_p
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.IMAGENET_TRAIN)
 def imagenet_train(dataset_params=None, dataloader_params=None, config_name="imagenet_dataset_params"):
 def imagenet_train(dataset_params=None, dataloader_params=None, config_name="imagenet_dataset_params"):
     return get_data_loader(
     return get_data_loader(
         config_name=config_name,
         config_name=config_name,
@@ -229,6 +244,7 @@ def imagenet_train(dataset_params=None, dataloader_params=None, config_name="ima
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.IMAGENET_VAL)
 def imagenet_val(dataset_params=None, dataloader_params=None, config_name="imagenet_dataset_params"):
 def imagenet_val(dataset_params=None, dataloader_params=None, config_name="imagenet_dataset_params"):
     return get_data_loader(
     return get_data_loader(
         config_name=config_name,
         config_name=config_name,
@@ -239,6 +255,7 @@ def imagenet_val(dataset_params=None, dataloader_params=None, config_name="image
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.IMAGENET_EFFICIENTNET_TRAIN)
 def imagenet_efficientnet_train(dataset_params=None, dataloader_params=None):
 def imagenet_efficientnet_train(dataset_params=None, dataloader_params=None):
     return imagenet_train(
     return imagenet_train(
         dataset_params,
         dataset_params,
@@ -247,6 +264,7 @@ def imagenet_efficientnet_train(dataset_params=None, dataloader_params=None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.IMAGENET_EFFICIENTNET_VAL)
 def imagenet_efficientnet_val(dataset_params=None, dataloader_params=None):
 def imagenet_efficientnet_val(dataset_params=None, dataloader_params=None):
     return imagenet_val(
     return imagenet_val(
         dataset_params,
         dataset_params,
@@ -255,6 +273,7 @@ def imagenet_efficientnet_val(dataset_params=None, dataloader_params=None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.IMAGENET_MOBILENETV2_TRAIN)
 def imagenet_mobilenetv2_train(dataset_params=None, dataloader_params=None):
 def imagenet_mobilenetv2_train(dataset_params=None, dataloader_params=None):
     return imagenet_train(
     return imagenet_train(
         dataset_params,
         dataset_params,
@@ -263,6 +282,7 @@ def imagenet_mobilenetv2_train(dataset_params=None, dataloader_params=None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.IMAGENET_MOBILENETV2_VAL)
 def imagenet_mobilenetv2_val(dataset_params=None, dataloader_params=None):
 def imagenet_mobilenetv2_val(dataset_params=None, dataloader_params=None):
     return imagenet_val(
     return imagenet_val(
         dataset_params,
         dataset_params,
@@ -271,6 +291,7 @@ def imagenet_mobilenetv2_val(dataset_params=None, dataloader_params=None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.IMAGENET_MOBILENETV3_TRAIN)
 def imagenet_mobilenetv3_train(dataset_params=None, dataloader_params=None):
 def imagenet_mobilenetv3_train(dataset_params=None, dataloader_params=None):
     return imagenet_train(
     return imagenet_train(
         dataset_params,
         dataset_params,
@@ -279,6 +300,7 @@ def imagenet_mobilenetv3_train(dataset_params=None, dataloader_params=None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.IMAGENET_MOBILENETV3_VAL)
 def imagenet_mobilenetv3_val(dataset_params=None, dataloader_params=None):
 def imagenet_mobilenetv3_val(dataset_params=None, dataloader_params=None):
     return imagenet_val(
     return imagenet_val(
         dataset_params,
         dataset_params,
@@ -287,14 +309,17 @@ def imagenet_mobilenetv3_val(dataset_params=None, dataloader_params=None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.IMAGENET_REGNETY_TRAIN)
 def imagenet_regnetY_train(dataset_params=None, dataloader_params=None):
 def imagenet_regnetY_train(dataset_params=None, dataloader_params=None):
     return imagenet_train(dataset_params, dataloader_params, config_name="imagenet_regnetY_dataset_params")
     return imagenet_train(dataset_params, dataloader_params, config_name="imagenet_regnetY_dataset_params")
 
 
 
 
+@register_dataloader(Dataloaders.IMAGENET_REGNETY_VAL)
 def imagenet_regnetY_val(dataset_params=None, dataloader_params=None):
 def imagenet_regnetY_val(dataset_params=None, dataloader_params=None):
     return imagenet_val(dataset_params, dataloader_params, config_name="imagenet_regnetY_dataset_params")
     return imagenet_val(dataset_params, dataloader_params, config_name="imagenet_regnetY_dataset_params")
 
 
 
 
+@register_dataloader(Dataloaders.IMAGENET_RESNET50_TRAIN)
 def imagenet_resnet50_train(dataset_params=None, dataloader_params=None):
 def imagenet_resnet50_train(dataset_params=None, dataloader_params=None):
     return imagenet_train(
     return imagenet_train(
         dataset_params,
         dataset_params,
@@ -303,6 +328,7 @@ def imagenet_resnet50_train(dataset_params=None, dataloader_params=None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.IMAGENET_RESNET50_VAL)
 def imagenet_resnet50_val(dataset_params=None, dataloader_params=None):
 def imagenet_resnet50_val(dataset_params=None, dataloader_params=None):
     return imagenet_val(
     return imagenet_val(
         dataset_params,
         dataset_params,
@@ -311,6 +337,7 @@ def imagenet_resnet50_val(dataset_params=None, dataloader_params=None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.IMAGENET_RESNET50_KD_TRAIN)
 def imagenet_resnet50_kd_train(dataset_params=None, dataloader_params=None):
 def imagenet_resnet50_kd_train(dataset_params=None, dataloader_params=None):
     return imagenet_train(
     return imagenet_train(
         dataset_params,
         dataset_params,
@@ -319,6 +346,7 @@ def imagenet_resnet50_kd_train(dataset_params=None, dataloader_params=None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.IMAGENET_RESNET50_KD_VAL)
 def imagenet_resnet50_kd_val(dataset_params=None, dataloader_params=None):
 def imagenet_resnet50_kd_val(dataset_params=None, dataloader_params=None):
     return imagenet_val(
     return imagenet_val(
         dataset_params,
         dataset_params,
@@ -327,6 +355,7 @@ def imagenet_resnet50_kd_val(dataset_params=None, dataloader_params=None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.IMAGENET_VIT_BASE_TRAIN)
 def imagenet_vit_base_train(dataset_params=None, dataloader_params=None):
 def imagenet_vit_base_train(dataset_params=None, dataloader_params=None):
     return imagenet_train(
     return imagenet_train(
         dataset_params,
         dataset_params,
@@ -335,6 +364,7 @@ def imagenet_vit_base_train(dataset_params=None, dataloader_params=None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.IMAGENET_VIT_BASE_VAL)
 def imagenet_vit_base_val(dataset_params=None, dataloader_params=None):
 def imagenet_vit_base_val(dataset_params=None, dataloader_params=None):
     return imagenet_val(
     return imagenet_val(
         dataset_params,
         dataset_params,
@@ -343,6 +373,7 @@ def imagenet_vit_base_val(dataset_params=None, dataloader_params=None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.TINY_IMAGENET_TRAIN)
 def tiny_imagenet_train(
 def tiny_imagenet_train(
     dataset_params=None,
     dataset_params=None,
     dataloader_params=None,
     dataloader_params=None,
@@ -357,6 +388,7 @@ def tiny_imagenet_train(
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.TINY_IMAGENET_VAL)
 def tiny_imagenet_val(
 def tiny_imagenet_val(
     dataset_params=None,
     dataset_params=None,
     dataloader_params=None,
     dataloader_params=None,
@@ -371,6 +403,7 @@ def tiny_imagenet_val(
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.CIFAR10_TRAIN)
 def cifar10_train(dataset_params: Dict = None, dataloader_params: Dict = None):
 def cifar10_train(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="cifar10_dataset_params",
         config_name="cifar10_dataset_params",
@@ -381,6 +414,7 @@ def cifar10_train(dataset_params: Dict = None, dataloader_params: Dict = None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.CIFAR10_VAL)
 def cifar10_val(dataset_params: Dict = None, dataloader_params: Dict = None):
 def cifar10_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="cifar10_dataset_params",
         config_name="cifar10_dataset_params",
@@ -391,6 +425,7 @@ def cifar10_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.CIFAR100_TRAIN)
 def cifar100_train(dataset_params: Dict = None, dataloader_params: Dict = None):
 def cifar100_train(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="cifar100_dataset_params",
         config_name="cifar100_dataset_params",
@@ -401,6 +436,7 @@ def cifar100_train(dataset_params: Dict = None, dataloader_params: Dict = None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.CIFAR100_VAL)
 def cifar100_val(dataset_params: Dict = None, dataloader_params: Dict = None):
 def cifar100_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="cifar100_dataset_params",
         config_name="cifar100_dataset_params",
@@ -435,6 +471,7 @@ def segmentation_test_dataloader(batch_size: int = 5, image_size: int = 512, dat
     return DataLoader(dataset=dataset, batch_size=batch_size)
     return DataLoader(dataset=dataset, batch_size=batch_size)
 
 
 
 
+@register_dataloader(Dataloaders.CITYSCAPES_TRAIN)
 def cityscapes_train(dataset_params: Dict = None, dataloader_params: Dict = None):
 def cityscapes_train(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="cityscapes_dataset_params",
         config_name="cityscapes_dataset_params",
@@ -445,6 +482,7 @@ def cityscapes_train(dataset_params: Dict = None, dataloader_params: Dict = None
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.CITYSCAPES_VAL)
 def cityscapes_val(dataset_params: Dict = None, dataloader_params: Dict = None):
 def cityscapes_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="cityscapes_dataset_params",
         config_name="cityscapes_dataset_params",
@@ -455,6 +493,7 @@ def cityscapes_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.CITYSCAPES_STDC_SEG50_TRAIN)
 def cityscapes_stdc_seg50_train(dataset_params: Dict = None, dataloader_params: Dict = None):
 def cityscapes_stdc_seg50_train(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="cityscapes_stdc_seg50_dataset_params",
         config_name="cityscapes_stdc_seg50_dataset_params",
@@ -465,6 +504,7 @@ def cityscapes_stdc_seg50_train(dataset_params: Dict = None, dataloader_params:
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.CITYSCAPES_STDC_SEG50_VAL)
 def cityscapes_stdc_seg50_val(dataset_params: Dict = None, dataloader_params: Dict = None):
 def cityscapes_stdc_seg50_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="cityscapes_stdc_seg50_dataset_params",
         config_name="cityscapes_stdc_seg50_dataset_params",
@@ -475,6 +515,7 @@ def cityscapes_stdc_seg50_val(dataset_params: Dict = None, dataloader_params: Di
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.CITYSCAPES_STDC_SEG75_TRAIN)
 def cityscapes_stdc_seg75_train(dataset_params: Dict = None, dataloader_params: Dict = None):
 def cityscapes_stdc_seg75_train(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="cityscapes_stdc_seg75_dataset_params",
         config_name="cityscapes_stdc_seg75_dataset_params",
@@ -485,6 +526,7 @@ def cityscapes_stdc_seg75_train(dataset_params: Dict = None, dataloader_params:
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.CITYSCAPES_STDC_SEG75_VAL)
 def cityscapes_stdc_seg75_val(dataset_params: Dict = None, dataloader_params: Dict = None):
 def cityscapes_stdc_seg75_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="cityscapes_stdc_seg75_dataset_params",
         config_name="cityscapes_stdc_seg75_dataset_params",
@@ -495,6 +537,7 @@ def cityscapes_stdc_seg75_val(dataset_params: Dict = None, dataloader_params: Di
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.CITYSCAPES_REGSEG48_TRAIN)
 def cityscapes_regseg48_train(dataset_params: Dict = None, dataloader_params: Dict = None):
 def cityscapes_regseg48_train(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="cityscapes_regseg48_dataset_params",
         config_name="cityscapes_regseg48_dataset_params",
@@ -505,6 +548,7 @@ def cityscapes_regseg48_train(dataset_params: Dict = None, dataloader_params: Di
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.CITYSCAPES_REGSEG48_VAL)
 def cityscapes_regseg48_val(dataset_params: Dict = None, dataloader_params: Dict = None):
 def cityscapes_regseg48_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="cityscapes_regseg48_dataset_params",
         config_name="cityscapes_regseg48_dataset_params",
@@ -515,6 +559,7 @@ def cityscapes_regseg48_val(dataset_params: Dict = None, dataloader_params: Dict
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.CITYSCAPES_DDRNET_TRAIN)
 def cityscapes_ddrnet_train(dataset_params: Dict = None, dataloader_params: Dict = None):
 def cityscapes_ddrnet_train(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="cityscapes_ddrnet_dataset_params",
         config_name="cityscapes_ddrnet_dataset_params",
@@ -525,6 +570,7 @@ def cityscapes_ddrnet_train(dataset_params: Dict = None, dataloader_params: Dict
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.CITYSCAPES_DDRNET_VAL)
 def cityscapes_ddrnet_val(dataset_params: Dict = None, dataloader_params: Dict = None):
 def cityscapes_ddrnet_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="cityscapes_ddrnet_dataset_params",
         config_name="cityscapes_ddrnet_dataset_params",
@@ -535,6 +581,7 @@ def cityscapes_ddrnet_val(dataset_params: Dict = None, dataloader_params: Dict =
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.COCO_SEGMENTATION_TRAIN)
 def coco_segmentation_train(dataset_params: Dict = None, dataloader_params: Dict = None):
 def coco_segmentation_train(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="coco_segmentation_dataset_params",
         config_name="coco_segmentation_dataset_params",
@@ -545,6 +592,7 @@ def coco_segmentation_train(dataset_params: Dict = None, dataloader_params: Dict
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.COCO_SEGMENTATION_VAL)
 def coco_segmentation_val(dataset_params: Dict = None, dataloader_params: Dict = None):
 def coco_segmentation_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="coco_segmentation_dataset_params",
         config_name="coco_segmentation_dataset_params",
@@ -555,6 +603,7 @@ def coco_segmentation_val(dataset_params: Dict = None, dataloader_params: Dict =
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.PASCAL_AUG_SEGMENTATION_TRAIN)
 def pascal_aug_segmentation_train(dataset_params: Dict = None, dataloader_params: Dict = None):
 def pascal_aug_segmentation_train(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="pascal_aug_segmentation_dataset_params",
         config_name="pascal_aug_segmentation_dataset_params",
@@ -565,10 +614,12 @@ def pascal_aug_segmentation_train(dataset_params: Dict = None, dataloader_params
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.PASCAL_AUG_SEGMENTATION_VAL)
 def pascal_aug_segmentation_val(dataset_params: Dict = None, dataloader_params: Dict = None):
 def pascal_aug_segmentation_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     return pascal_voc_segmentation_val(dataset_params=dataset_params, dataloader_params=dataloader_params)
     return pascal_voc_segmentation_val(dataset_params=dataset_params, dataloader_params=dataloader_params)
 
 
 
 
+@register_dataloader(Dataloaders.PASCAL_VOC_SEGMENTATION_TRAIN)
 def pascal_voc_segmentation_train(dataset_params: Dict = None, dataloader_params: Dict = None):
 def pascal_voc_segmentation_train(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="pascal_voc_segmentation_dataset_params",
         config_name="pascal_voc_segmentation_dataset_params",
@@ -579,6 +630,7 @@ def pascal_voc_segmentation_train(dataset_params: Dict = None, dataloader_params
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.PASCAL_VOC_SEGMENTATION_VAL)
 def pascal_voc_segmentation_val(dataset_params: Dict = None, dataloader_params: Dict = None):
 def pascal_voc_segmentation_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="pascal_voc_segmentation_dataset_params",
         config_name="pascal_voc_segmentation_dataset_params",
@@ -589,6 +641,7 @@ def pascal_voc_segmentation_val(dataset_params: Dict = None, dataloader_params:
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.SUPERVISELY_PERSONS_TRAIN)
 def supervisely_persons_train(dataset_params: Dict = None, dataloader_params: Dict = None):
 def supervisely_persons_train(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="supervisely_persons_dataset_params",
         config_name="supervisely_persons_dataset_params",
@@ -599,6 +652,7 @@ def supervisely_persons_train(dataset_params: Dict = None, dataloader_params: Di
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.SUPERVISELY_PERSONS_VAL)
 def supervisely_persons_val(dataset_params: Dict = None, dataloader_params: Dict = None):
 def supervisely_persons_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="supervisely_persons_dataset_params",
         config_name="supervisely_persons_dataset_params",
@@ -609,6 +663,7 @@ def supervisely_persons_val(dataset_params: Dict = None, dataloader_params: Dict
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.MAPILLARY_TRAIN)
 def mapillary_train(dataset_params: Dict = None, dataloader_params: Dict = None):
 def mapillary_train(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="mapillary_dataset_params",
         config_name="mapillary_dataset_params",
@@ -619,6 +674,7 @@ def mapillary_train(dataset_params: Dict = None, dataloader_params: Dict = None)
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.MAPILLARY_VAL)
 def mapillary_val(dataset_params: Dict = None, dataloader_params: Dict = None):
 def mapillary_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="mapillary_dataset_params",
         config_name="mapillary_dataset_params",
@@ -629,6 +685,7 @@ def mapillary_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.PASCAL_VOC_DETECTION_TRAIN)
 def pascal_voc_detection_train(dataset_params: Dict = None, dataloader_params: Dict = None):
 def pascal_voc_detection_train(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="pascal_voc_detection_dataset_params",
         config_name="pascal_voc_detection_dataset_params",
@@ -639,6 +696,7 @@ def pascal_voc_detection_train(dataset_params: Dict = None, dataloader_params: D
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.PASCAL_VOC_DETECTION_VAL)
 def pascal_voc_detection_val(dataset_params: Dict = None, dataloader_params: Dict = None):
 def pascal_voc_detection_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="pascal_voc_detection_dataset_params",
         config_name="pascal_voc_detection_dataset_params",
@@ -649,6 +707,7 @@ def pascal_voc_detection_val(dataset_params: Dict = None, dataloader_params: Dic
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.COCO2017_POSE_TRAIN)
 def coco2017_pose_train(dataset_params: Dict = None, dataloader_params: Dict = None):
 def coco2017_pose_train(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="coco_pose_estimation_dataset_params",
         config_name="coco_pose_estimation_dataset_params",
@@ -659,6 +718,7 @@ def coco2017_pose_train(dataset_params: Dict = None, dataloader_params: Dict = N
     )
     )
 
 
 
 
+@register_dataloader(Dataloaders.COCO2017_POSE_VAL)
 def coco2017_pose_val(dataset_params: Dict = None, dataloader_params: Dict = None):
 def coco2017_pose_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="coco_pose_estimation_dataset_params",
         config_name="coco_pose_estimation_dataset_params",
@@ -669,66 +729,6 @@ def coco2017_pose_val(dataset_params: Dict = None, dataloader_params: Dict = Non
     )
     )
 
 
 
 
-ALL_DATALOADERS = {
-    "coco2017_train": coco2017_train,
-    "coco2017_val": coco2017_val,
-    "coco2017_train_yolox": coco2017_train_yolox,
-    "coco2017_val_yolox": coco2017_val_yolox,
-    "coco2017_train_ppyoloe": coco2017_train_ppyoloe,
-    "coco2017_val_ppyoloe": coco2017_val_ppyoloe,
-    "coco2017_train_ssd_lite_mobilenet_v2": coco2017_train_ssd_lite_mobilenet_v2,
-    "coco2017_val_ssd_lite_mobilenet_v2": coco2017_val_ssd_lite_mobilenet_v2,
-    "coco2017_pose_train": coco2017_pose_train,
-    "coco2017_pose_val": coco2017_pose_val,
-    "coco2017_train_deci_yolo": coco2017_train_deci_yolo,
-    "coco2017_val_deci_yolo": coco2017_val_deci_yolo,
-    "imagenet_train": imagenet_train,
-    "imagenet_val": imagenet_val,
-    "imagenet_efficientnet_train": imagenet_efficientnet_train,
-    "imagenet_efficientnet_val": imagenet_efficientnet_val,
-    "imagenet_mobilenetv2_train": imagenet_mobilenetv2_train,
-    "imagenet_mobilenetv2_val": imagenet_mobilenetv2_val,
-    "imagenet_mobilenetv3_train": imagenet_mobilenetv3_train,
-    "imagenet_mobilenetv3_val": imagenet_mobilenetv3_val,
-    "imagenet_regnetY_train": imagenet_regnetY_train,
-    "imagenet_regnetY_val": imagenet_regnetY_val,
-    "imagenet_resnet50_train": imagenet_resnet50_train,
-    "imagenet_resnet50_val": imagenet_resnet50_val,
-    "imagenet_resnet50_kd_train": imagenet_resnet50_kd_train,
-    "imagenet_resnet50_kd_val": imagenet_resnet50_kd_val,
-    "imagenet_vit_base_train": imagenet_vit_base_train,
-    "imagenet_vit_base_val": imagenet_vit_base_val,
-    "tiny_imagenet_train": tiny_imagenet_train,
-    "tiny_imagenet_val": tiny_imagenet_val,
-    "cifar10_train": cifar10_train,
-    "cifar10_val": cifar10_val,
-    "cifar100_train": cifar100_train,
-    "cifar100_val": cifar100_val,
-    "cityscapes_train": cityscapes_train,
-    "cityscapes_val": cityscapes_val,
-    "cityscapes_stdc_seg50_train": cityscapes_stdc_seg50_train,
-    "cityscapes_stdc_seg50_val": cityscapes_stdc_seg50_val,
-    "cityscapes_stdc_seg75_train": cityscapes_stdc_seg75_train,
-    "cityscapes_stdc_seg75_val": cityscapes_stdc_seg75_val,
-    "cityscapes_regseg48_train": cityscapes_regseg48_train,
-    "cityscapes_regseg48_val": cityscapes_regseg48_val,
-    "cityscapes_ddrnet_train": cityscapes_ddrnet_train,
-    "cityscapes_ddrnet_val": cityscapes_ddrnet_val,
-    "coco_segmentation_train": coco_segmentation_train,
-    "coco_segmentation_val": coco_segmentation_val,
-    "mapillary_train": mapillary_train,
-    "mapillary_val": mapillary_val,
-    "pascal_aug_segmentation_train": pascal_aug_segmentation_train,
-    "pascal_aug_segmentation_val": pascal_aug_segmentation_val,
-    "pascal_voc_segmentation_train": pascal_voc_segmentation_train,
-    "pascal_voc_segmentation_val": pascal_voc_segmentation_val,
-    "supervisely_persons_train": supervisely_persons_train,
-    "supervisely_persons_val": supervisely_persons_val,
-    "pascal_voc_detection_train": pascal_voc_detection_train,
-    "pascal_voc_detection_val": pascal_voc_detection_val,
-}
-
-
 def get(name: str = None, dataset_params: Dict = None, dataloader_params: Dict = None, dataset: torch.utils.data.Dataset = None) -> DataLoader:
 def get(name: str = None, dataset_params: Dict = None, dataloader_params: Dict = None, dataset: torch.utils.data.Dataset = None) -> DataLoader:
     """
     """
     Get DataLoader of the recipe-configured dataset defined by name in ALL_DATALOADERS.
     Get DataLoader of the recipe-configured dataset defined by name in ALL_DATALOADERS.
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
  1. from super_gradients.training.datasets.datasets_utils import ComposedCollateFunction, MultiScaleCollateFunction
  2. from super_gradients.training.datasets.mixup import CollateMixup
  3. from super_gradients.training.datasets.pose_estimation_datasets import KeypointsCollate
  4. from super_gradients.training.utils.detection_utils import DetectionCollateFN, CrowdDetectionCollateFN
  5. ALL_COLLATE_FUNCTIONS = {
  6. "ComposedCollateFunction": ComposedCollateFunction,
  7. "MultiScaleCollateFunction": MultiScaleCollateFunction,
  8. "CollateMixup": CollateMixup,
  9. "KeypointsCollate": KeypointsCollate,
  10. "DetectionCollateFN": DetectionCollateFN,
  11. "CrowdDetectionCollateFN": CrowdDetectionCollateFN,
  12. }
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
  1. from super_gradients.training.datasets.classification_datasets import Cifar10, Cifar100, ImageNetDataset
  2. from super_gradients.training.datasets.detection_datasets import COCODetectionDataset, DetectionDataset, PascalVOCDetectionDataset
  3. from super_gradients.training.datasets.segmentation_datasets import (
  4. SegmentationDataSet,
  5. CoCoSegmentationDataSet,
  6. PascalAUG2012SegmentationDataSet,
  7. PascalVOC2012SegmentationDataSet,
  8. CityscapesDataset,
  9. SuperviselyPersonsDataset,
  10. PascalVOCAndAUGUnifiedDataset,
  11. MapillaryDataset,
  12. )
  13. from super_gradients.training.datasets.pose_estimation_datasets import COCOKeypointsDataset
  14. ALL_DATASETS = {
  15. "Cifar10": Cifar10,
  16. "Cifar100": Cifar100,
  17. "ImageNetDataset": ImageNetDataset,
  18. "COCODetectionDataset": COCODetectionDataset,
  19. "DetectionDataset": DetectionDataset,
  20. "PascalVOCDetectionDataset": PascalVOCDetectionDataset,
  21. "SegmentationDataSet": SegmentationDataSet,
  22. "CoCoSegmentationDataSet": CoCoSegmentationDataSet,
  23. "PascalAUG2012SegmentationDataSet": PascalAUG2012SegmentationDataSet,
  24. "PascalVOC2012SegmentationDataSet": PascalVOC2012SegmentationDataSet,
  25. "CityscapesDataset": CityscapesDataset,
  26. "MapillaryDataset": MapillaryDataset,
  27. "SuperviselyPersonsDataset": SuperviselyPersonsDataset,
  28. "PascalVOCAndAUGUnifiedDataset": PascalVOCAndAUGUnifiedDataset,
  29. "COCOKeypointsDataset": COCOKeypointsDataset,
  30. }
Discard
1
2
3
  1. from super_gradients.training.datasets.pose_estimation_datasets.target_generators import DEKRTargetsGenerator
  2. ALL_TARGET_GENERATORS = {"DEKRTargetsGenerator": DEKRTargetsGenerator}
Discard
@@ -15,6 +15,8 @@ from typing import List
 from PIL import Image, ImageOps, ImageEnhance
 from PIL import Image, ImageOps, ImageEnhance
 import numpy as np
 import numpy as np
 
 
+from super_gradients.common.object_names import Transforms
+from super_gradients.common.registry.registry import register_transform
 
 
 _FILL = (128, 128, 128)
 _FILL = (128, 128, 128)
 
 
@@ -393,6 +395,7 @@ class RandAugment:
         return img
         return img
 
 
 
 
+@register_transform(Transforms.RandAugmentTransform)
 def rand_augment_transform(config_str, crop_size: int, img_mean: List[float]):
 def rand_augment_transform(config_str, crop_size: int, img_mean: List[float]):
     """
     """
     Create a RandAugment transform
     Create a RandAugment transform
Discard
@@ -1,12 +1,15 @@
 from typing import Optional, Callable, Union
 from typing import Optional, Callable, Union
 
 
+from torchvision.datasets import CIFAR10, CIFAR100
 from torchvision.transforms import Compose
 from torchvision.transforms import Compose
 
 
+from super_gradients.common.object_names import Datasets
+from super_gradients.common.registry.registry import register_dataset
 from super_gradients.common.factories.transforms_factory import TransformsFactory
 from super_gradients.common.factories.transforms_factory import TransformsFactory
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
-from torchvision.datasets import CIFAR10, CIFAR100
 
 
 
 
+@register_dataset(Datasets.CIFAR_10)
 class Cifar10(CIFAR10):
 class Cifar10(CIFAR10):
     """
     """
     CIFAR10 Dataset
     CIFAR10 Dataset
@@ -41,6 +44,7 @@ class Cifar10(CIFAR10):
         )
         )
 
 
 
 
+@register_dataset(Datasets.CIFAR_100)
 class Cifar100(CIFAR100):
 class Cifar100(CIFAR100):
     @resolve_param("transforms", TransformsFactory())
     @resolve_param("transforms", TransformsFactory())
     def __init__(
     def __init__(
Discard
@@ -3,10 +3,13 @@ from typing import Union
 import torchvision.datasets as torch_datasets
 import torchvision.datasets as torch_datasets
 from torchvision.transforms import Compose
 from torchvision.transforms import Compose
 
 
+from super_gradients.common.registry.registry import register_dataset
+from super_gradients.common.object_names import Datasets
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.factories.transforms_factory import TransformsFactory
 from super_gradients.common.factories.transforms_factory import TransformsFactory
 
 
 
 
+@register_dataset(Datasets.IMAGENET_DATASET)
 class ImageNetDataset(torch_datasets.ImageFolder):
 class ImageNetDataset(torch_datasets.ImageFolder):
     """ImageNetDataset dataset.
     """ImageNetDataset dataset.
 
 
Discard
@@ -2,6 +2,9 @@ import numpy as np
 import torch
 import torch
 from torchvision.transforms import RandomErasing
 from torchvision.transforms import RandomErasing
 
 
+from super_gradients.common.object_names import Transforms
+from super_gradients.common.registry.registry import register_transform
+
 
 
 class DataAugmentation:
 class DataAugmentation:
     @staticmethod
     @staticmethod
@@ -68,6 +71,7 @@ IMAGENET_PCA = {
 }
 }
 
 
 
 
+@register_transform(Transforms.Lighting)
 class Lighting(object):
 class Lighting(object):
     """
     """
     Lighting noise(AlexNet - style PCA - based noise)
     Lighting noise(AlexNet - style PCA - based noise)
@@ -92,6 +96,7 @@ class Lighting(object):
         return img.add(rgb.view(3, 1, 1).expand_as(img))
         return img.add(rgb.view(3, 1, 1).expand_as(img))
 
 
 
 
+@register_transform(Transforms.RandomErase)
 class RandomErase(RandomErasing):
 class RandomErase(RandomErasing):
     """
     """
     A simple class that translates the parameters supported in SuperGradient's code base
     A simple class that translates the parameters supported in SuperGradient's code base
Discard
@@ -20,6 +20,8 @@ from torchvision.transforms import transforms, InterpolationMode, RandomResizedC
 from tqdm import tqdm
 from tqdm import tqdm
 
 
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
+from super_gradients.common.object_names import Callbacks, Transforms
+from super_gradients.common.registry.registry import register_collate_function, register_callback, register_transform
 from super_gradients.training.datasets.auto_augment import rand_augment_transform
 from super_gradients.training.datasets.auto_augment import rand_augment_transform
 from super_gradients.training.utils.detection_utils import DetectionVisualization, Anchors
 from super_gradients.training.utils.detection_utils import DetectionVisualization, Anchors
 from super_gradients.training.utils.distributed_training_utils import get_local_rank, get_world_size
 from super_gradients.training.utils.distributed_training_utils import get_local_rank, get_world_size
@@ -102,6 +104,7 @@ class AbstractCollateFunction(ABC):
         pass
         pass
 
 
 
 
+@register_collate_function()
 class ComposedCollateFunction(AbstractCollateFunction):
 class ComposedCollateFunction(AbstractCollateFunction):
     """
     """
     A function (for torch DataLoader) which executes a sequence of sub collate functions
     A function (for torch DataLoader) which executes a sequence of sub collate functions
@@ -127,6 +130,7 @@ class AtomicInteger:
         return self._value.value
         return self._value.value
 
 
 
 
+@register_collate_function()
 class MultiScaleCollateFunction(AbstractCollateFunction):
 class MultiScaleCollateFunction(AbstractCollateFunction):
     """
     """
     a collate function to implement multi-scale data augmentation
     a collate function to implement multi-scale data augmentation
@@ -269,6 +273,7 @@ class MultiscalePrePredictionCallback(AbstractPrePredictionCallback):
         return inputs, targets
         return inputs, targets
 
 
 
 
+@register_callback(Callbacks.DETECTION_MULTISCALE_PREPREDICTION)
 class DetectionMultiscalePrePredictionCallback(MultiscalePrePredictionCallback):
 class DetectionMultiscalePrePredictionCallback(MultiscalePrePredictionCallback):
     """
     """
     Mutiscalepre-prediction callback for object detection.
     Mutiscalepre-prediction callback for object detection.
@@ -335,6 +340,7 @@ def _pil_interp(method):
 _RANDOM_INTERPOLATION = (InterpolationMode.BILINEAR, InterpolationMode.BICUBIC)
 _RANDOM_INTERPOLATION = (InterpolationMode.BILINEAR, InterpolationMode.BICUBIC)
 
 
 
 
+@register_transform(Transforms.RandomResizedCropAndInterpolation)
 class RandomResizedCropAndInterpolation(RandomResizedCrop):
 class RandomResizedCropAndInterpolation(RandomResizedCrop):
     """
     """
     Crop the given PIL Image to random size and aspect ratio with explicitly chosen or random interpolation.
     Crop the given PIL Image to random size and aspect ratio with explicitly chosen or random interpolation.
Discard
@@ -7,6 +7,8 @@ from pycocotools.coco import COCO
 
 
 from contextlib import redirect_stdout
 from contextlib import redirect_stdout
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
+from super_gradients.common.object_names import Datasets
+from super_gradients.common.registry.registry import register_dataset
 from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset
 from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset
 from super_gradients.training.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException
 from super_gradients.training.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException
 from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL
 from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL
@@ -14,6 +16,7 @@ from super_gradients.training.datasets.data_formats.default_formats import XYXY_
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
+@register_dataset(Datasets.COCO_DETECTION_DATASET)
 class COCODetectionDataset(DetectionDataset):
 class COCODetectionDataset(DetectionDataset):
     """Dataset for COCO object detection.
     """Dataset for COCO object detection.
 
 
Discard
@@ -13,6 +13,8 @@ import numpy as np
 from tqdm import tqdm
 from tqdm import tqdm
 from torch.utils.data import Dataset
 from torch.utils.data import Dataset
 
 
+from super_gradients.common.object_names import Datasets
+from super_gradients.common.registry.registry import register_dataset
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.training.utils.detection_utils import get_cls_posx_in_target
 from super_gradients.training.utils.detection_utils import get_cls_posx_in_target
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
@@ -26,6 +28,7 @@ from super_gradients.training.datasets.data_formats.formats import ConcatenatedT
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
+@register_dataset(Datasets.DETECTION_DATASET)
 class DetectionDataset(Dataset):
 class DetectionDataset(Dataset):
     """Detection dataset.
     """Detection dataset.
 
 
Discard
@@ -9,6 +9,8 @@ from tqdm import tqdm
 
 
 import numpy as np
 import numpy as np
 
 
+from super_gradients.common.object_names import Datasets
+from super_gradients.common.registry.registry import register_dataset
 from super_gradients.training.transforms.transforms import DetectionTransform
 from super_gradients.training.transforms.transforms import DetectionTransform
 from super_gradients.training.utils.utils import download_and_untar_from_url, get_image_size_from_path
 from super_gradients.training.utils.utils import download_and_untar_from_url, get_image_size_from_path
 from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset
 from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset
@@ -19,6 +21,7 @@ from super_gradients.training.datasets.datasets_conf import PASCAL_VOC_2012_CLAS
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
+@register_dataset(Datasets.PASCAL_VOC_DETECTION_DATASET)
 class PascalVOCDetectionDataset(DetectionDataset):
 class PascalVOCDetectionDataset(DetectionDataset):
     """Dataset for Pascal VOC object detection
     """Dataset for Pascal VOC object detection
 
 
Discard
@@ -15,6 +15,7 @@ from typing import List, Union
 import numpy as np
 import numpy as np
 import torch
 import torch
 
 
+from super_gradients.common.registry.registry import register_collate_function
 from super_gradients.training.exceptions.dataset_exceptions import IllegalDatasetParameterException
 from super_gradients.training.exceptions.dataset_exceptions import IllegalDatasetParameterException
 
 
 
 
@@ -99,6 +100,7 @@ def cutmix_bbox_and_lam(img_shape: tuple, lam: float, ratio_minmax: Union[tuple,
     return (yl, yu, xl, xu), lam
     return (yl, yu, xl, xu), lam
 
 
 
 
+@register_collate_function()
 class CollateMixup:
 class CollateMixup:
     """
     """
     Collate with Mixup/Cutmix that applies different params to each element or whole batch
     Collate with Mixup/Cutmix that applies different params to each element or whole batch
Discard
@@ -6,6 +6,7 @@ import torch
 from torch.utils.data import default_collate, Dataset
 from torch.utils.data import default_collate, Dataset
 
 
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
+from super_gradients.common.registry.registry import register_collate_function
 from super_gradients.training.datasets.pose_estimation_datasets.target_generators import KeypointsTargetsGenerator
 from super_gradients.training.datasets.pose_estimation_datasets.target_generators import KeypointsTargetsGenerator
 from super_gradients.training.transforms.keypoint_transforms import KeypointsCompose, KeypointTransform
 from super_gradients.training.transforms.keypoint_transforms import KeypointsCompose, KeypointTransform
 
 
@@ -95,6 +96,7 @@ class BaseKeypointsDataset(Dataset):
         return joints
         return joints
 
 
 
 
+@register_collate_function()
 class KeypointsCollate:
 class KeypointsCollate:
     """
     """
     Collate image & targets, return extras as is.
     Collate image & targets, return extras as is.
Discard
@@ -8,6 +8,8 @@ from pycocotools.coco import COCO
 from torch import Tensor
 from torch import Tensor
 
 
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
+from super_gradients.common.object_names import Datasets
+from super_gradients.common.registry.registry import register_dataset
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.factories.target_generator_factory import TargetGeneratorsFactory
 from super_gradients.common.factories.target_generator_factory import TargetGeneratorsFactory
 from super_gradients.common.factories.transforms_factory import TransformsFactory
 from super_gradients.common.factories.transforms_factory import TransformsFactory
@@ -17,6 +19,7 @@ from super_gradients.training.transforms.keypoint_transforms import KeypointTran
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
+@register_dataset(Datasets.COCO_KEY_POINTS_DATASET)
 class COCOKeypointsDataset(BaseKeypointsDataset):
 class COCOKeypointsDataset(BaseKeypointsDataset):
     """
     """
     Dataset class for training pose estimation models on COCO Keypoints dataset.
     Dataset class for training pose estimation models on COCO Keypoints dataset.
Discard
@@ -5,6 +5,8 @@ import cv2
 import numpy as np
 import numpy as np
 from torch import Tensor
 from torch import Tensor
 
 
+from super_gradients.common.registry.registry import register_target_generator
+
 __all__ = ["KeypointsTargetsGenerator", "DEKRTargetsGenerator"]
 __all__ = ["KeypointsTargetsGenerator", "DEKRTargetsGenerator"]
 
 
 
 
@@ -24,6 +26,7 @@ class KeypointsTargetsGenerator:
         raise NotImplementedError()
         raise NotImplementedError()
 
 
 
 
+@register_target_generator()
 class DEKRTargetsGenerator(KeypointsTargetsGenerator):
 class DEKRTargetsGenerator(KeypointsTargetsGenerator):
     """
     """
     Target generator for pose estimation task tailored for the DEKR paper (https://arxiv.org/abs/2104.02300)
     Target generator for pose estimation task tailored for the DEKR paper (https://arxiv.org/abs/2104.02300)
Discard
@@ -1,6 +1,7 @@
 from super_gradients.training.datasets.samplers.infinite_sampler import InfiniteSampler
 from super_gradients.training.datasets.samplers.infinite_sampler import InfiniteSampler
 from super_gradients.training.datasets.samplers.repeated_augmentation_sampler import RepeatAugSampler
 from super_gradients.training.datasets.samplers.repeated_augmentation_sampler import RepeatAugSampler
-from super_gradients.training.datasets.samplers.all_samplers import SAMPLERS, Samplers
+from super_gradients.common.object_names import Samplers
+from super_gradients.common.registry.registry import SAMPLERS
 
 
 
 
 __all__ = ["SAMPLERS", "Samplers", "InfiniteSampler", "RepeatAugSampler"]
 __all__ = ["SAMPLERS", "Samplers", "InfiniteSampler", "RepeatAugSampler"]
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
  1. from super_gradients.common.object_names import Samplers
  2. from super_gradients.training.datasets.samplers.infinite_sampler import InfiniteSampler
  3. from super_gradients.training.datasets.samplers.repeated_augmentation_sampler import RepeatAugSampler
  4. from torch.utils.data.distributed import DistributedSampler
  5. from torch.utils.data.sampler import SequentialSampler, SubsetRandomSampler, RandomSampler, WeightedRandomSampler
  6. SAMPLERS = {
  7. Samplers.INFINITE: InfiniteSampler,
  8. Samplers.REPEAT_AUG: RepeatAugSampler,
  9. Samplers.DISTRIBUTED: DistributedSampler,
  10. Samplers.SEQUENTIAL: SequentialSampler,
  11. Samplers.SUBSET_RANDOM: SubsetRandomSampler,
  12. Samplers.RANDOM: RandomSampler,
  13. Samplers.WEIGHTED_RANDOM: WeightedRandomSampler,
  14. }
Discard
@@ -9,7 +9,11 @@ import torch.distributed as dist
 from torch.utils.data.sampler import Sampler
 from torch.utils.data.sampler import Sampler
 from deprecate import deprecated
 from deprecate import deprecated
 
 
+from super_gradients.common.object_names import Samplers
+from super_gradients.common.registry.registry import register_sampler
 
 
+
+@register_sampler(Samplers.INFINITE)
 class InfiniteSampler(Sampler):
 class InfiniteSampler(Sampler):
     """
     """
     In training, we only care about the "infinite stream" of training data.
     In training, we only care about the "infinite stream" of training data.
Discard
@@ -3,10 +3,13 @@ import torch
 from torch.utils.data import Sampler
 from torch.utils.data import Sampler
 import torch.distributed as dist
 import torch.distributed as dist
 
 
+from super_gradients.common.object_names import Samplers
+from super_gradients.common.registry.registry import register_sampler
 
 
 # TODO: Add unit test for RepeatAugSampler once DDP unit tests are supported.
 # TODO: Add unit test for RepeatAugSampler once DDP unit tests are supported.
 
 
 
 
+@register_sampler(Samplers.REPEAT_AUG)
 class RepeatAugSampler(Sampler):
 class RepeatAugSampler(Sampler):
     """
     """
     Sampler that restricts data loading to a subset of the dataset for distributed,
     Sampler that restricts data loading to a subset of the dataset for distributed,
Discard
@@ -2,6 +2,9 @@ import os
 import cv2
 import cv2
 import numpy as np
 import numpy as np
 from PIL import Image, ImageColor
 from PIL import Image, ImageColor
+
+from super_gradients.common.object_names import Datasets
+from super_gradients.common.registry.registry import register_dataset
 from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
 from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
 
 
 # TODO - ADD COARSE DATA - right now cityscapes dataset includes fine annotations. It's optional to use extra coarse
 # TODO - ADD COARSE DATA - right now cityscapes dataset includes fine annotations. It's optional to use extra coarse
@@ -11,6 +14,7 @@ from super_gradients.training.datasets.segmentation_datasets.segmentation_datase
 CITYSCAPES_IGNORE_LABEL = 19
 CITYSCAPES_IGNORE_LABEL = 19
 
 
 
 
+@register_dataset(Datasets.CITYSCAPES_DATASET)
 class CityscapesDataset(SegmentationDataSet):
 class CityscapesDataset(SegmentationDataSet):
     """
     """
     CityscapesDataset - Segmentation Data Set Class for Cityscapes Segmentation Data Set,
     CityscapesDataset - Segmentation Data Set Class for Cityscapes Segmentation Data Set,
Discard
@@ -11,6 +11,8 @@ try:
 except ModuleNotFoundError as ex:
 except ModuleNotFoundError as ex:
     print("[WARNING]" + str(ex))
     print("[WARNING]" + str(ex))
 
 
+from super_gradients.common.object_names import Datasets
+from super_gradients.common.registry.registry import register_dataset
 from super_gradients.training.datasets.datasets_conf import COCO_DEFAULT_CLASSES_TUPLES_LIST
 from super_gradients.training.datasets.datasets_conf import COCO_DEFAULT_CLASSES_TUPLES_LIST
 from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
 from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
 
 
@@ -19,6 +21,7 @@ class EmptyCoCoClassesSelectionException(Exception):
     pass
     pass
 
 
 
 
+@register_dataset(Datasets.COCO_SEGMENTATION_DATASET)
 class CoCoSegmentationDataSet(SegmentationDataSet):
 class CoCoSegmentationDataSet(SegmentationDataSet):
     """
     """
     Segmentation Data Set Class for COCO 2017 Segmentation Data Set
     Segmentation Data Set Class for COCO 2017 Segmentation Data Set
Discard
@@ -4,9 +4,12 @@ import os
 import numpy as np
 import numpy as np
 from PIL import Image
 from PIL import Image
 
 
+from super_gradients.common.object_names import Datasets
+from super_gradients.common.registry.registry import register_dataset
 from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
 from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
 
 
 
 
+@register_dataset(Datasets.MAPILLARY_DATASET)
 class MapillaryDataset(SegmentationDataSet):
 class MapillaryDataset(SegmentationDataSet):
     """
     """
     Mapillary Vistas is a large-scale urban street-view dataset.
     Mapillary Vistas is a large-scale urban street-view dataset.
Discard
@@ -5,6 +5,8 @@ import scipy.io
 from PIL import Image
 from PIL import Image
 from torch.utils.data import ConcatDataset
 from torch.utils.data import ConcatDataset
 
 
+from super_gradients.common.object_names import Datasets
+from super_gradients.common.registry.registry import register_dataset
 from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
 from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 
 
@@ -35,6 +37,7 @@ PASCAL_VOC_2012_CLASSES = [
 ]
 ]
 
 
 
 
+@register_dataset(Datasets.PASCAL_VOC_2012_SEGMENTATION_DATASET)
 class PascalVOC2012SegmentationDataSet(SegmentationDataSet):
 class PascalVOC2012SegmentationDataSet(SegmentationDataSet):
     """
     """
     Segmentation Data Set Class for Pascal VOC 2012 Data Set.
     Segmentation Data Set Class for Pascal VOC 2012 Data Set.
@@ -173,6 +176,7 @@ class PascalVOC2012SegmentationDataSet(SegmentationDataSet):
         )
         )
 
 
 
 
+@register_dataset(Datasets.PASCAL_AUG_2012_SEGMENTATION_DATASET)
 class PascalAUG2012SegmentationDataSet(PascalVOC2012SegmentationDataSet):
 class PascalAUG2012SegmentationDataSet(PascalVOC2012SegmentationDataSet):
     """
     """
     Segmentation Data Set Class for Pascal AUG 2012 Data Set
     Segmentation Data Set Class for Pascal AUG 2012 Data Set
@@ -218,6 +222,7 @@ class PascalAUG2012SegmentationDataSet(PascalVOC2012SegmentationDataSet):
         return Image.fromarray(mask)
         return Image.fromarray(mask)
 
 
 
 
+@register_dataset(Datasets.PASCAL_VOC_AND_AUG_UNIFIED_DATASET)
 class PascalVOCAndAUGUnifiedDataset(ConcatDataset):
 class PascalVOCAndAUGUnifiedDataset(ConcatDataset):
     """
     """
     Pascal VOC + AUG train dataset, aka `SBD` dataset contributed in "Semantic contours from inverse detectors".
     Pascal VOC + AUG train dataset, aka `SBD` dataset contributed in "Semantic contours from inverse detectors".
Discard
@@ -7,11 +7,14 @@ import torchvision.transforms as transform
 from PIL import Image
 from PIL import Image
 from tqdm import tqdm
 from tqdm import tqdm
 
 
+from super_gradients.common.object_names import Datasets
+from super_gradients.common.registry.registry import register_dataset
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.factories.transforms_factory import TransformsFactory
 from super_gradients.common.factories.transforms_factory import TransformsFactory
 from super_gradients.training.datasets.sg_dataset import DirectoryDataSet, ListDataset
 from super_gradients.training.datasets.sg_dataset import DirectoryDataSet, ListDataset
 
 
 
 
+@register_dataset(Datasets.SEGMENTATION_DATASET)
 class SegmentationDataSet(DirectoryDataSet, ListDataset):
 class SegmentationDataSet(DirectoryDataSet, ListDataset):
     @resolve_param("transforms", factory=TransformsFactory())
     @resolve_param("transforms", factory=TransformsFactory())
     def __init__(
     def __init__(
Discard
@@ -1,9 +1,12 @@
 import csv
 import csv
 import os
 import os
 
 
+from super_gradients.common.object_names import Datasets
+from super_gradients.common.registry.registry import register_dataset
 from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
 from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
 
 
 
 
+@register_dataset(Datasets.SUPERVISELY_PERSONS_DATASET)
 class SuperviselyPersonsDataset(SegmentationDataSet):
 class SuperviselyPersonsDataset(SegmentationDataSet):
     """
     """
     SuperviselyPersonsDataset - Segmentation Data Set Class for Supervisely Persons Segmentation Data Set,
     SuperviselyPersonsDataset - Segmentation Data Set Class for Supervisely Persons Segmentation Data Set,
Discard
@@ -18,7 +18,7 @@ from super_gradients.training.exceptions.kd_trainer_exceptions import (
     UndefinedNumClassesException,
     UndefinedNumClassesException,
 )
 )
 from super_gradients.training.models import SgModule
 from super_gradients.training.models import SgModule
-from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
+from super_gradients.common.registry.registry import KD_ARCHITECTURES
 from super_gradients.training.models.kd_modules.kd_module import KDModule
 from super_gradients.training.models.kd_modules.kd_module import KDModule
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.sg_trainer import Trainer
 from super_gradients.training.sg_trainer import Trainer
Discard
@@ -1,6 +1,3 @@
-from super_gradients.common.registry.registry import LOSSES
-from super_gradients.common.object_names import Losses
-
 from super_gradients.training.losses.focal_loss import FocalLoss
 from super_gradients.training.losses.focal_loss import FocalLoss
 from super_gradients.training.losses.kd_losses import KDLogitsLoss
 from super_gradients.training.losses.kd_losses import KDLogitsLoss
 from super_gradients.training.losses.label_smoothing_cross_entropy_loss import LabelSmoothingCrossEntropyLoss
 from super_gradients.training.losses.label_smoothing_cross_entropy_loss import LabelSmoothingCrossEntropyLoss
@@ -13,6 +10,10 @@ from super_gradients.training.losses.bce_dice_loss import BCEDiceLoss
 from super_gradients.training.losses.dice_ce_edge_loss import DiceCEEdgeLoss
 from super_gradients.training.losses.dice_ce_edge_loss import DiceCEEdgeLoss
 from super_gradients.training.losses.ppyolo_loss import PPYoloELoss
 from super_gradients.training.losses.ppyolo_loss import PPYoloELoss
 from super_gradients.training.losses.dekr_loss import DEKRLoss
 from super_gradients.training.losses.dekr_loss import DEKRLoss
+from super_gradients.training.losses.stdc_loss import STDCLoss
+
+from super_gradients.common.object_names import Losses
+from super_gradients.common.registry.registry import LOSSES
 
 
 __all__ = [
 __all__ = [
     "LOSSES",
     "LOSSES",
@@ -30,4 +31,5 @@ __all__ = [
     "DiceCEEdgeLoss",
     "DiceCEEdgeLoss",
     "PPYoloELoss",
     "PPYoloELoss",
     "DEKRLoss",
     "DEKRLoss",
+    "STDCLoss",
 ]
 ]
Discard
    Discard
    @@ -1,10 +1,11 @@
     import torch
     import torch
     import torch.nn as nn
     import torch.nn as nn
    -from super_gradients.training.losses.dice_loss import DiceLoss, BinaryDiceLoss
    -from super_gradients.training.utils.segmentation_utils import target_to_binary_edge
     from torch.nn.modules.loss import _Loss
     from torch.nn.modules.loss import _Loss
     from typing import Union, Tuple
     from typing import Union, Tuple
     
     
    +from super_gradients.training.losses.dice_loss import DiceLoss, BinaryDiceLoss
    +from super_gradients.training.utils.segmentation_utils import target_to_binary_edge
    +
     from super_gradients.common.object_names import Losses
     from super_gradients.common.object_names import Losses
     from super_gradients.common.registry.registry import register_loss
     from super_gradients.common.registry.registry import register_loss
     from super_gradients.training.losses.mask_loss import MaskAttentionLoss
     from super_gradients.training.losses.mask_loss import MaskAttentionLoss
    Discard
    @@ -4,11 +4,9 @@ import torch
     import torch.nn as nn
     import torch.nn as nn
     from torch.nn.modules.loss import _Loss
     from torch.nn.modules.loss import _Loss
     
     
    -from super_gradients.training.utils import convert_to_tensor
    -
    -
     from super_gradients.common.object_names import Losses
     from super_gradients.common.object_names import Losses
     from super_gradients.common.registry.registry import register_loss
     from super_gradients.common.registry.registry import register_loss
    +from super_gradients.training.utils import convert_to_tensor
     
     
     
     
     @register_loss(Losses.R_SQUARED_LOSS)
     @register_loss(Losses.R_SQUARED_LOSS)
    Discard
    @@ -2,7 +2,6 @@ import torch
     from torch import nn
     from torch import nn
     from torch.autograd import Variable
     from torch.autograd import Variable
     
     
    -
     from super_gradients.common.object_names import Losses
     from super_gradients.common.object_names import Losses
     from super_gradients.common.registry.registry import register_loss
     from super_gradients.common.registry.registry import register_loss
     
     
    Discard
    @@ -4,7 +4,8 @@ from super_gradients.training.metrics.classification_metrics import accuracy, Ac
     from super_gradients.training.metrics.detection_metrics import DetectionMetrics, DetectionMetrics_050, DetectionMetrics_075, DetectionMetrics_050_095
     from super_gradients.training.metrics.detection_metrics import DetectionMetrics, DetectionMetrics_050, DetectionMetrics_075, DetectionMetrics_050_095
     from super_gradients.training.metrics.segmentation_metrics import PreprocessSegmentationMetricsArgs, PixelAccuracy, IoU, Dice, BinaryIOU, BinaryDice
     from super_gradients.training.metrics.segmentation_metrics import PreprocessSegmentationMetricsArgs, PixelAccuracy, IoU, Dice, BinaryIOU, BinaryDice
     from super_gradients.training.metrics.pose_estimation_metrics import PoseEstimationMetrics
     from super_gradients.training.metrics.pose_estimation_metrics import PoseEstimationMetrics
    -from super_gradients.training.metrics.all_metrics import METRICS, Metrics
    +from super_gradients.common.object_names import Metrics
    +from super_gradients.common.registry.registry import METRICS
     
     
     __all__ = [
     __all__ = [
         "METRICS",
         "METRICS",
    Discard
    @@ -1,31 +0,0 @@
    -from super_gradients.common.object_names import Metrics
    -from super_gradients.training.metrics import (
    -    Accuracy,
    -    Top5,
    -    DetectionMetrics,
    -    IoU,
    -    PixelAccuracy,
    -    BinaryIOU,
    -    Dice,
    -    BinaryDice,
    -    DetectionMetrics_050,
    -    DetectionMetrics_075,
    -    DetectionMetrics_050_095,
    -    PoseEstimationMetrics,
    -)
    -
    -
    -METRICS = {
    -    Metrics.ACCURACY: Accuracy,
    -    Metrics.TOP5: Top5,
    -    Metrics.DETECTION_METRICS: DetectionMetrics,
    -    Metrics.DETECTION_METRICS_050: DetectionMetrics_050,
    -    Metrics.DETECTION_METRICS_075: DetectionMetrics_075,
    -    Metrics.DETECTION_METRICS_050_095: DetectionMetrics_050_095,
    -    Metrics.IOU: IoU,
    -    Metrics.BINARY_IOU: BinaryIOU,
    -    Metrics.DICE: Dice,
    -    Metrics.BINARY_DICE: BinaryDice,
    -    Metrics.PIXEL_ACCURACY: PixelAccuracy,
    -    Metrics.POSE_ESTIMATION_METRICS: PoseEstimationMetrics,
    -}
    Discard
    @@ -1,8 +1,11 @@
    -from super_gradients.training.utils import convert_to_tensor
     import torch
     import torch
     import torchmetrics
     import torchmetrics
     from torchmetrics import Metric
     from torchmetrics import Metric
     
     
    +from super_gradients.common.object_names import Metrics
    +from super_gradients.common.registry.registry import register_metric
    +from super_gradients.training.utils import convert_to_tensor
    +
     
     
     def accuracy(output, target, topk=(1,)):
     def accuracy(output, target, topk=(1,)):
         """Computes the precision@k for the specified values of k
         """Computes the precision@k for the specified values of k
    @@ -34,6 +37,7 @@ def accuracy(output, target, topk=(1,)):
         return res
         return res
     
     
     
     
    +@register_metric(Metrics.ACCURACY)
     class Accuracy(torchmetrics.Accuracy):
     class Accuracy(torchmetrics.Accuracy):
         def __init__(self, dist_sync_on_step=False):
         def __init__(self, dist_sync_on_step=False):
             super().__init__(dist_sync_on_step=dist_sync_on_step)
             super().__init__(dist_sync_on_step=dist_sync_on_step)
    @@ -45,6 +49,7 @@ class Accuracy(torchmetrics.Accuracy):
             super().update(preds=preds.argmax(1), target=target)
             super().update(preds=preds.argmax(1), target=target)
     
     
     
     
    +@register_metric(Metrics.TOP5)
     class Top5(Metric):
     class Top5(Metric):
         def __init__(self, dist_sync_on_step=False):
         def __init__(self, dist_sync_on_step=False):
             super().__init__(dist_sync_on_step=dist_sync_on_step)
             super().__init__(dist_sync_on_step=dist_sync_on_step)
    Discard
    @@ -1,7 +1,10 @@
     from typing import Dict, Optional, Union
     from typing import Dict, Optional, Union
     import torch
     import torch
     from torchmetrics import Metric
     from torchmetrics import Metric
    +
     import super_gradients
     import super_gradients
    +from super_gradients.common.object_names import Metrics
    +from super_gradients.common.registry.registry import register_metric
     from super_gradients.training.utils import tensor_container_to_device
     from super_gradients.training.utils import tensor_container_to_device
     from super_gradients.training.utils.detection_utils import compute_detection_matching, compute_detection_metrics
     from super_gradients.training.utils.detection_utils import compute_detection_matching, compute_detection_metrics
     from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback, IouThreshold
     from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback, IouThreshold
    @@ -10,6 +13,7 @@ from super_gradients.common.abstractions.abstract_logger import get_logger
     logger = get_logger(__name__)
     logger = get_logger(__name__)
     
     
     
     
    +@register_metric(Metrics.DETECTION_METRICS)
     class DetectionMetrics(Metric):
     class DetectionMetrics(Metric):
         """
         """
         DetectionMetrics
         DetectionMetrics
    @@ -176,6 +180,7 @@ class DetectionMetrics(Metric):
             return "@%.2f" % self.iou_thresholds[0] if not len(self.iou_thresholds) > 1 else "@%.2f:%.2f" % (self.iou_thresholds[0], self.iou_thresholds[-1])
             return "@%.2f" % self.iou_thresholds[0] if not len(self.iou_thresholds) > 1 else "@%.2f:%.2f" % (self.iou_thresholds[0], self.iou_thresholds[-1])
     
     
     
     
    +@register_metric(Metrics.DETECTION_METRICS_050)
     class DetectionMetrics_050(DetectionMetrics):
     class DetectionMetrics_050(DetectionMetrics):
         def __init__(
         def __init__(
             self,
             self,
    @@ -202,6 +207,7 @@ class DetectionMetrics_050(DetectionMetrics):
             )
             )
     
     
     
     
    +@register_metric(Metrics.DETECTION_METRICS_075)
     class DetectionMetrics_075(DetectionMetrics):
     class DetectionMetrics_075(DetectionMetrics):
         def __init__(
         def __init__(
             self,
             self,
    @@ -220,6 +226,7 @@ class DetectionMetrics_075(DetectionMetrics):
             )
             )
     
     
     
     
    +@register_metric(Metrics.DETECTION_METRICS_050_095)
     class DetectionMetrics_050_095(DetectionMetrics):
     class DetectionMetrics_050_095(DetectionMetrics):
         def __init__(
         def __init__(
             self,
             self,
    Discard
    @@ -8,6 +8,8 @@ from torchmetrics import Metric
     
     
     from super_gradients.common.abstractions.abstract_logger import get_logger
     from super_gradients.common.abstractions.abstract_logger import get_logger
     from super_gradients.common.environment.ddp_utils import is_distributed
     from super_gradients.common.environment.ddp_utils import is_distributed
    +from super_gradients.common.object_names import Metrics
    +from super_gradients.common.registry.registry import register_metric
     from super_gradients.training.metrics.pose_estimation_utils import compute_img_keypoint_matching, compute_visible_bbox_xywh
     from super_gradients.training.metrics.pose_estimation_utils import compute_img_keypoint_matching, compute_visible_bbox_xywh
     from super_gradients.training.utils.detection_utils import compute_detection_metrics_per_cls
     from super_gradients.training.utils.detection_utils import compute_detection_metrics_per_cls
     
     
    @@ -16,6 +18,7 @@ logger = get_logger(__name__)
     __all__ = ["PoseEstimationMetrics"]
     __all__ = ["PoseEstimationMetrics"]
     
     
     
     
    +@register_metric(Metrics.POSE_ESTIMATION_METRICS)
     class PoseEstimationMetrics(Metric):
     class PoseEstimationMetrics(Metric):
         """
         """
         Implementation of COCO Keypoint evaluation metric.
         Implementation of COCO Keypoint evaluation metric.
    Discard
    @@ -7,6 +7,10 @@ from torchmetrics.utilities.distributed import reduce
     from abc import ABC, abstractmethod
     from abc import ABC, abstractmethod
     
     
     
     
    +from super_gradients.common.object_names import Metrics
    +from super_gradients.common.registry.registry import register_metric
    +
    +
     def batch_pix_accuracy(predict, target):
     def batch_pix_accuracy(predict, target):
         """Batch Pixel Accuracy
         """Batch Pixel Accuracy
         Args:
         Args:
    @@ -160,6 +164,7 @@ class PreprocessSegmentationMetricsArgs(AbstractMetricsArgsPrepFn):
             return preds, target
             return preds, target
     
     
     
     
    +@register_metric(Metrics.PIXEL_ACCURACY)
     class PixelAccuracy(Metric):
     class PixelAccuracy(Metric):
         def __init__(self, ignore_label=-100, dist_sync_on_step=False, metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None):
         def __init__(self, ignore_label=-100, dist_sync_on_step=False, metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None):
             super().__init__(dist_sync_on_step=dist_sync_on_step)
             super().__init__(dist_sync_on_step=dist_sync_on_step)
    @@ -185,6 +190,7 @@ class PixelAccuracy(Metric):
             return pix_acc
             return pix_acc
     
     
     
     
    +@register_metric(Metrics.IOU)
     class IoU(torchmetrics.JaccardIndex):
     class IoU(torchmetrics.JaccardIndex):
         def __init__(
         def __init__(
             self,
             self,
    @@ -208,6 +214,7 @@ class IoU(torchmetrics.JaccardIndex):
             super().update(preds=preds, target=target)
             super().update(preds=preds, target=target)
     
     
     
     
    +@register_metric(Metrics.DICE)
     class Dice(torchmetrics.JaccardIndex):
     class Dice(torchmetrics.JaccardIndex):
         def __init__(
         def __init__(
             self,
             self,
    @@ -235,6 +242,7 @@ class Dice(torchmetrics.JaccardIndex):
             return _dice_from_confmat(self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction)
             return _dice_from_confmat(self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction)
     
     
     
     
    +@register_metric(Metrics.BINARY_IOU)
     class BinaryIOU(IoU):
     class BinaryIOU(IoU):
         def __init__(
         def __init__(
             self,
             self,
    @@ -264,6 +272,7 @@ class BinaryIOU(IoU):
             return {"target_IOU": ious[1], "background_IOU": ious[0], "mean_IOU": ious.mean()}
             return {"target_IOU": ious[1], "background_IOU": ious[0], "mean_IOU": ious.mean()}
     
     
     
     
    +@register_metric(Metrics.BINARY_DICE)
     class BinaryDice(Dice):
     class BinaryDice(Dice):
         def __init__(
         def __init__(
             self,
             self,
    Discard
    @@ -67,6 +67,7 @@ from super_gradients.training.models.detection_models.pp_yolo_e.pp_yolo_e import
     from super_gradients.training.models.detection_models.darknet53 import Darknet53, Darknet53Base
     from super_gradients.training.models.detection_models.darknet53 import Darknet53, Darknet53Base
     from super_gradients.training.models.detection_models.ssd import SSDMobileNetV1, SSDLiteMobileNetV2
     from super_gradients.training.models.detection_models.ssd import SSDMobileNetV1, SSDLiteMobileNetV2
     from super_gradients.training.models.detection_models.yolo_base import YoloBase, YoloPostPredictionCallback
     from super_gradients.training.models.detection_models.yolo_base import YoloBase, YoloPostPredictionCallback
    +from super_gradients.training.models.detection_models.yolox import YoloX_N, YoloX_T, YoloX_S, YoloX_M, YoloX_L, YoloX_X, CustomYoloX
     from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector
     from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector
     
     
     # Segmentation models
     # Segmentation models
    @@ -108,7 +109,10 @@ import super_gradients.training.models.user_models as user_models
     from super_gradients.training.models.model_factory import get
     from super_gradients.training.models.model_factory import get
     from super_gradients.training.models.arch_params_factory import get_arch_params
     from super_gradients.training.models.arch_params_factory import get_arch_params
     from super_gradients.training.models.conversion import convert_to_onnx, convert_from_config
     from super_gradients.training.models.conversion import convert_to_onnx, convert_from_config
    -from super_gradients.training.models.all_architectures import ARCHITECTURES, Models
    +
    +
    +from super_gradients.common.object_names import Models
    +from super_gradients.common.registry.registry import ARCHITECTURES
     
     
     __all__ = [
     __all__ = [
         "SgModule",
         "SgModule",
    @@ -198,6 +202,13 @@ __all__ = [
         "SSDMobileNetV1",
         "SSDMobileNetV1",
         "SSDLiteMobileNetV2",
         "SSDLiteMobileNetV2",
         "YoloBase",
         "YoloBase",
    +    "YoloX_N",
    +    "YoloX_T",
    +    "YoloX_S",
    +    "YoloX_M",
    +    "YoloX_L",
    +    "YoloX_X",
    +    "CustomYoloX",
         "YoloPostPredictionCallback",
         "YoloPostPredictionCallback",
         "CustomizableDetector",
         "CustomizableDetector",
         "ShelfNet50",
         "ShelfNet50",
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    1. from super_gradients.common.object_names import Models
    2. from super_gradients.training.models import ResNeXt50, ResNeXt101, GoogleNetV1
    3. from super_gradients.training.models.classification_models import repvgg, efficientnet, densenet, resnet, regnet
    4. from super_gradients.training.models.classification_models.mobilenetv2 import MobileNetV2Base, MobileNetV2_135, CustomMobileNetV2
    5. from super_gradients.training.models.classification_models.mobilenetv3 import mobilenetv3_large, mobilenetv3_small, mobilenetv3_custom
    6. from super_gradients.training.models.classification_models.shufflenetv2 import (
    7. ShufflenetV2_x0_5,
    8. ShufflenetV2_x1_0,
    9. ShufflenetV2_x1_5,
    10. ShufflenetV2_x2_0,
    11. CustomizedShuffleNetV2,
    12. )
    13. from super_gradients.training.models.classification_models.vit import ViTBase, ViTLarge, ViTHuge
    14. from super_gradients.training.models.detection_models.csp_darknet53 import CSPDarknet53
    15. from super_gradients.training.models.detection_models.darknet53 import Darknet53
    16. from super_gradients.training.models.detection_models.pp_yolo_e.pp_yolo_e import PPYoloE_M, PPYoloE_L, PPYoloE_X, PPYoloE_S
    17. from super_gradients.training.models.detection_models.ssd import SSDMobileNetV1, SSDLiteMobileNetV2
    18. from super_gradients.training.models.detection_models.yolox import YoloX_N, YoloX_T, YoloX_S, YoloX_M, YoloX_L, YoloX_X, CustomYoloX
    19. from super_gradients.training.models.pose_estimation_models.dekr_hrnet import DEKRPoseEstimationModel, DEKRW32
    20. from super_gradients.training.models.pose_estimation_models.pose_ddrnet39 import PoseDDRNet39
    21. from super_gradients.training.models.pose_estimation_models.pose_ppyolo import PosePPYoloL
    22. from super_gradients.training.models.segmentation_models.ddrnet import DDRNet23, DDRNet23Slim, AnyBackBoneDDRNet23, DDRNet39
    23. from super_gradients.training.models.segmentation_models.regseg import RegSeg48
    24. from super_gradients.training.models.segmentation_models.shelfnet import ShelfNet18_LW, ShelfNet34_LW, ShelfNet50, ShelfNet503343, ShelfNet101
    25. from super_gradients.training.models.segmentation_models.stdc import (
    26. STDC1Classification,
    27. STDC2Classification,
    28. STDC1Seg,
    29. STDC2Seg,
    30. CustomSTDCSegmentation,
    31. STDCClassification,
    32. )
    33. from super_gradients.training.models.kd_modules.kd_module import KDModule
    34. from super_gradients.training.models.classification_models.beit import BeitBasePatch16_224, BeitLargePatch16_224
    35. from super_gradients.training.models.segmentation_models.ppliteseg import PPLiteSegT, PPLiteSegB
    36. from super_gradients.training.models.segmentation_models.unet import UNetCustom, UnetClassificationCustom, UNet
    37. ARCHITECTURES = {
    38. Models.RESNET18: resnet.ResNet18,
    39. Models.RESNET34: resnet.ResNet34,
    40. Models.RESNET50_3343: resnet.ResNet50_3343,
    41. Models.RESNET50: resnet.ResNet50,
    42. Models.RESNET101: resnet.ResNet101,
    43. Models.RESNET152: resnet.ResNet152,
    44. Models.RESNET18_CIFAR: resnet.ResNet18Cifar,
    45. Models.CUSTOM_RESNET: resnet.CustomizedResnet,
    46. Models.CUSTOM_RESNET50: resnet.CustomizedResnet50,
    47. Models.CUSTOM_RESNET_CIFAR: resnet.CustomizedResnetCifar,
    48. Models.CUSTOM_RESNET50_CIFAR: resnet.CustomizedResnet50Cifar,
    49. Models.MOBILENET_V2: MobileNetV2Base,
    50. Models.MOBILE_NET_V2_135: MobileNetV2_135,
    51. Models.CUSTOM_MOBILENET_V2: CustomMobileNetV2,
    52. Models.MOBILENET_V3_LARGE: mobilenetv3_large,
    53. Models.MOBILENET_V3_SMALL: mobilenetv3_small,
    54. Models.MOBILENET_V3_CUSTOM: mobilenetv3_custom,
    55. Models.CUSTOM_DENSENET: densenet.CustomizedDensnet,
    56. Models.DENSENET121: densenet.DenseNet121,
    57. Models.DENSENET161: densenet.DenseNet161,
    58. Models.DENSENET169: densenet.DenseNet169,
    59. Models.DENSENET201: densenet.DenseNet201,
    60. Models.SHELFNET18_LW: ShelfNet18_LW,
    61. Models.SHELFNET34_LW: ShelfNet34_LW,
    62. Models.SHELFNET50_3343: ShelfNet503343,
    63. Models.SHELFNET50: ShelfNet50,
    64. Models.SHELFNET101: ShelfNet101,
    65. Models.SHUFFLENET_V2_X0_5: ShufflenetV2_x0_5,
    66. Models.SHUFFLENET_V2_X1_0: ShufflenetV2_x1_0,
    67. Models.SHUFFLENET_V2_X1_5: ShufflenetV2_x1_5,
    68. Models.SHUFFLENET_V2_X2_0: ShufflenetV2_x2_0,
    69. Models.SHUFFLENET_V2_CUSTOM5: CustomizedShuffleNetV2,
    70. Models.DARKNET53: Darknet53,
    71. Models.CSP_DARKNET53: CSPDarknet53,
    72. Models.RESNEXT50: ResNeXt50,
    73. Models.RESNEXT101: ResNeXt101,
    74. Models.GOOGLENET_V1: GoogleNetV1,
    75. Models.EFFICIENTNET_B0: efficientnet.EfficientNetB0,
    76. Models.EFFICIENTNET_B1: efficientnet.EfficientNetB1,
    77. Models.EFFICIENTNET_B2: efficientnet.EfficientNetB2,
    78. Models.EFFICIENTNET_B3: efficientnet.EfficientNetB3,
    79. Models.EFFICIENTNET_B4: efficientnet.EfficientNetB4,
    80. Models.EFFICIENTNET_B5: efficientnet.EfficientNetB5,
    81. Models.EFFICIENTNET_B6: efficientnet.EfficientNetB6,
    82. Models.EFFICIENTNET_B7: efficientnet.EfficientNetB7,
    83. Models.EFFICIENTNET_B8: efficientnet.EfficientNetB8,
    84. Models.EFFICIENTNET_L2: efficientnet.EfficientNetL2,
    85. Models.CUSTOMIZEDEFFICIENTNET: efficientnet.CustomizedEfficientnet,
    86. Models.REGNETY200: regnet.RegNetY200,
    87. Models.REGNETY400: regnet.RegNetY400,
    88. Models.REGNETY600: regnet.RegNetY600,
    89. Models.REGNETY800: regnet.RegNetY800,
    90. Models.CUSTOM_REGNET: regnet.CustomRegNet,
    91. Models.NAS_REGNET: regnet.NASRegNet,
    92. Models.YOLOX_N: YoloX_N,
    93. Models.YOLOX_T: YoloX_T,
    94. Models.YOLOX_S: YoloX_S,
    95. Models.YOLOX_M: YoloX_M,
    96. Models.YOLOX_L: YoloX_L,
    97. Models.YOLOX_X: YoloX_X,
    98. Models.CUSTOM_YOLO_X: CustomYoloX,
    99. Models.SSD_MOBILENET_V1: SSDMobileNetV1,
    100. Models.SSD_LITE_MOBILENET_V2: SSDLiteMobileNetV2,
    101. Models.REPVGG_A0: repvgg.RepVggA0,
    102. Models.REPVGG_A1: repvgg.RepVggA1,
    103. Models.REPVGG_A2: repvgg.RepVggA2,
    104. Models.REPVGG_B0: repvgg.RepVggB0,
    105. Models.REPVGG_B1: repvgg.RepVggB1,
    106. Models.REPVGG_B2: repvgg.RepVggB2,
    107. Models.REPVGG_B3: repvgg.RepVggB3,
    108. Models.REPVGG_D2SE: repvgg.RepVggD2SE,
    109. Models.REPVGG_CUSTOM: repvgg.RepVggCustom,
    110. Models.DDRNET_23: DDRNet23,
    111. Models.DDRNET_23_SLIM: DDRNet23Slim,
    112. Models.DDRNET_39: DDRNet39,
    113. Models.CUSTOM_DDRNET_23: AnyBackBoneDDRNet23,
    114. Models.STDC1_CLASSIFICATION: STDC1Classification,
    115. Models.STDC2_CLASSIFICATION: STDC2Classification,
    116. Models.STDC1_SEG: STDC1Seg,
    117. Models.STDC1_SEG50: STDC1Seg,
    118. Models.STDC1_SEG75: STDC1Seg,
    119. Models.STDC2_SEG: STDC2Seg,
    120. Models.STDC2_SEG50: STDC2Seg,
    121. Models.STDC2_SEG75: STDC2Seg,
    122. Models.STDC_CUSTOM: CustomSTDCSegmentation,
    123. Models.STDC_CUSTOM_CLS: STDCClassification,
    124. Models.REGSEG48: RegSeg48,
    125. Models.KD_MODULE: KDModule,
    126. Models.VIT_BASE: ViTBase,
    127. Models.VIT_LARGE: ViTLarge,
    128. Models.VIT_HUGE: ViTHuge,
    129. Models.BEIT_BASE_PATCH16_224: BeitBasePatch16_224,
    130. Models.BEIT_LARGE_PATCH16_224: BeitLargePatch16_224,
    131. Models.PP_LITE_T_SEG: PPLiteSegT,
    132. Models.PP_LITE_T_SEG50: PPLiteSegT,
    133. Models.PP_LITE_T_SEG75: PPLiteSegT,
    134. Models.PP_LITE_B_SEG: PPLiteSegB,
    135. Models.PP_LITE_B_SEG50: PPLiteSegB,
    136. Models.PP_LITE_B_SEG75: PPLiteSegB,
    137. Models.CUSTOM_ANYNET: regnet.CustomAnyNet,
    138. Models.UNET_CUSTOM: UNetCustom,
    139. Models.UNET_CUSTOM_CLS: UnetClassificationCustom,
    140. Models.UNET: UNet,
    141. Models.PP_YOLOE_S: PPYoloE_S,
    142. Models.PP_YOLOE_M: PPYoloE_M,
    143. Models.PP_YOLOE_L: PPYoloE_L,
    144. Models.PP_YOLOE_X: PPYoloE_X,
    145. #
    146. Models.DEKR_CUSTOM: DEKRPoseEstimationModel,
    147. Models.DEKR_W32_NO_DC: DEKRW32,
    148. Models.POSE_PP_YOLO_L: PosePPYoloL,
    149. Models.POSE_DDRNET_39: PoseDDRNet39,
    150. }
    151. KD_ARCHITECTURES = {Models.KD_MODULE: KDModule}
    Discard
    @@ -27,6 +27,9 @@ import torch.nn as nn
     import torch.nn.functional as F
     import torch.nn.functional as F
     from torch.utils.checkpoint import checkpoint
     from torch.utils.checkpoint import checkpoint
     from torch import Tensor
     from torch import Tensor
    +
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.training.models.classification_models.vit import PatchEmbed
     from super_gradients.training.models.classification_models.vit import PatchEmbed
     from super_gradients.training.utils.regularization_utils import DropPath
     from super_gradients.training.utils.regularization_utils import DropPath
     from super_gradients.common.abstractions.abstract_logger import get_logger
     from super_gradients.common.abstractions.abstract_logger import get_logger
    @@ -451,6 +454,7 @@ class Beit(SgModule):
                 self.head = nn.Linear(self.head.in_features, new_num_classes)
                 self.head = nn.Linear(self.head.in_features, new_num_classes)
     
     
     
     
    +@register_model(Models.BEIT_BASE_PATCH16_224)
     class BeitBasePatch16_224(Beit):
     class BeitBasePatch16_224(Beit):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             model_kwargs = HpmStruct(
             model_kwargs = HpmStruct(
    @@ -460,6 +464,7 @@ class BeitBasePatch16_224(Beit):
             super(BeitBasePatch16_224, self).__init__(**model_kwargs.to_dict())
             super(BeitBasePatch16_224, self).__init__(**model_kwargs.to_dict())
     
     
     
     
    +@register_model(Models.BEIT_LARGE_PATCH16_224)
     class BeitLargePatch16_224(Beit):
     class BeitLargePatch16_224(Beit):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             model_kwargs = HpmStruct(
             model_kwargs = HpmStruct(
    Discard
    @@ -3,6 +3,9 @@ import torch.nn as nn
     import torch.nn.functional as F
     import torch.nn.functional as F
     from torch import Tensor
     from torch import Tensor
     from collections import OrderedDict
     from collections import OrderedDict
    +
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.models.sg_module import SgModule
     
     
     """Densenet-BC model class, based on
     """Densenet-BC model class, based on
    @@ -130,6 +133,7 @@ class DenseNet(SgModule):
             return out
             return out
     
     
     
     
    +@register_model(Models.CUSTOM_DENSENET)
     class CustomizedDensnet(DenseNet):
     class CustomizedDensnet(DenseNet):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             super().__init__(
             super().__init__(
    @@ -142,21 +146,25 @@ class CustomizedDensnet(DenseNet):
             )
             )
     
     
     
     
    +@register_model(Models.DENSENET121)
     class DenseNet121(DenseNet):
     class DenseNet121(DenseNet):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             super().__init__(32, [6, 12, 24, 16], 64, 4, 0, arch_params.num_classes)
             super().__init__(32, [6, 12, 24, 16], 64, 4, 0, arch_params.num_classes)
     
     
     
     
    +@register_model(Models.DENSENET161)
     class DenseNet161(DenseNet):
     class DenseNet161(DenseNet):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             super().__init__(48, [6, 12, 36, 24], 96, 4, 0, arch_params.num_classes)
             super().__init__(48, [6, 12, 36, 24], 96, 4, 0, arch_params.num_classes)
     
     
     
     
    +@register_model(Models.DENSENET169)
     class DenseNet169(DenseNet):
     class DenseNet169(DenseNet):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             super().__init__(32, [6, 12, 32, 32], 64, 4, 0, arch_params.num_classes)
             super().__init__(32, [6, 12, 32, 32], 64, 4, 0, arch_params.num_classes)
     
     
     
     
    +@register_model(Models.DENSENET201)
     class DenseNet201(DenseNet):
     class DenseNet201(DenseNet):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             super().__init__(32, [6, 12, 48, 32], 64, 4, 0, arch_params.num_classes)
             super().__init__(32, [6, 12, 48, 32], 64, 4, 0, arch_params.num_classes)
    Discard
    @@ -22,6 +22,9 @@ import torch
     from torch import nn
     from torch import nn
     from torch.nn import functional as F
     from torch.nn import functional as F
     from collections import OrderedDict
     from collections import OrderedDict
    +
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.training.utils import HpmStruct
     from super_gradients.training.utils import HpmStruct
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.models.sg_module import SgModule
     
     
    @@ -592,66 +595,77 @@ def get_efficientnet_params(width: float, depth: float, res: float, dropout: flo
         return blocks_args, arch_params_new
         return blocks_args, arch_params_new
     
     
     
     
    +@register_model(Models.EFFICIENTNET_B0)
     class EfficientNetB0(EfficientNet):
     class EfficientNetB0(EfficientNet):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             blocks_args, arch_params = get_efficientnet_params(width=1.0, depth=1.0, res=224, dropout=0.2, arch_params=arch_params)
             blocks_args, arch_params = get_efficientnet_params(width=1.0, depth=1.0, res=224, dropout=0.2, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
     
     
     
     
    +@register_model(Models.EFFICIENTNET_B1)
     class EfficientNetB1(EfficientNet):
     class EfficientNetB1(EfficientNet):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             blocks_args, arch_params = get_efficientnet_params(width=1.0, depth=1.1, res=240, dropout=0.2, arch_params=arch_params)
             blocks_args, arch_params = get_efficientnet_params(width=1.0, depth=1.1, res=240, dropout=0.2, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
     
     
     
     
    +@register_model(Models.EFFICIENTNET_B2)
     class EfficientNetB2(EfficientNet):
     class EfficientNetB2(EfficientNet):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             blocks_args, arch_params = get_efficientnet_params(width=1.1, depth=1.2, res=260, dropout=0.3, arch_params=arch_params)
             blocks_args, arch_params = get_efficientnet_params(width=1.1, depth=1.2, res=260, dropout=0.3, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
     
     
     
     
    +@register_model(Models.EFFICIENTNET_B3)
     class EfficientNetB3(EfficientNet):
     class EfficientNetB3(EfficientNet):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             blocks_args, arch_params = get_efficientnet_params(width=1.2, depth=1.4, res=300, dropout=0.3, arch_params=arch_params)
             blocks_args, arch_params = get_efficientnet_params(width=1.2, depth=1.4, res=300, dropout=0.3, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
     
     
     
     
    +@register_model(Models.EFFICIENTNET_B4)
     class EfficientNetB4(EfficientNet):
     class EfficientNetB4(EfficientNet):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             blocks_args, arch_params = get_efficientnet_params(width=1.4, depth=1.8, res=380, dropout=0.4, arch_params=arch_params)
             blocks_args, arch_params = get_efficientnet_params(width=1.4, depth=1.8, res=380, dropout=0.4, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
     
     
     
     
    +@register_model(Models.EFFICIENTNET_B5)
     class EfficientNetB5(EfficientNet):
     class EfficientNetB5(EfficientNet):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             blocks_args, arch_params = get_efficientnet_params(width=1.6, depth=2.2, res=456, dropout=0.4, arch_params=arch_params)
             blocks_args, arch_params = get_efficientnet_params(width=1.6, depth=2.2, res=456, dropout=0.4, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
     
     
     
     
    +@register_model(Models.EFFICIENTNET_B6)
     class EfficientNetB6(EfficientNet):
     class EfficientNetB6(EfficientNet):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             blocks_args, arch_params = get_efficientnet_params(width=1.8, depth=2.6, res=528, dropout=0.5, arch_params=arch_params)
             blocks_args, arch_params = get_efficientnet_params(width=1.8, depth=2.6, res=528, dropout=0.5, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
     
     
     
     
    +@register_model(Models.EFFICIENTNET_B7)
     class EfficientNetB7(EfficientNet):
     class EfficientNetB7(EfficientNet):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             blocks_args, arch_params = get_efficientnet_params(width=2.0, depth=3.1, res=600, dropout=0.5, arch_params=arch_params)
             blocks_args, arch_params = get_efficientnet_params(width=2.0, depth=3.1, res=600, dropout=0.5, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
     
     
     
     
    +@register_model(Models.EFFICIENTNET_B8)
     class EfficientNetB8(EfficientNet):
     class EfficientNetB8(EfficientNet):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             blocks_args, arch_params = get_efficientnet_params(width=2.2, depth=3.6, res=672, dropout=0.5, arch_params=arch_params)
             blocks_args, arch_params = get_efficientnet_params(width=2.2, depth=3.6, res=672, dropout=0.5, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
     
     
     
     
    +@register_model(Models.EFFICIENTNET_L2)
     class EfficientNetL2(EfficientNet):
     class EfficientNetL2(EfficientNet):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             blocks_args, arch_params = get_efficientnet_params(width=4.3, depth=5.3, res=800, dropout=0.5, arch_params=arch_params)
             blocks_args, arch_params = get_efficientnet_params(width=4.3, depth=5.3, res=800, dropout=0.5, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
             super().__init__(blocks_args=blocks_args, arch_params=arch_params)
     
     
     
     
    +@register_model(Models.CUSTOMIZEDEFFICIENTNET)
     class CustomizedEfficientnet(EfficientNet):
     class CustomizedEfficientnet(EfficientNet):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             blocks_args, arch_params = get_efficientnet_params(
             blocks_args, arch_params = get_efficientnet_params(
    Discard
    @@ -7,6 +7,9 @@ import torch
     import torch.nn as nn
     import torch.nn as nn
     import torch.nn.functional as F
     import torch.nn.functional as F
     from collections import OrderedDict
     from collections import OrderedDict
    +
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.models.sg_module import SgModule
     
     
     GoogLeNetOutputs = namedtuple("GoogLeNetOutputs", ["log_", "aux_logits2", "aux_logits1"])
     GoogLeNetOutputs = namedtuple("GoogLeNetOutputs", ["log_", "aux_logits2", "aux_logits1"])
    @@ -232,6 +235,7 @@ class BasicConv2d(nn.Module):
             return x
             return x
     
     
     
     
    +@register_model(Models.GOOGLENET_V1)
     class GoogleNetV1(GoogLeNet):
     class GoogleNetV1(GoogLeNet):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             super(GoogleNetV1, self).__init__(aux_logits=False, num_classes=arch_params.num_classes, dropout=arch_params.dropout)
             super(GoogleNetV1, self).__init__(aux_logits=False, num_classes=arch_params.num_classes, dropout=arch_params.dropout)
    Discard
    @@ -12,6 +12,9 @@ import numpy as np
     import torch
     import torch
     import torch.nn as nn
     import torch.nn as nn
     import math
     import math
    +
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.utils.utils import get_param
     from super_gradients.training.utils.utils import get_param
     
     
    @@ -186,6 +189,7 @@ class MobileNetV2(MobileNetBase):
                     m.bias.data.zero_()
                     m.bias.data.zero_()
     
     
     
     
    +@register_model(Models.MOBILENET_V2)
     class MobileNetV2Base(MobileNetV2):
     class MobileNetV2Base(MobileNetV2):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             """
             """
    @@ -201,6 +205,7 @@ class MobileNetV2Base(MobileNetV2):
             )
             )
     
     
     
     
    +@register_model(Models.MOBILE_NET_V2_135)
     class MobileNetV2_135(MobileNetV2):
     class MobileNetV2_135(MobileNetV2):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             """
             """
    @@ -217,6 +222,7 @@ class MobileNetV2_135(MobileNetV2):
             )
             )
     
     
     
     
    +@register_model(Models.CUSTOM_MOBILENET_V2)
     class CustomMobileNetV2(MobileNetV2):
     class CustomMobileNetV2(MobileNetV2):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             """
             """
    Discard
    @@ -8,6 +8,9 @@ arXiv preprint arXiv:1905.02244.
     
     
     import torch.nn as nn
     import torch.nn as nn
     import math
     import math
    +
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.training.models.classification_models.mobilenetv2 import MobileNetBase
     from super_gradients.training.models.classification_models.mobilenetv2 import MobileNetBase
     from super_gradients.training.utils import get_param
     from super_gradients.training.utils import get_param
     
     
    @@ -171,6 +174,7 @@ class MobileNetV3(MobileNetBase):
                     m.bias.data.zero_()
                     m.bias.data.zero_()
     
     
     
     
    +@register_model(Models.MOBILENET_V3_LARGE)
     class mobilenetv3_large(MobileNetV3):
     class mobilenetv3_large(MobileNetV3):
         """
         """
         Constructs a MobileNetV3-Large model
         Constructs a MobileNetV3-Large model
    @@ -199,6 +203,7 @@ class mobilenetv3_large(MobileNetV3):
             super().__init__(cfgs, mode="large", num_classes=arch_params.num_classes, width_mult=width_mult, in_channels=get_param(arch_params, "in_channels", 3))
             super().__init__(cfgs, mode="large", num_classes=arch_params.num_classes, width_mult=width_mult, in_channels=get_param(arch_params, "in_channels", 3))
     
     
     
     
    +@register_model(Models.MOBILENET_V3_SMALL)
     class mobilenetv3_small(MobileNetV3):
     class mobilenetv3_small(MobileNetV3):
         """
         """
         Constructs a MobileNetV3-Small model
         Constructs a MobileNetV3-Small model
    @@ -223,6 +228,7 @@ class mobilenetv3_small(MobileNetV3):
             super().__init__(cfgs, mode="small", num_classes=arch_params.num_classes, width_mult=width_mult, in_channels=get_param(arch_params, "in_channels", 3))
             super().__init__(cfgs, mode="small", num_classes=arch_params.num_classes, width_mult=width_mult, in_channels=get_param(arch_params, "in_channels", 3))
     
     
     
     
    +@register_model(Models.MOBILENET_V3_CUSTOM)
     class mobilenetv3_custom(MobileNetV3):
     class mobilenetv3_custom(MobileNetV3):
         """
         """
         Constructs a MobileNetV3-Customized model
         Constructs a MobileNetV3-Customized model
    Discard
    @@ -8,6 +8,8 @@ import numpy as np
     import torch.nn as nn
     import torch.nn as nn
     from math import sqrt
     from math import sqrt
     
     
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.modules import Residual
     from super_gradients.modules import Residual
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.utils.regularization_utils import DropPath
     from super_gradients.training.utils.regularization_utils import DropPath
    @@ -233,6 +235,7 @@ def verify_correctness_of_parameters(ls_num_blocks, ls_block_width, ls_bottlenec
             assert int(block_width // bottleneck_ratio) % group_width == 0
             assert int(block_width // bottleneck_ratio) % group_width == 0
     
     
     
     
    +@register_model(Models.CUSTOM_REGNET)
     class CustomRegNet(RegNetX):
     class CustomRegNet(RegNetX):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             """All parameters must be provided in arch_params other than SE"""
             """All parameters must be provided in arch_params other than SE"""
    @@ -250,6 +253,7 @@ class CustomRegNet(RegNetX):
             )
             )
     
     
     
     
    +@register_model(Models.CUSTOM_ANYNET)
     class CustomAnyNet(AnyNetX):
     class CustomAnyNet(AnyNetX):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             """All parameters must be provided in arch_params other than SE"""
             """All parameters must be provided in arch_params other than SE"""
    @@ -268,6 +272,7 @@ class CustomAnyNet(AnyNetX):
             )
             )
     
     
     
     
    +@register_model(Models.NAS_REGNET)
     class NASRegNet(RegNetX):
     class NASRegNet(RegNetX):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             """All parameters are provided as a single structure list: arch_params.structure"""
             """All parameters are provided as a single structure list: arch_params.structure"""
    @@ -285,21 +290,25 @@ class NASRegNet(RegNetX):
             )
             )
     
     
     
     
    +@register_model(Models.REGNETY200)
     class RegNetY200(RegNetY):
     class RegNetY200(RegNetY):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             super().__init__(24, 36, 2.5, 13, 1, 8, 2, arch_params, 4)
             super().__init__(24, 36, 2.5, 13, 1, 8, 2, arch_params, 4)
     
     
     
     
    +@register_model(Models.REGNETY400)
     class RegNetY400(RegNetY):
     class RegNetY400(RegNetY):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             super().__init__(48, 28, 2.1, 16, 1, 8, 2, arch_params, 4)
             super().__init__(48, 28, 2.1, 16, 1, 8, 2, arch_params, 4)
     
     
     
     
    +@register_model(Models.REGNETY600)
     class RegNetY600(RegNetY):
     class RegNetY600(RegNetY):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             super().__init__(48, 33, 2.3, 15, 1, 16, 2, arch_params, 4)
             super().__init__(48, 33, 2.3, 15, 1, 16, 2, arch_params, 4)
     
     
     
     
    +@register_model(Models.REGNETY800)
     class RegNetY800(RegNetY):
     class RegNetY800(RegNetY):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             super().__init__(56, 39, 2.4, 14, 1, 16, 2, arch_params, 4)
             super().__init__(56, 39, 2.4, 14, 1, 16, 2, arch_params, 4)
    Discard
    @@ -12,6 +12,8 @@ from typing import Union
     
     
     import torch.nn as nn
     import torch.nn as nn
     
     
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.modules import RepVGGBlock, SEBlock
     from super_gradients.modules import RepVGGBlock, SEBlock
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.modules.repvgg_block import fuse_repvgg_blocks_residual_branches
     from super_gradients.modules.repvgg_block import fuse_repvgg_blocks_residual_branches
    @@ -129,6 +131,7 @@ class RepVGG(SgModule):
                 self.linear = nn.Linear(int(512 * self.final_width_mult), new_num_classes)
                 self.linear = nn.Linear(int(512 * self.final_width_mult), new_num_classes)
     
     
     
     
    +@register_model(Models.REPVGG_CUSTOM)
     class RepVggCustom(RepVGG):
     class RepVggCustom(RepVGG):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             super().__init__(
             super().__init__(
    @@ -142,48 +145,56 @@ class RepVggCustom(RepVGG):
             )
             )
     
     
     
     
    +@register_model(Models.REPVGG_A0)
     class RepVggA0(RepVggCustom):
     class RepVggA0(RepVggCustom):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             arch_params.override(struct=[2, 4, 14, 1], width_multiplier=[0.75, 0.75, 0.75, 2.5])
             arch_params.override(struct=[2, 4, 14, 1], width_multiplier=[0.75, 0.75, 0.75, 2.5])
             super().__init__(arch_params=arch_params)
             super().__init__(arch_params=arch_params)
     
     
     
     
    +@register_model(Models.REPVGG_A1)
     class RepVggA1(RepVggCustom):
     class RepVggA1(RepVggCustom):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             arch_params.override(struct=[2, 4, 14, 1], width_multiplier=[1, 1, 1, 2.5])
             arch_params.override(struct=[2, 4, 14, 1], width_multiplier=[1, 1, 1, 2.5])
             super().__init__(arch_params=arch_params)
             super().__init__(arch_params=arch_params)
     
     
     
     
    +@register_model(Models.REPVGG_A2)
     class RepVggA2(RepVggCustom):
     class RepVggA2(RepVggCustom):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             arch_params.override(struct=[2, 4, 14, 1], width_multiplier=[1.5, 1.5, 1.5, 2.75])
             arch_params.override(struct=[2, 4, 14, 1], width_multiplier=[1.5, 1.5, 1.5, 2.75])
             super().__init__(arch_params=arch_params)
             super().__init__(arch_params=arch_params)
     
     
     
     
    +@register_model(Models.REPVGG_B0)
     class RepVggB0(RepVggCustom):
     class RepVggB0(RepVggCustom):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             arch_params.override(struct=[4, 6, 16, 1], width_multiplier=[1, 1, 1, 2.5])
             arch_params.override(struct=[4, 6, 16, 1], width_multiplier=[1, 1, 1, 2.5])
             super().__init__(arch_params=arch_params)
             super().__init__(arch_params=arch_params)
     
     
     
     
    +@register_model(Models.REPVGG_B1)
     class RepVggB1(RepVggCustom):
     class RepVggB1(RepVggCustom):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             arch_params.override(struct=[4, 6, 16, 1], width_multiplier=[2, 2, 2, 4])
             arch_params.override(struct=[4, 6, 16, 1], width_multiplier=[2, 2, 2, 4])
             super().__init__(arch_params=arch_params)
             super().__init__(arch_params=arch_params)
     
     
     
     
    +@register_model(Models.REPVGG_B2)
     class RepVggB2(RepVggCustom):
     class RepVggB2(RepVggCustom):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             arch_params.override(struct=[4, 6, 16, 1], width_multiplier=[2.5, 2.5, 2.5, 5])
             arch_params.override(struct=[4, 6, 16, 1], width_multiplier=[2.5, 2.5, 2.5, 5])
             super().__init__(arch_params=arch_params)
             super().__init__(arch_params=arch_params)
     
     
     
     
    +@register_model(Models.REPVGG_B3)
     class RepVggB3(RepVggCustom):
     class RepVggB3(RepVggCustom):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             arch_params.override(struct=[4, 6, 16, 1], width_multiplier=[3, 3, 3, 5])
             arch_params.override(struct=[4, 6, 16, 1], width_multiplier=[3, 3, 3, 5])
             super().__init__(arch_params=arch_params)
             super().__init__(arch_params=arch_params)
     
     
     
     
    +@register_model(Models.REPVGG_D2SE)
     class RepVggD2SE(RepVggCustom):
     class RepVggD2SE(RepVggCustom):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             arch_params.override(struct=[8, 14, 24, 1], width_multiplier=[2.5, 2.5, 2.5, 5])
             arch_params.override(struct=[8, 14, 24, 1], width_multiplier=[2.5, 2.5, 2.5, 5])
    Discard
    @@ -17,6 +17,8 @@ from collections import OrderedDict
     from super_gradients.training.models import SgModule
     from super_gradients.training.models import SgModule
     from super_gradients.training.utils import get_param
     from super_gradients.training.utils import get_param
     from super_gradients.training.utils.regularization_utils import DropPath
     from super_gradients.training.utils.regularization_utils import DropPath
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     
     
     
     
     def width_multiplier(original, factor):
     def width_multiplier(original, factor):
    @@ -233,6 +235,7 @@ class ResNet(SgModule):
                 self.linear = nn.Linear(width_multiplier(512, self.width_mult) * self.expansion, new_num_classes)
                 self.linear = nn.Linear(width_multiplier(512, self.width_mult) * self.expansion, new_num_classes)
     
     
     
     
    +@register_model(Models.RESNET18)
     class ResNet18(ResNet):
     class ResNet18(ResNet):
         def __init__(self, arch_params, num_classes=None):
         def __init__(self, arch_params, num_classes=None):
             super().__init__(
             super().__init__(
    @@ -244,11 +247,13 @@ class ResNet18(ResNet):
             )
             )
     
     
     
     
    +@register_model(Models.RESNET18_CIFAR)
     class ResNet18Cifar(CifarResNet):
     class ResNet18Cifar(CifarResNet):
         def __init__(self, arch_params, num_classes=None):
         def __init__(self, arch_params, num_classes=None):
             super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes or arch_params.num_classes)
             super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes or arch_params.num_classes)
     
     
     
     
    +@register_model(Models.RESNET34)
     class ResNet34(ResNet):
     class ResNet34(ResNet):
         def __init__(self, arch_params, num_classes=None):
         def __init__(self, arch_params, num_classes=None):
             super().__init__(
             super().__init__(
    @@ -260,6 +265,7 @@ class ResNet34(ResNet):
             )
             )
     
     
     
     
    +@register_model(Models.RESNET50)
     class ResNet50(ResNet):
     class ResNet50(ResNet):
         def __init__(self, arch_params, num_classes=None):
         def __init__(self, arch_params, num_classes=None):
             super().__init__(
             super().__init__(
    @@ -272,6 +278,7 @@ class ResNet50(ResNet):
             )
             )
     
     
     
     
    +@register_model(Models.RESNET50_3343)
     class ResNet50_3343(ResNet):
     class ResNet50_3343(ResNet):
         def __init__(self, arch_params, num_classes=None):
         def __init__(self, arch_params, num_classes=None):
             super().__init__(
             super().__init__(
    @@ -284,6 +291,7 @@ class ResNet50_3343(ResNet):
             )
             )
     
     
     
     
    +@register_model(Models.RESNET101)
     class ResNet101(ResNet):
     class ResNet101(ResNet):
         def __init__(self, arch_params, num_classes=None):
         def __init__(self, arch_params, num_classes=None):
             super().__init__(
             super().__init__(
    @@ -296,6 +304,7 @@ class ResNet101(ResNet):
             )
             )
     
     
     
     
    +@register_model(Models.RESNET152)
     class ResNet152(ResNet):
     class ResNet152(ResNet):
         def __init__(self, arch_params, num_classes=None):
         def __init__(self, arch_params, num_classes=None):
             super().__init__(
             super().__init__(
    @@ -308,16 +317,19 @@ class ResNet152(ResNet):
             )
             )
     
     
     
     
    +@register_model(Models.CUSTOM_RESNET_CIFAR)
     class CustomizedResnetCifar(CifarResNet):
     class CustomizedResnetCifar(CifarResNet):
         def __init__(self, arch_params, num_classes=None):
         def __init__(self, arch_params, num_classes=None):
             super().__init__(BasicBlock, arch_params.structure, width_mult=arch_params.width_mult, num_classes=num_classes or arch_params.num_classes)
             super().__init__(BasicBlock, arch_params.structure, width_mult=arch_params.width_mult, num_classes=num_classes or arch_params.num_classes)
     
     
     
     
    +@register_model(Models.CUSTOM_RESNET50_CIFAR)
     class CustomizedResnet50Cifar(CifarResNet):
     class CustomizedResnet50Cifar(CifarResNet):
         def __init__(self, arch_params, num_classes=None):
         def __init__(self, arch_params, num_classes=None):
             super().__init__(Bottleneck, arch_params.structure, width_mult=arch_params.width_mult, num_classes=num_classes or arch_params.num_classes, expansion=4)
             super().__init__(Bottleneck, arch_params.structure, width_mult=arch_params.width_mult, num_classes=num_classes or arch_params.num_classes, expansion=4)
     
     
     
     
    +@register_model(Models.CUSTOM_RESNET)
     class CustomizedResnet(ResNet):
     class CustomizedResnet(ResNet):
         def __init__(self, arch_params, num_classes=None):
         def __init__(self, arch_params, num_classes=None):
             super().__init__(
             super().__init__(
    @@ -330,6 +342,7 @@ class CustomizedResnet(ResNet):
             )
             )
     
     
     
     
    +@register_model(Models.CUSTOM_RESNET50)
     class CustomizedResnet50(ResNet):
     class CustomizedResnet50(ResNet):
         def __init__(self, arch_params, num_classes=None):
         def __init__(self, arch_params, num_classes=None):
             super().__init__(
             super().__init__(
    Discard
    @@ -6,6 +6,9 @@ Code adapted from https://github.com/pytorch/vision/blob/master/torchvision/mode
     """
     """
     import torch.nn as nn
     import torch.nn as nn
     import torch.nn.functional as F
     import torch.nn.functional as F
    +
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.models.sg_module import SgModule
     
     
     
     
    @@ -142,11 +145,13 @@ class CustomizedResNeXt(ResNeXt):
             )
             )
     
     
     
     
    +@register_model(Models.RESNEXT50)
     class ResNeXt50(ResNeXt):
     class ResNeXt50(ResNeXt):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             super(ResNeXt50, self).__init__(layers=[3, 4, 6, 3], cardinality=32, bottleneck_width=4, num_classes=arch_params.num_classes)
             super(ResNeXt50, self).__init__(layers=[3, 4, 6, 3], cardinality=32, bottleneck_width=4, num_classes=arch_params.num_classes)
     
     
     
     
    +@register_model(Models.RESNEXT101)
     class ResNeXt101(ResNeXt):
     class ResNeXt101(ResNeXt):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             super(ResNeXt101, self).__init__(layers=[3, 4, 23, 3], cardinality=32, bottleneck_width=8, num_classes=arch_params.num_classes)
             super(ResNeXt101, self).__init__(layers=[3, 4, 23, 3], cardinality=32, bottleneck_width=8, num_classes=arch_params.num_classes)
    Discard
    @@ -12,6 +12,8 @@ import torch
     from torch import Tensor
     from torch import Tensor
     import torch.nn as nn
     import torch.nn as nn
     
     
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.training.utils import HpmStruct
     from super_gradients.training.utils import HpmStruct
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.models.sg_module import SgModule
     
     
    @@ -199,26 +201,31 @@ class ShuffleNetV2Base(SgModule):
             return x
             return x
     
     
     
     
    +@register_model(Models.SHUFFLENET_V2_X0_5)
     class ShufflenetV2_x0_5(ShuffleNetV2Base):
     class ShufflenetV2_x0_5(ShuffleNetV2Base):
         def __init__(self, arch_params: HpmStruct, num_classes: int = 1000, backbone_mode: bool = False):
         def __init__(self, arch_params: HpmStruct, num_classes: int = 1000, backbone_mode: bool = False):
             super().__init__([4, 8, 4], [24, 48, 96, 192, 1024], backbone_mode=backbone_mode, num_classes=num_classes or arch_params.num_classes)
             super().__init__([4, 8, 4], [24, 48, 96, 192, 1024], backbone_mode=backbone_mode, num_classes=num_classes or arch_params.num_classes)
     
     
     
     
    +@register_model(Models.SHUFFLENET_V2_X1_0)
     class ShufflenetV2_x1_0(ShuffleNetV2Base):
     class ShufflenetV2_x1_0(ShuffleNetV2Base):
         def __init__(self, arch_params: HpmStruct, num_classes: int = 1000, backbone_mode: bool = False):
         def __init__(self, arch_params: HpmStruct, num_classes: int = 1000, backbone_mode: bool = False):
             super().__init__([4, 8, 4], [24, 116, 232, 464, 1024], backbone_mode=backbone_mode, num_classes=num_classes or arch_params.num_classes)
             super().__init__([4, 8, 4], [24, 116, 232, 464, 1024], backbone_mode=backbone_mode, num_classes=num_classes or arch_params.num_classes)
     
     
     
     
    +@register_model(Models.SHUFFLENET_V2_X1_5)
     class ShufflenetV2_x1_5(ShuffleNetV2Base):
     class ShufflenetV2_x1_5(ShuffleNetV2Base):
         def __init__(self, arch_params: HpmStruct, num_classes: int = 1000, backbone_mode: bool = False):
         def __init__(self, arch_params: HpmStruct, num_classes: int = 1000, backbone_mode: bool = False):
             super().__init__([4, 8, 4], [24, 176, 352, 704, 1024], backbone_mode=backbone_mode, num_classes=num_classes or arch_params.num_classes)
             super().__init__([4, 8, 4], [24, 176, 352, 704, 1024], backbone_mode=backbone_mode, num_classes=num_classes or arch_params.num_classes)
     
     
     
     
    +@register_model(Models.SHUFFLENET_V2_X2_0)
     class ShufflenetV2_x2_0(ShuffleNetV2Base):
     class ShufflenetV2_x2_0(ShuffleNetV2Base):
         def __init__(self, arch_params: HpmStruct, num_classes: int = 1000, backbone_mode: bool = False):
         def __init__(self, arch_params: HpmStruct, num_classes: int = 1000, backbone_mode: bool = False):
             super().__init__([4, 8, 4], [24, 244, 488, 976, 2048], backbone_mode=backbone_mode, num_classes=num_classes or arch_params.num_classes)
             super().__init__([4, 8, 4], [24, 244, 488, 976, 2048], backbone_mode=backbone_mode, num_classes=num_classes or arch_params.num_classes)
     
     
     
     
    +@register_model(Models.SHUFFLENET_V2_CUSTOM5)
     class CustomizedShuffleNetV2(ShuffleNetV2Base):
     class CustomizedShuffleNetV2(ShuffleNetV2Base):
         def __init__(self, arch_params: HpmStruct, num_classes: int = 1000, backbone_mode: bool = False):
         def __init__(self, arch_params: HpmStruct, num_classes: int = 1000, backbone_mode: bool = False):
             super().__init__(
             super().__init__(
    Discard
    @@ -8,9 +8,12 @@ Code adapted from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorc
     
     
     import torch
     import torch
     from torch import nn
     from torch import nn
    +from einops import repeat
    +
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.training.models import SgModule
     from super_gradients.training.models import SgModule
     from super_gradients.training.utils import get_param
     from super_gradients.training.utils import get_param
    -from einops import repeat
     
     
     
     
     class PatchEmbed(nn.Module):
     class PatchEmbed(nn.Module):
    @@ -195,6 +198,7 @@ class ViT(SgModule):
                 self.head = nn.Linear(self.head.in_features, new_num_classes)
                 self.head = nn.Linear(self.head.in_features, new_num_classes)
     
     
     
     
    +@register_model(Models.VIT_BASE)
     class ViTBase(ViT):
     class ViTBase(ViT):
         def __init__(self, arch_params, num_classes=None, backbone_mode=None):
         def __init__(self, arch_params, num_classes=None, backbone_mode=None):
             super(ViTBase, self).__init__(
             super(ViTBase, self).__init__(
    @@ -212,6 +216,7 @@ class ViTBase(ViT):
             )
             )
     
     
     
     
    +@register_model(Models.VIT_LARGE)
     class ViTLarge(ViT):
     class ViTLarge(ViT):
         def __init__(self, arch_params, num_classes=None, backbone_mode=None):
         def __init__(self, arch_params, num_classes=None, backbone_mode=None):
             super(ViTLarge, self).__init__(
             super(ViTLarge, self).__init__(
    @@ -229,6 +234,7 @@ class ViTLarge(ViT):
             )
             )
     
     
     
     
    +@register_model(Models.VIT_HUGE)
     class ViTHuge(ViT):
     class ViTHuge(ViT):
         def __init__(self, arch_params, num_classes=None, backbone_mode=None):
         def __init__(self, arch_params, num_classes=None, backbone_mode=None):
             super(ViTHuge, self).__init__(
             super(ViTHuge, self).__init__(
    Discard
    @@ -8,6 +8,8 @@ from typing import Tuple, Type
     import torch
     import torch
     import torch.nn as nn
     import torch.nn as nn
     
     
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.modules import Residual
     from super_gradients.modules import Residual
     from super_gradients.training.utils.utils import get_param, HpmStruct
     from super_gradients.training.utils.utils import get_param, HpmStruct
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.models.sg_module import SgModule
    @@ -184,6 +186,7 @@ class ViewModule(nn.Module):
             return x.view(-1, self.features)
             return x.view(-1, self.features)
     
     
     
     
    +@register_model(Models.CSP_DARKNET53)
     class CSPDarknet53(SgModule):
     class CSPDarknet53(SgModule):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             super().__init__()
             super().__init__()
    Discard
    @@ -4,6 +4,7 @@ from pathlib import Path
     from typing import List, Type, Tuple, Union, Optional
     from typing import List, Type, Tuple, Union, Optional
     
     
     import torch
     import torch
    +from super_gradients.common.registry.registry import register_detection_module
     from super_gradients.common.decorators.factory_decorator import resolve_param
     from super_gradients.common.decorators.factory_decorator import resolve_param
     from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
     from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
     from torch import nn, Tensor
     from torch import nn, Tensor
    @@ -108,6 +109,7 @@ class CSPResStage(nn.Module):
             return y
             return y
     
     
     
     
    +@register_detection_module()
     class CSPResNetBackbone(nn.Module):
     class CSPResNetBackbone(nn.Module):
         """
         """
         CSPResNet backbone
         CSPResNet backbone
    Discard
    @@ -1,4 +1,7 @@
     from torch import nn
     from torch import nn
    +
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.utils import get_param
     from super_gradients.training.utils import get_param
     
     
    @@ -67,6 +70,7 @@ class Darknet53Base(SgModule):
             return nn.Sequential(*layers)
             return nn.Sequential(*layers)
     
     
     
     
    +@register_model(Models.DARKNET53)
     class Darknet53(Darknet53Base):
     class Darknet53(Darknet53Base):
         def __init__(self, arch_params=None, backbone_mode=True, num_classes=None):
         def __init__(self, arch_params=None, backbone_mode=True, num_classes=None):
             super(Darknet53, self).__init__()
             super(Darknet53, self).__init__()
    Discard
    @@ -2,9 +2,11 @@ import collections
     from typing import Type, Tuple, List
     from typing import Type, Tuple, List
     
     
     import torch
     import torch
    +from torch import nn, Tensor
    +
    +from super_gradients.common.registry.registry import register_detection_module
     from super_gradients.common.decorators.factory_decorator import resolve_param
     from super_gradients.common.decorators.factory_decorator import resolve_param
     from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
     from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
    -from torch import nn, Tensor
     from super_gradients.training.models.detection_models.csp_resnet import CSPResNetBasicBlock
     from super_gradients.training.models.detection_models.csp_resnet import CSPResNetBasicBlock
     from super_gradients.modules import ConvBNAct
     from super_gradients.modules import ConvBNAct
     
     
    @@ -65,6 +67,7 @@ class CSPStage(nn.Module):
             return y
             return y
     
     
     
     
    +@register_detection_module()
     class CustomCSPPAN(nn.Module):
     class CustomCSPPAN(nn.Module):
         @resolve_param("activation", ActivationsTypeFactory())
         @resolve_param("activation", ActivationsTypeFactory())
         def __init__(
         def __init__(
    Discard
    @@ -2,6 +2,8 @@ from typing import Union
     
     
     from torch import Tensor
     from torch import Tensor
     
     
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.modules import RepVGGBlock
     from super_gradients.modules import RepVGGBlock
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.models.detection_models.csp_resnet import CSPResNetBackbone
     from super_gradients.training.models.detection_models.csp_resnet import CSPResNetBackbone
    @@ -48,6 +50,7 @@ class PPYoloE(SgModule):
                 self.head.replace_num_classes(new_num_classes)
                 self.head.replace_num_classes(new_num_classes)
     
     
     
     
    +@register_model(Models.PP_YOLOE_S)
     class PPYoloE_S(PPYoloE):
     class PPYoloE_S(PPYoloE):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             if isinstance(arch_params, HpmStruct):
             if isinstance(arch_params, HpmStruct):
    @@ -56,6 +59,7 @@ class PPYoloE_S(PPYoloE):
             super().__init__(arch_params)
             super().__init__(arch_params)
     
     
     
     
    +@register_model(Models.PP_YOLOE_M)
     class PPYoloE_M(PPYoloE):
     class PPYoloE_M(PPYoloE):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             if isinstance(arch_params, HpmStruct):
             if isinstance(arch_params, HpmStruct):
    @@ -64,6 +68,7 @@ class PPYoloE_M(PPYoloE):
             super().__init__(arch_params)
             super().__init__(arch_params)
     
     
     
     
    +@register_model(Models.PP_YOLOE_L)
     class PPYoloE_L(PPYoloE):
     class PPYoloE_L(PPYoloE):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             if isinstance(arch_params, HpmStruct):
             if isinstance(arch_params, HpmStruct):
    @@ -72,6 +77,7 @@ class PPYoloE_L(PPYoloE):
             super().__init__(arch_params)
             super().__init__(arch_params)
     
     
     
     
    +@register_model(Models.PP_YOLOE_X)
     class PPYoloE_X(PPYoloE):
     class PPYoloE_X(PPYoloE):
         def __init__(self, arch_params):
         def __init__(self, arch_params):
             if isinstance(arch_params, HpmStruct):
             if isinstance(arch_params, HpmStruct):
    Discard
    @@ -3,6 +3,8 @@ from typing import Union
     
     
     from omegaconf import DictConfig
     from omegaconf import DictConfig
     
     
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.training.models.arch_params_factory import get_arch_params
     from super_gradients.training.models.arch_params_factory import get_arch_params
     from super_gradients.training.utils.utils import HpmStruct
     from super_gradients.training.utils.utils import HpmStruct
     from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector
     from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector
    @@ -12,6 +14,7 @@ DEFAULT_SSD_MOBILENET_V1_ARCH_PARAMS = get_arch_params("ssd_mobilenetv1_arch_par
     DEFAULT_SSD_LITE_MOBILENET_V2_ARCH_PARAMS = get_arch_params("ssd_lite_mobilenetv2_arch_params")
     DEFAULT_SSD_LITE_MOBILENET_V2_ARCH_PARAMS = get_arch_params("ssd_lite_mobilenetv2_arch_params")
     
     
     
     
    +@register_model(Models.SSD_MOBILENET_V1)
     class SSDMobileNetV1(CustomizableDetector):
     class SSDMobileNetV1(CustomizableDetector):
         def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
         def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
             merged_arch_params = HpmStruct(**copy.deepcopy(DEFAULT_SSD_MOBILENET_V1_ARCH_PARAMS))
             merged_arch_params = HpmStruct(**copy.deepcopy(DEFAULT_SSD_MOBILENET_V1_ARCH_PARAMS))
    @@ -19,6 +22,7 @@ class SSDMobileNetV1(CustomizableDetector):
             super().__init__(merged_arch_params, in_channels=in_channels)
             super().__init__(merged_arch_params, in_channels=in_channels)
     
     
     
     
    +@register_model(Models.SSD_LITE_MOBILENET_V2)
     class SSDLiteMobileNetV2(CustomizableDetector):
     class SSDLiteMobileNetV2(CustomizableDetector):
         def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
         def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
             merged_arch_params = HpmStruct(**copy.deepcopy(DEFAULT_SSD_LITE_MOBILENET_V2_ARCH_PARAMS))
             merged_arch_params = HpmStruct(**copy.deepcopy(DEFAULT_SSD_LITE_MOBILENET_V2_ARCH_PARAMS))
    Discard
    @@ -1,7 +1,10 @@
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.training.models.detection_models.yolo_base import YoloBase, YoloDarknetBackbone
     from super_gradients.training.models.detection_models.yolo_base import YoloBase, YoloDarknetBackbone
     from super_gradients.training.utils.utils import HpmStruct
     from super_gradients.training.utils.utils import HpmStruct
     
     
     
     
    +@register_model(Models.YOLOX_N)
     class YoloX_N(YoloBase):
     class YoloX_N(YoloBase):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             arch_params.depth_mult_factor = 0.33
             arch_params.depth_mult_factor = 0.33
    @@ -11,6 +14,7 @@ class YoloX_N(YoloBase):
             super().__init__(backbone=YoloDarknetBackbone, arch_params=arch_params)
             super().__init__(backbone=YoloDarknetBackbone, arch_params=arch_params)
     
     
     
     
    +@register_model(Models.YOLOX_T)
     class YoloX_T(YoloBase):
     class YoloX_T(YoloBase):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             arch_params.depth_mult_factor = 0.33
             arch_params.depth_mult_factor = 0.33
    @@ -19,6 +23,7 @@ class YoloX_T(YoloBase):
             super().__init__(backbone=YoloDarknetBackbone, arch_params=arch_params)
             super().__init__(backbone=YoloDarknetBackbone, arch_params=arch_params)
     
     
     
     
    +@register_model(Models.YOLOX_S)
     class YoloX_S(YoloBase):
     class YoloX_S(YoloBase):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             arch_params.depth_mult_factor = 0.33
             arch_params.depth_mult_factor = 0.33
    @@ -27,6 +32,7 @@ class YoloX_S(YoloBase):
             super().__init__(backbone=YoloDarknetBackbone, arch_params=arch_params)
             super().__init__(backbone=YoloDarknetBackbone, arch_params=arch_params)
     
     
     
     
    +@register_model(Models.YOLOX_M)
     class YoloX_M(YoloBase):
     class YoloX_M(YoloBase):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             arch_params.depth_mult_factor = 0.67
             arch_params.depth_mult_factor = 0.67
    @@ -35,6 +41,7 @@ class YoloX_M(YoloBase):
             super().__init__(backbone=YoloDarknetBackbone, arch_params=arch_params)
             super().__init__(backbone=YoloDarknetBackbone, arch_params=arch_params)
     
     
     
     
    +@register_model(Models.YOLOX_L)
     class YoloX_L(YoloBase):
     class YoloX_L(YoloBase):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             arch_params.depth_mult_factor = 1.0
             arch_params.depth_mult_factor = 1.0
    @@ -43,6 +50,7 @@ class YoloX_L(YoloBase):
             super().__init__(backbone=YoloDarknetBackbone, arch_params=arch_params)
             super().__init__(backbone=YoloDarknetBackbone, arch_params=arch_params)
     
     
     
     
    +@register_model(Models.YOLOX_X)
     class YoloX_X(YoloBase):
     class YoloX_X(YoloBase):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             arch_params.depth_mult_factor = 1.33
             arch_params.depth_mult_factor = 1.33
    @@ -51,6 +59,7 @@ class YoloX_X(YoloBase):
             super().__init__(backbone=YoloDarknetBackbone, arch_params=arch_params)
             super().__init__(backbone=YoloDarknetBackbone, arch_params=arch_params)
     
     
     
     
    +@register_model(Models.CUSTOM_YOLO_X)
     class CustomYoloX(YoloBase):
     class CustomYoloX(YoloBase):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             arch_params.yolo_type = "yoloX"
             arch_params.yolo_type = "yoloX"
    Discard
    @@ -1,12 +1,17 @@
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.models.sg_module import SgModule
     from collections import namedtuple
     from collections import namedtuple
     import torch
     import torch
    +
    +from super_gradients.common.registry.registry import register_kd_model, register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.training.utils.utils import HpmStruct
     from super_gradients.training.utils.utils import HpmStruct
     from super_gradients.training.utils import get_param
     from super_gradients.training.utils import get_param
     
     
     KDOutput = namedtuple("KDOutput", "student_output teacher_output")
     KDOutput = namedtuple("KDOutput", "student_output teacher_output")
     
     
     
     
    +@register_model(Models.KD_MODULE)
    +@register_kd_model(Models.KD_MODULE)
     class KDModule(SgModule):
     class KDModule(SgModule):
         """
         """
         KDModule
         KDModule
    Discard
    @@ -9,7 +9,7 @@ from super_gradients.common.plugins.deci_client import DeciClient, client_enable
     from super_gradients.training import utils as core_utils
     from super_gradients.training import utils as core_utils
     from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
     from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
     from super_gradients.training.models import SgModule
     from super_gradients.training.models import SgModule
    -from super_gradients.training.models.all_architectures import ARCHITECTURES
    +from super_gradients.common.registry.registry import ARCHITECTURES
     from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
     from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
     from super_gradients.training.utils import HpmStruct, get_param
     from super_gradients.training.utils import HpmStruct, get_param
     from super_gradients.training.utils.checkpoint_utils import (
     from super_gradients.training.utils.checkpoint_utils import (
    Discard
    @@ -18,6 +18,8 @@ import torch.nn.functional as F
     import torchvision
     import torchvision
     from torch import nn
     from torch import nn
     
     
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.common.abstractions.abstract_logger import get_logger
     from super_gradients.common.abstractions.abstract_logger import get_logger
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.models.arch_params_factory import get_arch_params
     from super_gradients.training.models.arch_params_factory import get_arch_params
    @@ -281,6 +283,7 @@ class HighResolutionModule(nn.Module):
     blocks_dict = {"BASIC": BasicBlock, "BOTTLENECK": Bottleneck, "ADAPTIVE": AdaptBlock}
     blocks_dict = {"BASIC": BasicBlock, "BOTTLENECK": Bottleneck, "ADAPTIVE": AdaptBlock}
     
     
     
     
    +@register_model(Models.DEKR_CUSTOM)
     class DEKRPoseEstimationModel(SgModule):
     class DEKRPoseEstimationModel(SgModule):
         """
         """
         Implementation of HRNet model from DEKR paper (https://arxiv.org/abs/2104.02300).
         Implementation of HRNet model from DEKR paper (https://arxiv.org/abs/2104.02300).
    @@ -521,6 +524,7 @@ class DEKRPoseEstimationModel(SgModule):
     POSE_DEKR_W32_NO_DC_ARCH_PARAMS = get_arch_params("pose_dekr_w32_no_dc_arch_params")
     POSE_DEKR_W32_NO_DC_ARCH_PARAMS = get_arch_params("pose_dekr_w32_no_dc_arch_params")
     
     
     
     
    +@register_model(Models.DEKR_W32_NO_DC)
     class DEKRW32(DEKRPoseEstimationModel):
     class DEKRW32(DEKRPoseEstimationModel):
         """
         """
         DEKR-W32 model for pose estimation without deformable convolutions.
         DEKR-W32 model for pose estimation without deformable convolutions.
    Discard
    @@ -3,6 +3,8 @@ from typing import Union
     
     
     from omegaconf import DictConfig
     from omegaconf import DictConfig
     
     
    +from super_gradients.common.object_names import Models
    +from super_gradients.common.registry.registry import register_model
     from super_gradients.training.models.arch_params_factory import get_arch_params
     from super_gradients.training.models.arch_params_factory import get_arch_params
     from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector
     from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector
     from super_gradients.training.utils import HpmStruct
     from super_gradients.training.utils import HpmStruct
    @@ -10,6 +12,7 @@ from super_gradients.training.utils import HpmStruct
     POSE_DDRNET39_ARCH_PARAMS = get_arch_params("pose_ddrnet39_arch_params")
     POSE_DDRNET39_ARCH_PARAMS = get_arch_params("pose_ddrnet39_arch_params")
     
     
     
     
    +@register_model(Models.POSE_DDRNET_39)
     class PoseDDRNet39(CustomizableDetector):
     class PoseDDRNet39(CustomizableDetector):
         def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
         def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
             merged_arch_params = HpmStruct(**copy.deepcopy(POSE_DDRNET39_ARCH_PARAMS))
             merged_arch_params = HpmStruct(**copy.deepcopy(POSE_DDRNET39_ARCH_PARAMS))
    Discard
    @@ -3,6 +3,8 @@ from typing import Union
     
     
     from omegaconf import DictConfig
     from omegaconf import DictConfig
     
     
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.training.models.arch_params_factory import get_arch_params
     from super_gradients.training.models.arch_params_factory import get_arch_params
     from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector
     from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector
     from super_gradients.training.utils import HpmStruct
     from super_gradients.training.utils import HpmStruct
    @@ -10,6 +12,7 @@ from super_gradients.training.utils import HpmStruct
     DEKR_PPPOSE_L_ARCH_PARAMS = get_arch_params("pose_pppose_l_arch_params")
     DEKR_PPPOSE_L_ARCH_PARAMS = get_arch_params("pose_pppose_l_arch_params")
     
     
     
     
    +@register_model(Models.POSE_PP_YOLO_L)
     class PosePPYoloL(CustomizableDetector):
     class PosePPYoloL(CustomizableDetector):
         def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
         def __init__(self, arch_params: Union[HpmStruct, DictConfig], in_channels: int = 3):
             merged_arch_params = HpmStruct(**copy.deepcopy(DEKR_PPPOSE_L_ARCH_PARAMS))
             merged_arch_params = HpmStruct(**copy.deepcopy(DEKR_PPPOSE_L_ARCH_PARAMS))
    Discard
    @@ -1,9 +1,10 @@
    -from collections import OrderedDict
    -
     import torch
     import torch
     import torch.nn as nn
     import torch.nn as nn
     import torch.nn.functional as F
     import torch.nn.functional as F
    +from collections import OrderedDict
     
     
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.training.models.classification_models.resnet import BasicBlock, Bottleneck
     from super_gradients.training.models.classification_models.resnet import BasicBlock, Bottleneck
     from super_gradients.training.models.segmentation_models.segmentation_module import SegmentationModule
     from super_gradients.training.models.segmentation_models.segmentation_module import SegmentationModule
     from super_gradients.training.utils import get_param, HpmStruct
     from super_gradients.training.utils import get_param, HpmStruct
    @@ -558,6 +559,7 @@ DEFAULT_DDRNET_23_SLIM_PARAMS = {
     DEFAULT_DDRNET_39_PARAMS = {**DEFAULT_DDRNET_23_PARAMS, "layers": [3, 4, 3, 3, 1, 3, 3, 1], "head_planes": 256, "layer3_repeats": 2}
     DEFAULT_DDRNET_39_PARAMS = {**DEFAULT_DDRNET_23_PARAMS, "layers": [3, 4, 3, 3, 1, 3, 3, 1], "head_planes": 256, "layer3_repeats": 2}
     
     
     
     
    +@register_model(Models.DDRNET_39)
     class DDRNet39(DDRNetCustom):
     class DDRNet39(DDRNetCustom):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             _arch_params = HpmStruct(**DEFAULT_DDRNET_39_PARAMS)
             _arch_params = HpmStruct(**DEFAULT_DDRNET_39_PARAMS)
    @@ -574,6 +576,7 @@ class DDRNet39(DDRNetCustom):
             super().__init__(_arch_params)
             super().__init__(_arch_params)
     
     
     
     
    +@register_model(Models.DDRNET_23)
     class DDRNet23(DDRNetCustom):
     class DDRNet23(DDRNetCustom):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             _arch_params = HpmStruct(**DEFAULT_DDRNET_23_PARAMS)
             _arch_params = HpmStruct(**DEFAULT_DDRNET_23_PARAMS)
    @@ -590,6 +593,7 @@ class DDRNet23(DDRNetCustom):
             super().__init__(_arch_params)
             super().__init__(_arch_params)
     
     
     
     
    +@register_model(Models.DDRNET_23_SLIM)
     class DDRNet23Slim(DDRNetCustom):
     class DDRNet23Slim(DDRNetCustom):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             _arch_params = HpmStruct(**DEFAULT_DDRNET_23_SLIM_PARAMS)
             _arch_params = HpmStruct(**DEFAULT_DDRNET_23_SLIM_PARAMS)
    @@ -606,6 +610,7 @@ class DDRNet23Slim(DDRNetCustom):
             super().__init__(_arch_params)
             super().__init__(_arch_params)
     
     
     
     
    +@register_model(Models.CUSTOM_DDRNET_23)
     class AnyBackBoneDDRNet23(DDRNetCustom):
     class AnyBackBoneDDRNet23(DDRNetCustom):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             _arch_params = HpmStruct(**DEFAULT_DDRNET_23_PARAMS)
             _arch_params = HpmStruct(**DEFAULT_DDRNET_23_PARAMS)
    Discard
    @@ -2,12 +2,14 @@ from typing import Tuple
     
     
     import torch
     import torch
     
     
    +from super_gradients.common.registry.registry import register_detection_module
     from super_gradients.training.models.segmentation_models.ddrnet import DDRNet39
     from super_gradients.training.models.segmentation_models.ddrnet import DDRNet39
     from super_gradients.training.utils import HpmStruct
     from super_gradients.training.utils import HpmStruct
     
     
     __all__ = ["DDRNet39Backbone"]
     __all__ = ["DDRNet39Backbone"]
     
     
     
     
    +@register_detection_module()
     class DDRNet39Backbone(DDRNet39):
     class DDRNet39Backbone(DDRNet39):
         """
         """
         A somewhat frankenstein version of the DDRNet39 model that tries to be a feature extractor module.
         A somewhat frankenstein version of the DDRNet39 model that tries to be a feature extractor module.
    Discard
    @@ -2,6 +2,9 @@ import torch
     import torch.nn as nn
     import torch.nn as nn
     from typing import Union, List
     from typing import Union, List
     from super_gradients.modules import ConvBNReLU
     from super_gradients.modules import ConvBNReLU
    +
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.modules.sampling import make_upsample_module
     from super_gradients.modules.sampling import make_upsample_module
     from super_gradients.common.data_types.enum.upsample_mode import UpsampleMode
     from super_gradients.common.data_types.enum.upsample_mode import UpsampleMode
     from super_gradients.training.models.segmentation_models.stdc import AbstractSTDCBackbone, STDC1Backbone, STDC2Backbone
     from super_gradients.training.models.segmentation_models.stdc import AbstractSTDCBackbone, STDC1Backbone, STDC2Backbone
    @@ -291,6 +294,9 @@ class PPLiteSegBase(SegmentationModule):
                     module.replace_num_classes(new_num_classes)
                     module.replace_num_classes(new_num_classes)
     
     
     
     
    +@register_model(Models.PP_LITE_B_SEG)
    +@register_model(Models.PP_LITE_B_SEG50)
    +@register_model(Models.PP_LITE_B_SEG75)
     class PPLiteSegB(PPLiteSegBase):
     class PPLiteSegB(PPLiteSegBase):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             backbone = STDC2Backbone(in_channels=get_param(arch_params, "in_channels", 3), out_down_ratios=[8, 16, 32])
             backbone = STDC2Backbone(in_channels=get_param(arch_params, "in_channels", 3), out_down_ratios=[8, 16, 32])
    @@ -316,6 +322,9 @@ class PPLiteSegB(PPLiteSegBase):
             )
             )
     
     
     
     
    +@register_model(Models.PP_LITE_T_SEG)
    +@register_model(Models.PP_LITE_T_SEG50)
    +@register_model(Models.PP_LITE_T_SEG75)
     class PPLiteSegT(PPLiteSegBase):
     class PPLiteSegT(PPLiteSegBase):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             backbone = STDC1Backbone(in_channels=get_param(arch_params, "in_channels", 3), out_down_ratios=[8, 16, 32])
             backbone = STDC1Backbone(in_channels=get_param(arch_params, "in_channels", 3), out_down_ratios=[8, 16, 32])
    Discard
    @@ -6,6 +6,9 @@ from typing import List
     
     
     import torch
     import torch
     import torch.nn as nn
     import torch.nn as nn
    +
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.training.models import SgModule
     from super_gradients.training.models import SgModule
     from super_gradients.training.utils import HpmStruct, get_param
     from super_gradients.training.utils import HpmStruct, get_param
     from super_gradients.modules import ConvBNReLU
     from super_gradients.modules import ConvBNReLU
    @@ -291,6 +294,7 @@ class RegSeg(SgModule):
             self.head = RegSegHead(self.decoder.out_channels, new_num_classes, head_config)
             self.head = RegSegHead(self.decoder.out_channels, new_num_classes, head_config)
     
     
     
     
    +@register_model(Models.REGSEG48)
     class RegSeg48(RegSeg):
     class RegSeg48(RegSeg):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             num_classes = get_param(arch_params, "num_classes")
             num_classes = get_param(arch_params, "num_classes")
    Discard
    @@ -11,6 +11,8 @@ import torch.nn.functional as F
     
     
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.models.sg_module import SgModule
     from super_gradients.training.utils import HpmStruct
     from super_gradients.training.utils import HpmStruct
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.training.models.classification_models.resnet import BasicBlock, ResNet, Bottleneck
     from super_gradients.training.models.classification_models.resnet import BasicBlock, ResNet, Bottleneck
     
     
     
     
    @@ -628,6 +630,7 @@ class ShelfNetLW(ShelfNetBase):
             return params_list
             return params_list
     
     
     
     
    +@register_model(Models.SHELFNET18_LW)
     class ShelfNet18_LW(ShelfNetLW):
     class ShelfNet18_LW(ShelfNetLW):
         def __init__(self, *args, **kwargs):
         def __init__(self, *args, **kwargs):
             super().__init__(backbone=ShelfResNetBackBone18, planes=64, layers=3, *args, **kwargs)
             super().__init__(backbone=ShelfResNetBackBone18, planes=64, layers=3, *args, **kwargs)
    @@ -645,6 +648,7 @@ class ShelfNet18_LW(ShelfNetLW):
                 out_planes *= 2
                 out_planes *= 2
     
     
     
     
    +@register_model(Models.SHELFNET34_LW)
     class ShelfNet34_LW(ShelfNetLW):
     class ShelfNet34_LW(ShelfNetLW):
         def __init__(self, *args, **kwargs):
         def __init__(self, *args, **kwargs):
             super().__init__(backbone=ShelfResNetBackBone34, planes=128, layers=3, *args, **kwargs)
             super().__init__(backbone=ShelfResNetBackBone34, planes=128, layers=3, *args, **kwargs)
    @@ -659,16 +663,19 @@ class ShelfNet34_LW(ShelfNetLW):
                 net_out_planes *= 2
                 net_out_planes *= 2
     
     
     
     
    +@register_model(Models.SHELFNET50_3343)
     class ShelfNet503343(ShelfNetHW):
     class ShelfNet503343(ShelfNetHW):
         def __init__(self, *args, **kwargs):
         def __init__(self, *args, **kwargs):
             super().__init__(backbone=ShelfResNetBackBone503343, planes=256, layers=4, *args, **kwargs)
             super().__init__(backbone=ShelfResNetBackBone503343, planes=256, layers=4, *args, **kwargs)
     
     
     
     
    +@register_model(Models.SHELFNET50)
     class ShelfNet50(ShelfNetHW):
     class ShelfNet50(ShelfNetHW):
         def __init__(self, *args, **kwargs):
         def __init__(self, *args, **kwargs):
             super().__init__(backbone=ShelfResNetBackBone50, planes=256, layers=4, *args, **kwargs)
             super().__init__(backbone=ShelfResNetBackBone50, planes=256, layers=4, *args, **kwargs)
     
     
     
     
    +@register_model(Models.SHELFNET101)
     class ShelfNet101(ShelfNetHW):
     class ShelfNet101(ShelfNetHW):
         def __init__(self, *args, **kwargs):
         def __init__(self, *args, **kwargs):
             super().__init__(backbone=ShelfResNetBackBone101, planes=256, layers=4, *args, **kwargs)
             super().__init__(backbone=ShelfResNetBackBone101, planes=256, layers=4, *args, **kwargs)
    Discard
    @@ -9,6 +9,8 @@ import torch
     import torch.nn as nn
     import torch.nn as nn
     import torch.nn.functional as F
     import torch.nn.functional as F
     
     
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.common.decorators.factory_decorator import resolve_param
     from super_gradients.common.decorators.factory_decorator import resolve_param
     from super_gradients.common.factories.base_factory import BaseFactory
     from super_gradients.common.factories.base_factory import BaseFactory
     from super_gradients.training.models import SgModule
     from super_gradients.training.models import SgModule
    @@ -242,6 +244,7 @@ class STDCClassificationBase(SgModule):
             return out
             return out
     
     
     
     
    +@register_model(Models.STDC_CUSTOM_CLS)
     class STDCClassification(STDCClassificationBase):
     class STDCClassification(STDCClassificationBase):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             super().__init__(
             super().__init__(
    @@ -580,6 +583,7 @@ class STDCSegmentationBase(SgModule):
             return multiply_lr_params.items(), no_multiply_params.items()
             return multiply_lr_params.items(), no_multiply_params.items()
     
     
     
     
    +@register_model(Models.STDC_CUSTOM)
     class CustomSTDCSegmentation(STDCSegmentationBase):
     class CustomSTDCSegmentation(STDCSegmentationBase):
         """
         """
         Fully customized STDC Segmentation factory module.
         Fully customized STDC Segmentation factory module.
    @@ -622,6 +626,7 @@ class STDC2Backbone(STDCBackbone):
             )
             )
     
     
     
     
    +@register_model(Models.STDC1_CLASSIFICATION)
     class STDC1Classification(STDCClassification):
     class STDC1Classification(STDCClassification):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             backbone = STDC1Backbone(in_channels=get_param(arch_params, "input_channels", 3), out_down_ratios=(32,))
             backbone = STDC1Backbone(in_channels=get_param(arch_params, "input_channels", 3), out_down_ratios=(32,))
    @@ -629,6 +634,7 @@ class STDC1Classification(STDCClassification):
             super().__init__(arch_params)
             super().__init__(arch_params)
     
     
     
     
    +@register_model(Models.STDC2_CLASSIFICATION)
     class STDC2Classification(STDCClassification):
     class STDC2Classification(STDCClassification):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             backbone = STDC2Backbone(in_channels=get_param(arch_params, "input_channels", 3), out_down_ratios=(32,))
             backbone = STDC2Backbone(in_channels=get_param(arch_params, "input_channels", 3), out_down_ratios=(32,))
    @@ -636,6 +642,9 @@ class STDC2Classification(STDCClassification):
             super().__init__(arch_params)
             super().__init__(arch_params)
     
     
     
     
    +@register_model(Models.STDC1_SEG)
    +@register_model(Models.STDC1_SEG50)
    +@register_model(Models.STDC1_SEG75)
     class STDC1Seg(CustomSTDCSegmentation):
     class STDC1Seg(CustomSTDCSegmentation):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             backbone = STDC1Backbone(in_channels=get_param(arch_params, "in_channels", 3), out_down_ratios=[8, 16, 32])
             backbone = STDC1Backbone(in_channels=get_param(arch_params, "in_channels", 3), out_down_ratios=[8, 16, 32])
    @@ -645,6 +654,9 @@ class STDC1Seg(CustomSTDCSegmentation):
             super().__init__(arch_params)
             super().__init__(arch_params)
     
     
     
     
    +@register_model(Models.STDC2_SEG)
    +@register_model(Models.STDC2_SEG50)
    +@register_model(Models.STDC2_SEG75)
     class STDC2Seg(CustomSTDCSegmentation):
     class STDC2Seg(CustomSTDCSegmentation):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             backbone = STDC2Backbone(in_channels=get_param(arch_params, "in_channels", 3), out_down_ratios=[8, 16, 32])
             backbone = STDC2Backbone(in_channels=get_param(arch_params, "in_channels", 3), out_down_ratios=[8, 16, 32])
    Discard
    @@ -1,6 +1,8 @@
     import torch.nn as nn
     import torch.nn as nn
     from typing import Optional, Union, List
     from typing import Optional, Union, List
     
     
    +from super_gradients.common.registry.registry import register_model
    +from super_gradients.common.object_names import Models
     from super_gradients.training.utils import HpmStruct, get_param
     from super_gradients.training.utils import HpmStruct, get_param
     from super_gradients.training import models
     from super_gradients.training import models
     from super_gradients.training.models.segmentation_models.segmentation_module import SegmentationModule
     from super_gradients.training.models.segmentation_models.segmentation_module import SegmentationModule
    @@ -200,6 +202,7 @@ class UNetBase(SegmentationModule):
                     module.replace_num_classes(new_num_classes)
                     module.replace_num_classes(new_num_classes)
     
     
     
     
    +@register_model(Models.UNET_CUSTOM)
     class UNetCustom(UNetBase):
     class UNetCustom(UNetBase):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             arch_params = HpmStruct(**models.get_arch_params("unet_default_arch_params.yaml", overriding_params=arch_params.to_dict()))
             arch_params = HpmStruct(**models.get_arch_params("unet_default_arch_params.yaml", overriding_params=arch_params.to_dict()))
    @@ -218,6 +221,7 @@ class UNetCustom(UNetBase):
             )
             )
     
     
     
     
    +@register_model(Models.UNET)
     class UNet(UNetCustom):
     class UNet(UNetCustom):
         """
         """
         implementation of:
         implementation of:
    Discard
    @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
     import torch.nn as nn
     import torch.nn as nn
     import torch
     import torch
     
     
    +from super_gradients.common.registry.registry import register_unet_up_block, UP_FUSE_BLOCKS
     from super_gradients.modules import ConvBNReLU, CrossModelSkipConnection, Residual
     from super_gradients.modules import ConvBNReLU, CrossModelSkipConnection, Residual
     from super_gradients.modules.sampling import make_upsample_module
     from super_gradients.modules.sampling import make_upsample_module
     from super_gradients.common.decorators.factory_decorator import resolve_param
     from super_gradients.common.decorators.factory_decorator import resolve_param
    @@ -29,6 +30,7 @@ class AbstractUpFuseBlock(nn.Module, ABC):
             raise NotImplementedError()
             raise NotImplementedError()
     
     
     
     
    +@register_unet_up_block()
     class UpFactorBlock(AbstractUpFuseBlock):
     class UpFactorBlock(AbstractUpFuseBlock):
         """
         """
         Ignore Skip features, simply apply upsampling and ConvBNRelu layers.
         Ignore Skip features, simply apply upsampling and ConvBNRelu layers.
    @@ -48,6 +50,7 @@ class UpFactorBlock(AbstractUpFuseBlock):
             return self.last_convs(x)
             return self.last_convs(x)
     
     
     
     
    +@register_unet_up_block()
     class UpCatBlock(AbstractUpFuseBlock):
     class UpCatBlock(AbstractUpFuseBlock):
         """
         """
         Fuse features with concatenation and followed Convolutions.
         Fuse features with concatenation and followed Convolutions.
    @@ -67,6 +70,7 @@ class UpCatBlock(AbstractUpFuseBlock):
             return self.last_convs(x)
             return self.last_convs(x)
     
     
     
     
    +@register_unet_up_block()
     class UpSumBlock(AbstractUpFuseBlock):
     class UpSumBlock(AbstractUpFuseBlock):
         """
         """
         Fuse features with concatenation and followed Convolutions.
         Fuse features with concatenation and followed Convolutions.
    @@ -89,13 +93,6 @@ class UpSumBlock(AbstractUpFuseBlock):
             return self.last_convs(x)
             return self.last_convs(x)
     
     
     
     
    -UP_FUSE_BLOCKS = dict(
    -    UpCatBlock=UpCatBlock,
    -    UpFactorBlock=UpFactorBlock,
    -    UpSumBlock=UpSumBlock,
    -)
    -
    -
     class Decoder(nn.Module):
     class Decoder(nn.Module):
         @resolve_param("up_block_types", ListFactory(TypeFactory(UP_FUSE_BLOCKS)))
         @resolve_param("up_block_types", ListFactory(TypeFactory(UP_FUSE_BLOCKS)))
         def __init__(
         def __init__(
    Discard
    @@ -5,6 +5,8 @@ from abc import ABC, abstractmethod
     import torch
     import torch
     import torch.nn as nn
     import torch.nn as nn
     
     
    +from super_gradients.common.registry.registry import register_model, register_unet_backbone_stage, BACKBONE_STAGES
    +from super_gradients.common.object_names import Models
     from super_gradients.common.factories.context_modules_factory import ContextModulesFactory
     from super_gradients.common.factories.context_modules_factory import ContextModulesFactory
     from super_gradients.training.models.segmentation_models.context_modules import AbstractContextModule
     from super_gradients.training.models.segmentation_models.context_modules import AbstractContextModule
     from super_gradients.training.utils.utils import get_param, HpmStruct
     from super_gradients.training.utils.utils import get_param, HpmStruct
    @@ -68,6 +70,7 @@ class BackboneStage(nn.Module, ABC):
             return self.blocks(x)
             return self.blocks(x)
     
     
     
     
    +@register_unet_backbone_stage()
     class STDCStage(BackboneStage):
     class STDCStage(BackboneStage):
         """
         """
         STDC stage with STDCBlock as building block.
         STDC stage with STDCBlock as building block.
    @@ -149,6 +152,7 @@ class ConvBaseStage(BackboneStage, ABC):
             raise NotImplementedError()
             raise NotImplementedError()
     
     
     
     
    +@register_unet_backbone_stage()
     class RepVGGStage(ConvBaseStage):
     class RepVGGStage(ConvBaseStage):
         """
         """
         RepVGG stage with RepVGGBlock as building block.
         RepVGG stage with RepVGGBlock as building block.
    @@ -158,6 +162,7 @@ class RepVGGStage(ConvBaseStage):
             return RepVGGBlock(in_channels, out_channels, stride=stride)
             return RepVGGBlock(in_channels, out_channels, stride=stride)
     
     
     
     
    +@register_unet_backbone_stage()
     class QARepVGGStage(ConvBaseStage):
     class QARepVGGStage(ConvBaseStage):
         """
         """
         QARepVGG stage with QARepVGGBlock as building block.
         QARepVGG stage with QARepVGGBlock as building block.
    @@ -167,6 +172,7 @@ class QARepVGGStage(ConvBaseStage):
             return QARepVGGBlock(in_channels, out_channels, stride=stride, use_residual_connection=(out_channels == in_channels and stride == 1))
             return QARepVGGBlock(in_channels, out_channels, stride=stride, use_residual_connection=(out_channels == in_channels and stride == 1))
     
     
     
     
    +@register_unet_backbone_stage()
     class RegnetXStage(BackboneStage):
     class RegnetXStage(BackboneStage):
         """
         """
         RegNetX stage with XBlock as building block.
         RegNetX stage with XBlock as building block.
    @@ -206,6 +212,7 @@ class RegnetXStage(BackboneStage):
             return 1
             return 1
     
     
     
     
    +@register_unet_backbone_stage()
     class ConvStage(ConvBaseStage):
     class ConvStage(ConvBaseStage):
         """
         """
         Conv stage with ConvBNReLU as building block.
         Conv stage with ConvBNReLU as building block.
    @@ -215,15 +222,6 @@ class ConvStage(ConvBaseStage):
             return ConvBNReLU(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
             return ConvBNReLU(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
     
     
     
     
    -BACKBONE_STAGES = dict(
    -    RepVGGStage=RepVGGStage,
    -    QARepVGGStage=QARepVGGStage,
    -    STDCStage=STDCStage,
    -    RegnetXStage=RegnetXStage,
    -    ConvStage=ConvStage,
    -)
    -
    -
     class UNetBackboneBase(AbstractUNetBackbone):
     class UNetBackboneBase(AbstractUNetBackbone):
         @resolve_param("block_types_list", ListFactory(TypeFactory(BACKBONE_STAGES)))
         @resolve_param("block_types_list", ListFactory(TypeFactory(BACKBONE_STAGES)))
         def __init__(
         def __init__(
    @@ -329,6 +327,7 @@ class UnetClassification(SgModule):
             return self.classifier_head(x)
             return self.classifier_head(x)
     
     
     
     
    +@register_model(Models.UNET_CUSTOM_CLS)
     class UnetClassificationCustom(UnetClassification):
     class UnetClassificationCustom(UnetClassification):
         def __init__(self, arch_params: HpmStruct):
         def __init__(self, arch_params: HpmStruct):
             arch_params = HpmStruct(**models.get_arch_params("unet_default_arch_params.yaml", overriding_params=arch_params.to_dict()))
             arch_params = HpmStruct(**models.get_arch_params("unet_default_arch_params.yaml", overriding_params=arch_params.to_dict()))
    Discard
    @@ -3,10 +3,6 @@ from super_gradients.training.pre_launch_callbacks.pre_launch_callbacks import (
         AutoTrainBatchSizeSelectionCallback,
         AutoTrainBatchSizeSelectionCallback,
         QATRecipeModificationCallback,
         QATRecipeModificationCallback,
     )
     )
    -
    -ALL_PRE_LAUNCH_CALLBACKS = {
    -    "AutoTrainBatchSizeSelectionCallback": AutoTrainBatchSizeSelectionCallback,
    -    "QATRecipeModificationCallback": QATRecipeModificationCallback,
    -}
    +from super_gradients.common.registry.registry import ALL_PRE_LAUNCH_CALLBACKS
     
     
     __all__ = ["PreLaunchCallback", "AutoTrainBatchSizeSelectionCallback", "QATRecipeModificationCallback", "ALL_PRE_LAUNCH_CALLBACKS"]
     __all__ = ["PreLaunchCallback", "AutoTrainBatchSizeSelectionCallback", "QATRecipeModificationCallback", "ALL_PRE_LAUNCH_CALLBACKS"]
    Discard
    @@ -5,6 +5,7 @@ from typing import Union
     from omegaconf import DictConfig
     from omegaconf import DictConfig
     import torch
     import torch
     
     
    +from super_gradients.common.registry.registry import register_pre_launch_callback
     from super_gradients import is_distributed
     from super_gradients import is_distributed
     from super_gradients.common.abstractions.abstract_logger import get_logger
     from super_gradients.common.abstractions.abstract_logger import get_logger
     from super_gradients.training import models
     from super_gradients.training import models
    @@ -28,6 +29,7 @@ class PreLaunchCallback:
             raise NotImplementedError
             raise NotImplementedError
     
     
     
     
    +@register_pre_launch_callback()
     class AutoTrainBatchSizeSelectionCallback(PreLaunchCallback):
     class AutoTrainBatchSizeSelectionCallback(PreLaunchCallback):
         """
         """
         AutoTrainBatchSizeSelectionCallback
         AutoTrainBatchSizeSelectionCallback
    @@ -178,6 +180,7 @@ class AutoTrainBatchSizeSelectionCallback(PreLaunchCallback):
                 barrier()
                 barrier()
     
     
     
     
    +@register_pre_launch_callback()
     class QATRecipeModificationCallback(PreLaunchCallback):
     class QATRecipeModificationCallback(PreLaunchCallback):
         """
         """
          QATRecipeModificationCallback(PreLaunchCallback)
          QATRecipeModificationCallback(PreLaunchCallback)
    Discard
    @@ -20,15 +20,15 @@ from tqdm import tqdm
     from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path, get_ckpt_local_path
     from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path, get_ckpt_local_path
     
     
     from super_gradients.common.abstractions.abstract_logger import get_logger
     from super_gradients.common.abstractions.abstract_logger import get_logger
    +from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
    +from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
     from super_gradients.common.data_types.enum import MultiGPUMode, StrictLoad, EvaluationType
     from super_gradients.common.data_types.enum import MultiGPUMode, StrictLoad, EvaluationType
     from super_gradients.common.decorators.factory_decorator import resolve_param
     from super_gradients.common.decorators.factory_decorator import resolve_param
     from super_gradients.common.factories.callbacks_factory import CallbacksFactory
     from super_gradients.common.factories.callbacks_factory import CallbacksFactory
     from super_gradients.common.factories.list_factory import ListFactory
     from super_gradients.common.factories.list_factory import ListFactory
     from super_gradients.common.factories.losses_factory import LossesFactory
     from super_gradients.common.factories.losses_factory import LossesFactory
     from super_gradients.common.factories.metrics_factory import MetricsFactory
     from super_gradients.common.factories.metrics_factory import MetricsFactory
    -from super_gradients.common.sg_loggers import SG_LOGGERS
    -from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
    -from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
    +
     from super_gradients.training import utils as core_utils, models, dataloaders
     from super_gradients.training import utils as core_utils, models, dataloaders
     from super_gradients.training.datasets.samplers import InfiniteSampler, RepeatAugSampler
     from super_gradients.training.datasets.samplers import InfiniteSampler, RepeatAugSampler
     from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat
     from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat
    @@ -40,8 +40,7 @@ from super_gradients.training.metrics.metric_utils import (
         get_train_loop_description_dict,
         get_train_loop_description_dict,
     )
     )
     from super_gradients.training.models import SgModule
     from super_gradients.training.models import SgModule
    -from super_gradients.training.models.all_architectures import ARCHITECTURES
    -from super_gradients.training.params import TrainingParams
    +from super_gradients.common.registry.registry import ARCHITECTURES, SG_LOGGERS
     from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
     from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
     from super_gradients.training.utils import sg_trainer_utils, get_param
     from super_gradients.training.utils import sg_trainer_utils, get_param
     from super_gradients.training.utils.distributed_training_utils import (
     from super_gradients.training.utils.distributed_training_utils import (
    @@ -74,17 +73,17 @@ from super_gradients.training.datasets.datasets_utils import DatasetStatisticsTe
     from super_gradients.training.utils.callbacks import (
     from super_gradients.training.utils.callbacks import (
         CallbackHandler,
         CallbackHandler,
         Phase,
         Phase,
    -    LR_SCHEDULERS_CLS_DICT,
         PhaseContext,
         PhaseContext,
         MetricsUpdateCallback,
         MetricsUpdateCallback,
    -    LR_WARMUP_CLS_DICT,
         ContextSgMethods,
         ContextSgMethods,
         LRCallbackBase,
         LRCallbackBase,
     )
     )
    +from super_gradients.common.registry.registry import LR_SCHEDULERS_CLS_DICT, LR_WARMUP_CLS_DICT
     from super_gradients.common.environment.device_utils import device_config
     from super_gradients.common.environment.device_utils import device_config
     from super_gradients.training.utils import HpmStruct
     from super_gradients.training.utils import HpmStruct
     from super_gradients.common.environment.cfg_utils import load_experiment_cfg, add_params_to_cfg
     from super_gradients.common.environment.cfg_utils import load_experiment_cfg, add_params_to_cfg
     from super_gradients.common.factories.pre_launch_callbacks_factory import PreLaunchCallbacksFactory
     from super_gradients.common.factories.pre_launch_callbacks_factory import PreLaunchCallbacksFactory
    +from super_gradients.training.params import TrainingParams
     
     
     logger = get_logger(__name__)
     logger = get_logger(__name__)
     
     
    Discard

    Some files were not shown because too many files changed in this diff