Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#344 Feature/sg 255 add class for supported strings

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-255-add_class_for_supported_strings
36 changed files with 830 additions and 374 deletions
  1. 2
    2
      src/super_gradients/__init__.py
  2. 10
    0
      src/super_gradients/common/factories/__init__.py
  3. 2
    18
      src/super_gradients/common/factories/callbacks_factory.py
  4. 2
    13
      src/super_gradients/common/factories/metrics_factory.py
  5. 2
    10
      src/super_gradients/common/factories/optimizers_type_factory.py
  6. 2
    8
      src/super_gradients/common/factories/samplers_factory.py
  7. 2
    40
      src/super_gradients/common/factories/transforms_factory.py
  8. 10
    10
      src/super_gradients/examples/regseg_transfer_learning_example/regseg_transfer_learning_example.py
  9. 5
    5
      src/super_gradients/recipes/dataset_params/cityscapes_ddrnet_dataset_params.yaml
  10. 5
    5
      src/super_gradients/recipes/dataset_params/cityscapes_regseg48_dataset_params.yaml
  11. 6
    6
      src/super_gradients/recipes/dataset_params/cityscapes_stdc_seg50_dataset_params.yaml
  12. 6
    6
      src/super_gradients/recipes/dataset_params/cityscapes_stdc_seg75_dataset_params.yaml
  13. 7
    7
      src/super_gradients/recipes/dataset_params/coco_segmentation_dataset_params.yaml
  14. 7
    7
      src/super_gradients/recipes/dataset_params/pascal_aug_segmentation_dataset_params.yaml
  15. 7
    7
      src/super_gradients/recipes/dataset_params/pascal_voc_segmentation_dataset_params.yaml
  16. 6
    6
      src/super_gradients/recipes/dataset_params/supervisely_persons_dataset_params.yaml
  17. 0
    0
      src/super_gradients/training/datasets/all_datasets.py
  18. 6
    0
      src/super_gradients/training/datasets/samplers/__init__.py
  19. 11
    0
      src/super_gradients/training/datasets/samplers/all_samplers.py
  20. 3
    3
      src/super_gradients/training/losses/__init__.py
  21. 14
    14
      src/super_gradients/training/losses/all_losses.py
  22. 3
    2
      src/super_gradients/training/metrics/__init__.py
  23. 15
    0
      src/super_gradients/training/metrics/all_metrics.py
  24. 1
    1
      src/super_gradients/training/models/__init__.py
  25. 197
    99
      src/super_gradients/training/models/all_architectures.py
  26. 131
    0
      src/super_gradients/training/object_names.py
  27. 2
    1
      src/super_gradients/training/transforms/__init__.py
  28. 74
    0
      src/super_gradients/training/transforms/all_transforms.py
  29. 14
    14
      src/super_gradients/training/transforms/transforms.py
  30. 15
    0
      src/super_gradients/training/utils/callbacks/__init__.py
  31. 31
    0
      src/super_gradients/training/utils/callbacks/all_callbacks.py
  32. 0
    12
      src/super_gradients/training/utils/callbacks/callbacks.py
  33. 6
    0
      src/super_gradients/training/utils/optimizers/__init__.py
  34. 13
    0
      src/super_gradients/training/utils/optimizers/all_optimizers.py
  35. 35
    35
      tests/unit_tests/segmentation_transforms_test.py
  36. 178
    43
      tutorials/SG_transfer_learning_semantic_segmentation.ipynb
@@ -1,11 +1,11 @@
-from super_gradients.training import ARCHITECTURES, losses, utils, datasets_utils, DataAugmentation, Trainer, KDTrainer
+from super_gradients.training import ARCHITECTURES, losses, utils, datasets_utils, DataAugmentation, Trainer, KDTrainer, object_names
 from super_gradients.common import init_trainer, is_distributed
 from super_gradients.common import init_trainer, is_distributed
 from super_gradients.examples.train_from_recipe_example import train_from_recipe
 from super_gradients.examples.train_from_recipe_example import train_from_recipe
 from super_gradients.examples.train_from_kd_recipe_example import train_from_kd_recipe
 from super_gradients.examples.train_from_kd_recipe_example import train_from_kd_recipe
 from super_gradients.sanity_check import env_sanity_check
 from super_gradients.sanity_check import env_sanity_check
 
 
 __all__ = ['ARCHITECTURES', 'losses', 'utils', 'datasets_utils', 'DataAugmentation',
 __all__ = ['ARCHITECTURES', 'losses', 'utils', 'datasets_utils', 'DataAugmentation',
-           'Trainer', 'KDTrainer',
+           'Trainer', 'KDTrainer', 'object_names',
            'init_trainer', 'is_distributed', 'train_from_recipe', 'train_from_kd_recipe',
            'init_trainer', 'is_distributed', 'train_from_recipe', 'train_from_kd_recipe',
            'env_sanity_check']
            'env_sanity_check']
 
 
Discard
@@ -0,0 +1,10 @@
+from super_gradients.common.factories.callbacks_factory import CallbacksFactory
+from super_gradients.common.factories.list_factory import ListFactory
+from super_gradients.common.factories.losses_factory import LossesFactory
+from super_gradients.common.factories.metrics_factory import MetricsFactory
+from super_gradients.common.factories.optimizers_type_factory import OptimizersTypeFactory
+from super_gradients.common.factories.samplers_factory import SamplersFactory
+from super_gradients.common.factories.transforms_factory import TransformsFactory
+
+
+__all__ = ["CallbacksFactory", "ListFactory", "LossesFactory", "MetricsFactory", "OptimizersTypeFactory", "SamplersFactory", "TransformsFactory"]
Discard
@@ -1,24 +1,8 @@
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.base_factory import BaseFactory
-from super_gradients.training.datasets.datasets_utils import DetectionMultiscalePrePredictionCallback
-from super_gradients.training.utils.callbacks import DeciLabUploadCallback, LRCallbackBase, LRSchedulerCallback, \
-    MetricsUpdateCallback, \
-    ModelConversionCheckCallback, YoloXTrainingStageSwitchCallback
-from super_gradients.training.utils.early_stopping import EarlyStop
+from super_gradients.training.utils.callbacks import CALLBACKS
 
 
 
 
 class CallbacksFactory(BaseFactory):
 class CallbacksFactory(BaseFactory):
 
 
     def __init__(self):
     def __init__(self):
-        type_dict = {
-            'DeciLabUploadCallback': DeciLabUploadCallback,
-            'LRCallbackBase': LRCallbackBase,
-            'LRSchedulerCallback': LRSchedulerCallback,
-            'MetricsUpdateCallback': MetricsUpdateCallback,
-            'ModelConversionCheckCallback': ModelConversionCheckCallback,
-            'EarlyStop': EarlyStop,
-            'DetectionMultiscalePrePredictionCallback': DetectionMultiscalePrePredictionCallback,
-            'YoloXTrainingStageSwitchCallback': YoloXTrainingStageSwitchCallback
-
-
-        }
-        super().__init__(type_dict)
+        super().__init__(CALLBACKS)
Discard
@@ -1,19 +1,8 @@
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.base_factory import BaseFactory
-from super_gradients.training.metrics import Accuracy, Top5, DetectionMetrics, IoU, PixelAccuracy, BinaryIOU, Dice,\
-    BinaryDice
+from super_gradients.training.metrics import METRICS
 
 
 
 
 class MetricsFactory(BaseFactory):
 class MetricsFactory(BaseFactory):
 
 
     def __init__(self):
     def __init__(self):
-        type_dict = {
-            'Accuracy': Accuracy,
-            'Top5': Top5,
-            'DetectionMetrics': DetectionMetrics,
-            'IoU': IoU,
-            "BinaryIOU": BinaryIOU,
-            "Dice": Dice,
-            "BinaryDice": BinaryDice,
-            'PixelAccuracy': PixelAccuracy,
-        }
-        super().__init__(type_dict)
+        super().__init__(METRICS)
Discard
@@ -1,11 +1,9 @@
 import importlib
 import importlib
 from typing import Union
 from typing import Union
 
 
-from torch import optim
 
 
 from super_gradients.common.factories.base_factory import AbstractFactory
 from super_gradients.common.factories.base_factory import AbstractFactory
-from super_gradients.training.utils.optimizers.rmsprop_tf import RMSpropTF
-from super_gradients.training.utils.optimizers.lamb import Lamb
+from super_gradients.training.utils.optimizers import OPTIMIZERS
 
 
 
 
 class OptimizersTypeFactory(AbstractFactory):
 class OptimizersTypeFactory(AbstractFactory):
@@ -17,13 +15,7 @@ class OptimizersTypeFactory(AbstractFactory):
 
 
     def __init__(self):
     def __init__(self):
 
 
-        self.type_dict = {
-            "SGD": optim.SGD,
-            "Adam": optim.Adam,
-            "RMSprop": optim.RMSprop,
-            "RMSpropTF": RMSpropTF,
-            "Lamb": Lamb
-        }
+        self.type_dict = OPTIMIZERS
 
 
     def get(self, conf: Union[str]):
     def get(self, conf: Union[str]):
         """
         """
Discard
@@ -1,14 +1,8 @@
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.base_factory import BaseFactory
-from super_gradients.training.datasets.samplers.infinite_sampler import InfiniteSampler
-from super_gradients.training.datasets.samplers.repeated_augmentation_sampler import RepeatAugSampler
-from torch.utils.data.distributed import DistributedSampler
+from super_gradients.training.datasets.samplers import SAMPLERS
 
 
 
 
 class SamplersFactory(BaseFactory):
 class SamplersFactory(BaseFactory):
 
 
     def __init__(self):
     def __init__(self):
-        type_dict = {"InfiniteSampler": InfiniteSampler,
-                     "RepeatAugSampler": RepeatAugSampler,
-                     "DistributedSampler": DistributedSampler
-                     }
-        super().__init__(type_dict)
+        super().__init__(SAMPLERS)
Discard
@@ -1,54 +1,16 @@
-import inspect
 from typing import Union, Mapping
 from typing import Union, Mapping
 
 
 from omegaconf import ListConfig
 from omegaconf import ListConfig
-from torchvision import transforms
 
 
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.list_factory import ListFactory
 from super_gradients.common.factories.list_factory import ListFactory
-from super_gradients.training.datasets.data_augmentation import Lighting, RandomErase
-from super_gradients.training.datasets.datasets_utils import RandomResizedCropAndInterpolation, rand_augment_transform
-from super_gradients.training.transforms.transforms import RandomFlip, Rescale, RandomRescale, RandomRotate, \
-    CropImageAndMask, RandomGaussianBlur, PadShortToCropSize, ResizeSeg, ColorJitterSeg, DetectionMosaic, DetectionRandomAffine, \
-    DetectionMixup, DetectionHSV, \
-    DetectionHorizontalFlip, DetectionTargetsFormat, DetectionPaddedRescale, \
-    DetectionTargetsFormatTransform
+from super_gradients.training.transforms import TRANSFORMS
 
 
 
 
 class TransformsFactory(BaseFactory):
 class TransformsFactory(BaseFactory):
 
 
     def __init__(self):
     def __init__(self):
-        type_dict = {
-            'RandomFlipSeg': RandomFlip,
-            'ResizeSeg': ResizeSeg,
-            'RescaleSeg': Rescale,
-            'RandomRescaleSeg': RandomRescale,
-            'RandomRotateSeg': RandomRotate,
-            'CropImageAndMaskSeg': CropImageAndMask,
-            'RandomGaussianBlurSeg': RandomGaussianBlur,
-            'PadShortToCropSizeSeg': PadShortToCropSize,
-            'ColorJitterSeg': ColorJitterSeg,
-            "DetectionMosaic": DetectionMosaic,
-            "DetectionRandomAffine": DetectionRandomAffine,
-            "DetectionMixup": DetectionMixup,
-            "DetectionHSV": DetectionHSV,
-            "DetectionHorizontalFlip": DetectionHorizontalFlip,
-            "DetectionPaddedRescale": DetectionPaddedRescale,
-            "DetectionTargetsFormat": DetectionTargetsFormat,
-            "DetectionTargetsFormatTransform": DetectionTargetsFormatTransform,
-
-            'RandomResizedCropAndInterpolation': RandomResizedCropAndInterpolation,
-            'RandAugmentTransform': rand_augment_transform,
-            'Lighting': Lighting,
-            'RandomErase': RandomErase
-        }
-        for name, obj in inspect.getmembers(transforms, inspect.isclass):
-            if name in type_dict:
-                raise RuntimeError(f'key {name} already exists in dictionary')
-
-            type_dict[name] = obj
-
-        super().__init__(type_dict)
+        super().__init__(TRANSFORMS)
 
 
     def get(self, conf: Union[str, dict]):
     def get(self, conf: Union[str, dict]):
 
 
Discard
@@ -2,21 +2,21 @@ from super_gradients.training import models, dataloaders
 
 
 from super_gradients.training.sg_trainer import Trainer
 from super_gradients.training.sg_trainer import Trainer
 from super_gradients.training.metrics import BinaryIOU
 from super_gradients.training.metrics import BinaryIOU
-from super_gradients.training.transforms.transforms import ResizeSeg, RandomFlip, RandomRescale, CropImageAndMask, \
-    PadShortToCropSize, ColorJitterSeg
+from super_gradients.training.transforms.transforms import SegResize, SegRandomFlip, SegRandomRescale, SegCropImageAndMask, \
+    SegPadShortToCropSize, SegColorJitter
 from super_gradients.training.utils.callbacks import BinarySegmentationVisualizationCallback, Phase
 from super_gradients.training.utils.callbacks import BinarySegmentationVisualizationCallback, Phase
 
 
 # DEFINE DATA TRANSFORMATIONS
 # DEFINE DATA TRANSFORMATIONS
 
 
 dl_train = dataloaders.supervisely_persons_train(
 dl_train = dataloaders.supervisely_persons_train(
-    dataset_params={"transforms": [ColorJitterSeg(brightness=0.5, contrast=0.5, saturation=0.5),
-                                   RandomFlip(),
-                                   RandomRescale(scales=[0.25, 1.]),
-                                   PadShortToCropSize([320, 480]),
-                                   CropImageAndMask(crop_size=[320, 480],
-                                                    mode="random")]})
-
-dl_val = dataloaders.supervisely_persons_val(dataset_params={"transforms": [ResizeSeg(h=480, w=320)]})
+    dataset_params={"transforms": [SegColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
+                                   SegRandomFlip(),
+                                   SegRandomRescale(scales=[0.25, 1.]),
+                                   SegPadShortToCropSize([320, 480]),
+                                   SegCropImageAndMask(crop_size=[320, 480],
+                                                       mode="random")]})
+
+dl_val = dataloaders.supervisely_persons_val(dataset_params={"transforms": [SegResize(h=480, w=320)]})
 
 
 trainer = Trainer("regseg48_transfer_learning_old_dice_diff_lrs_head_fixed_50_epochs")
 trainer = Trainer("regseg48_transfer_learning_old_dice_diff_lrs_head_fixed_50_epochs")
 
 
Discard
@@ -4,22 +4,22 @@ defaults:
 
 
 train_dataset_params:
 train_dataset_params:
   transforms:
   transforms:
-    - ColorJitterSeg:
+    - SegColorJitter:
         brightness: 0.5
         brightness: 0.5
         contrast: 0.5
         contrast: 0.5
         saturation: 0.5
         saturation: 0.5
 
 
-    - RandomFlipSeg:
+    - SegRandomFlip:
         prob: 0.5
         prob: 0.5
 
 
-    - RandomRescaleSeg:
+    - SegRandomRescale:
         scales: [ 0.5, 2. ]
         scales: [ 0.5, 2. ]
 
 
-    - PadShortToCropSizeSeg:
+    - SegPadShortToCropSize:
         crop_size: [ 1024, 1024 ]
         crop_size: [ 1024, 1024 ]
         fill_mask: 19
         fill_mask: 19
 
 
-    - CropImageAndMaskSeg:
+    - SegCropImageAndMask:
         crop_size: [ 1024, 1024 ]
         crop_size: [ 1024, 1024 ]
         mode: random
         mode: random
 
 
Discard
@@ -5,23 +5,23 @@ defaults:
 train_dataset_params:
 train_dataset_params:
   transforms:
   transforms:
     # for more options see common.factories.transforms_factory.py
     # for more options see common.factories.transforms_factory.py
-    - ColorJitterSeg:
+    - SegColorJitter:
         brightness: 0.1
         brightness: 0.1
         contrast: 0.1
         contrast: 0.1
         saturation: 0.1
         saturation: 0.1
 
 
-    - RandomFlipSeg:
+    - SegRandomFlip:
         prob: 0.5
         prob: 0.5
 
 
-    - RandomRescaleSeg:
+    - SegRandomRescale:
         scales: [ 0.4, 1.6 ]
         scales: [ 0.4, 1.6 ]
 
 
-    - PadShortToCropSizeSeg:
+    - SegPadShortToCropSize:
         crop_size: 1024
         crop_size: 1024
         fill_image: [ 19, 0, 0 ]
         fill_image: [ 19, 0, 0 ]
         fill_mask: 19                     # ignored label idx
         fill_mask: 19                     # ignored label idx
 
 
-    - CropImageAndMaskSeg:
+    - SegCropImageAndMask:
         crop_size: 1024
         crop_size: 1024
         mode: random
         mode: random
 
 
Discard
@@ -5,28 +5,28 @@ defaults:
 train_dataset_params:
 train_dataset_params:
   transforms:
   transforms:
     # for more options see common.factories.transforms_factory.py
     # for more options see common.factories.transforms_factory.py
-    - ColorJitterSeg:
+    - SegColorJitter:
         brightness: 0.5
         brightness: 0.5
         contrast: 0.5
         contrast: 0.5
         saturation: 0.5
         saturation: 0.5
 
 
-    - RandomFlipSeg:
+    - SegRandomFlip:
         prob: 0.5
         prob: 0.5
 
 
-    - RandomRescaleSeg:
+    - SegRandomRescale:
         scales: [ 0.125, 1.5 ]
         scales: [ 0.125, 1.5 ]
 
 
-    - PadShortToCropSizeSeg:
+    - SegPadShortToCropSize:
         crop_size: [ 1024, 512 ]
         crop_size: [ 1024, 512 ]
         fill_mask: 19                  # ignored label idx
         fill_mask: 19                  # ignored label idx
 
 
-    - CropImageAndMaskSeg:
+    - SegCropImageAndMask:
         crop_size: [ 1024, 512 ]
         crop_size: [ 1024, 512 ]
         mode: random
         mode: random
 
 
 val_dataset_params:
 val_dataset_params:
   transforms:
   transforms:
-    - RescaleSeg:
+    - SegRescale:
         scale_factor: 0.5
         scale_factor: 0.5
 
 
 train_dataloader_params:
 train_dataloader_params:
Discard
@@ -5,28 +5,28 @@ defaults:
 train_dataset_params:
 train_dataset_params:
   transforms:
   transforms:
     # for more options see common.factories.transforms_factory.py
     # for more options see common.factories.transforms_factory.py
-    - ColorJitterSeg:
+    - SegColorJitter:
         brightness: 0.5
         brightness: 0.5
         contrast: 0.5
         contrast: 0.5
         saturation: 0.5
         saturation: 0.5
 
 
-    - RandomFlipSeg:
+    - SegRandomFlip:
         prob: 0.5
         prob: 0.5
 
 
-    - RandomRescaleSeg:
+    - SegRandomRescale:
         scales: [ 0.125, 1.5 ]
         scales: [ 0.125, 1.5 ]
 
 
-    - PadShortToCropSizeSeg:
+    - SegPadShortToCropSize:
         crop_size: [ 1536, 768 ]
         crop_size: [ 1536, 768 ]
         fill_mask: 19                  # ignored label idx
         fill_mask: 19                  # ignored label idx
 
 
-    - CropImageAndMaskSeg:
+    - SegCropImageAndMask:
         crop_size: [ 1536, 768 ]
         crop_size: [ 1536, 768 ]
         mode: random
         mode: random
 
 
 val_dataset_params:
 val_dataset_params:
   transforms:
   transforms:
-    - RescaleSeg:
+    - SegRescale:
         scale_factor: 0.75
         scale_factor: 0.75
 
 
 train_dataloader_params:
 train_dataloader_params:
Discard
@@ -9,19 +9,19 @@ train_dataset_params:
   cache_images: False
   cache_images: False
   transforms:
   transforms:
     # for more options see common.factories.transforms_factory.py
     # for more options see common.factories.transforms_factory.py
-    - RandomFlipSeg:
+    - SegRandomFlip:
         prob: 0.5
         prob: 0.5
 
 
-    - RescaleSeg: # consider removing this step
+    - SegRescale: # consider removing this step
         long_size: 608
         long_size: 608
 
 
-    - RandomRescaleSeg:
+    - SegRandomRescale:
         scales: [ 0.5, 2.0 ]
         scales: [ 0.5, 2.0 ]
 
 
-    - PadShortToCropSizeSeg:
+    - SegPadShortToCropSize:
         crop_size: 512
         crop_size: 512
 
 
-    - CropImageAndMaskSeg:
+    - SegCropImageAndMask:
         crop_size: 512
         crop_size: 512
         mode: random
         mode: random
 
 
@@ -35,10 +35,10 @@ val_dataset_params:
   cache_labels: False
   cache_labels: False
   cache_images: False
   cache_images: False
   transforms:
   transforms:
-    - RescaleSeg:
+    - SegRescale:
         short_size: 512
         short_size: 512
 
 
-    - CropImageAndMaskSeg:
+    - SegCropImageAndMask:
         crop_size: 512
         crop_size: 512
         mode: center
         mode: center
 
 
Discard
@@ -7,19 +7,19 @@ train_dataset_params:
   cache_images: False
   cache_images: False
   transforms:
   transforms:
     # for more options see common.factories.transforms_factory.py
     # for more options see common.factories.transforms_factory.py
-    - RandomFlipSeg:
+    - SegRandomFlip:
         prob: 0.5
         prob: 0.5
 
 
-    - RescaleSeg: # consider removing this step
+    - SegRescale: # consider removing this step
         long_size: 608
         long_size: 608
 
 
-    - RandomRescaleSeg:
+    - SegRandomRescale:
         scales: [ 0.5, 2.0 ]
         scales: [ 0.5, 2.0 ]
 
 
-    - PadShortToCropSizeSeg:
+    - SegPadShortToCropSize:
         crop_size: 512
         crop_size: 512
 
 
-    - CropImageAndMaskSeg:
+    - SegCropImageAndMask:
         crop_size: 512
         crop_size: 512
         mode: random
         mode: random
 
 
@@ -31,10 +31,10 @@ val_dataset_params:
   cache_labels: False
   cache_labels: False
   cache_images: False
   cache_images: False
   transforms:
   transforms:
-    - RescaleSeg:
+    - SegRescale:
         short_size: 512
         short_size: 512
 
 
-    - CropImageAndMaskSeg:
+    - SegCropImageAndMask:
         crop_size: 512
         crop_size: 512
         mode: center
         mode: center
 
 
Discard
@@ -10,19 +10,19 @@ train_dataset_params:
   cache_images: False
   cache_images: False
   transforms:
   transforms:
     # for more options see common.factories.transforms_factory.py
     # for more options see common.factories.transforms_factory.py
-    - RandomFlipSeg:
+    - SegRandomFlip:
         prob: 0.5
         prob: 0.5
 
 
-    - RescaleSeg: # consider removing this step
+    - SegRescale: # consider removing this step
         long_size: 608
         long_size: 608
 
 
-    - RandomRescaleSeg:
+    - SegRandomRescale:
         scales: [ 0.5, 2.0 ]
         scales: [ 0.5, 2.0 ]
 
 
-    - PadShortToCropSizeSeg:
+    - SegPadShortToCropSize:
         crop_size: 512
         crop_size: 512
 
 
-    - CropImageAndMaskSeg:
+    - SegCropImageAndMask:
         crop_size: 512
         crop_size: 512
         mode: random
         mode: random
 
 
@@ -34,10 +34,10 @@ val_dataset_params:
   cache_labels: False
   cache_labels: False
   cache_images: False
   cache_images: False
   transforms:
   transforms:
-    - RescaleSeg:
+    - SegRescale:
         short_size: 512
         short_size: 512
 
 
-    - CropImageAndMaskSeg:
+    - SegCropImageAndMask:
         crop_size: 512
         crop_size: 512
         mode: center
         mode: center
 
 
Discard
@@ -4,18 +4,18 @@ train_dataset_params:
   cache_labels: False
   cache_labels: False
   cache_images: False
   cache_images: False
   transforms:
   transforms:
-    - RandomRescaleSeg:
+    - SegRandomRescale:
         scales: [ 0.25, 1. ]
         scales: [ 0.25, 1. ]
-    - ColorJitterSeg:
+    - SegColorJitter:
         brightness: 0.5
         brightness: 0.5
         contrast: 0.5
         contrast: 0.5
         saturation: 0.5
         saturation: 0.5
-    - RandomFlipSeg:
+    - SegRandomFlip:
         prob: 0.5
         prob: 0.5
-    - PadShortToCropSizeSeg:
+    - SegPadShortToCropSize:
         crop_size: [ 320, 480 ]
         crop_size: [ 320, 480 ]
         fill_mask: 0
         fill_mask: 0
-    - CropImageAndMaskSeg:
+    - SegCropImageAndMask:
         crop_size: [ 320, 480 ]
         crop_size: [ 320, 480 ]
         mode: random
         mode: random
 
 
@@ -25,7 +25,7 @@ val_dataset_params:
   cache_labels: False
   cache_labels: False
   cache_images: False
   cache_images: False
   transforms:
   transforms:
-    - ResizeSeg:
+    - SegResize:
         h: 480
         h: 480
         w: 320
         w: 320
 
 
Discard
    Discard
    @@ -0,0 +1,6 @@
    +from super_gradients.training.datasets.samplers.infinite_sampler import InfiniteSampler
    +from super_gradients.training.datasets.samplers.repeated_augmentation_sampler import RepeatAugSampler
    +from super_gradients.training.datasets.samplers.all_samplers import SAMPLERS, Samplers
    +
    +
    +__all__ = ['SAMPLERS', 'Samplers', 'InfiniteSampler', 'RepeatAugSampler']
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    1. from super_gradients.training.object_names import Samplers
    2. from super_gradients.training.datasets.samplers.infinite_sampler import InfiniteSampler
    3. from super_gradients.training.datasets.samplers.repeated_augmentation_sampler import RepeatAugSampler
    4. from torch.utils.data.distributed import DistributedSampler
    5. SAMPLERS = {
    6. Samplers.INFINITE: InfiniteSampler,
    7. Samplers.REPEAT_AUG: RepeatAugSampler,
    8. Samplers.DISTRIBUTED: DistributedSampler
    9. }
    Discard
    @@ -8,7 +8,7 @@ from super_gradients.training.losses.yolox_loss import YoloXDetectionLoss, YoloX
     from super_gradients.training.losses.ssd_loss import SSDLoss
     from super_gradients.training.losses.ssd_loss import SSDLoss
     from super_gradients.training.losses.bce_dice_loss import BCEDiceLoss
     from super_gradients.training.losses.bce_dice_loss import BCEDiceLoss
     from super_gradients.training.losses.dice_ce_edge_loss import DiceCEEdgeLoss
     from super_gradients.training.losses.dice_ce_edge_loss import DiceCEEdgeLoss
    -from super_gradients.training.losses.all_losses import LOSSES
    +from super_gradients.training.losses.all_losses import LOSSES, Losses
     
     
    -__all__ = ['FocalLoss', 'LabelSmoothingCrossEntropyLoss', 'ShelfNetOHEMLoss', 'ShelfNetSemanticEncodingLoss',
    -           'YoloXDetectionLoss', 'YoloXFastDetectionLoss', 'RSquaredLoss', 'SSDLoss', 'LOSSES', 'BCEDiceLoss', 'KDLogitsLoss', 'DiceCEEdgeLoss']
    +__all__ = ['LOSSES', 'Losses', 'FocalLoss', 'LabelSmoothingCrossEntropyLoss', 'ShelfNetOHEMLoss', 'ShelfNetSemanticEncodingLoss',
    +           'YoloXDetectionLoss', 'YoloXFastDetectionLoss', 'RSquaredLoss', 'SSDLoss', 'BCEDiceLoss', 'KDLogitsLoss', 'DiceCEEdgeLoss']
    Discard
    @@ -1,19 +1,19 @@
     from torch import nn
     from torch import nn
    -
    +from super_gradients.training.object_names import Losses
     from super_gradients.training.losses import LabelSmoothingCrossEntropyLoss, ShelfNetOHEMLoss, \
     from super_gradients.training.losses import LabelSmoothingCrossEntropyLoss, ShelfNetOHEMLoss, \
         ShelfNetSemanticEncodingLoss, RSquaredLoss, SSDLoss, BCEDiceLoss, YoloXDetectionLoss, YoloXFastDetectionLoss, KDLogitsLoss, DiceCEEdgeLoss
         ShelfNetSemanticEncodingLoss, RSquaredLoss, SSDLoss, BCEDiceLoss, YoloXDetectionLoss, YoloXFastDetectionLoss, KDLogitsLoss, DiceCEEdgeLoss
     from super_gradients.training.losses.stdc_loss import STDCLoss
     from super_gradients.training.losses.stdc_loss import STDCLoss
     
     
    -LOSSES = {"cross_entropy": LabelSmoothingCrossEntropyLoss,
    -          "mse": nn.MSELoss,
    -          "r_squared_loss": RSquaredLoss,
    -          "shelfnet_ohem_loss": ShelfNetOHEMLoss,
    -          "shelfnet_se_loss": ShelfNetSemanticEncodingLoss,
    -          "yolox_loss": YoloXDetectionLoss,
    -          "yolox_fast_loss": YoloXFastDetectionLoss,
    -          "ssd_loss": SSDLoss,
    -          "stdc_loss": STDCLoss,
    -          "bce_dice_loss": BCEDiceLoss,
    -          "kd_loss": KDLogitsLoss,
    -          "dice_ce_edge_loss": DiceCEEdgeLoss,
    -          }
    +
    +LOSSES = {Losses.CROSS_ENTROPY: LabelSmoothingCrossEntropyLoss,
    +          Losses.MSE: nn.MSELoss,
    +          Losses.R_SQUARED_LOSS: RSquaredLoss,
    +          Losses.SHELFNET_OHEM_LOSS: ShelfNetOHEMLoss,
    +          Losses.SHELFNET_SE_LOSS: ShelfNetSemanticEncodingLoss,
    +          Losses.YOLOX_LOSS: YoloXDetectionLoss,
    +          Losses.YOLOX_FAST_LOSS: YoloXFastDetectionLoss,
    +          Losses.SSD_LOSS: SSDLoss,
    +          Losses.STDC_LOSS: STDCLoss,
    +          Losses.BCE_DICE_LOSS: BCEDiceLoss,
    +          Losses.KD_LOSS: KDLogitsLoss,
    +          Losses.DICE_CE_EDGE_LOSS: DiceCEEdgeLoss}
    Discard
    @@ -3,7 +3,8 @@
     from super_gradients.training.metrics.classification_metrics import accuracy, Accuracy, Top5, ToyTestClassificationMetric
     from super_gradients.training.metrics.classification_metrics import accuracy, Accuracy, Top5, ToyTestClassificationMetric
     from super_gradients.training.metrics.detection_metrics import DetectionMetrics
     from super_gradients.training.metrics.detection_metrics import DetectionMetrics
     from super_gradients.training.metrics.segmentation_metrics import PreprocessSegmentationMetricsArgs, PixelAccuracy, IoU, Dice, BinaryIOU, BinaryDice
     from super_gradients.training.metrics.segmentation_metrics import PreprocessSegmentationMetricsArgs, PixelAccuracy, IoU, Dice, BinaryIOU, BinaryDice
    +from super_gradients.training.metrics.all_metrics import METRICS, Metrics
     
     
     
     
    -__all__ = ['accuracy', 'Accuracy', 'Top5', 'ToyTestClassificationMetric', 'DetectionMetrics', 'PreprocessSegmentationMetricsArgs', 'PixelAccuracy', 'IoU',
    -           'Dice', 'BinaryIOU', 'BinaryDice']
    +__all__ = ['METRICS', 'Metrics', 'accuracy', 'Accuracy', 'Top5', 'ToyTestClassificationMetric', 'DetectionMetrics', 'PreprocessSegmentationMetricsArgs',
    +           'PixelAccuracy', 'IoU', 'Dice', 'BinaryIOU', 'BinaryDice']
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    1. from super_gradients.training.object_names import Metrics
    2. from super_gradients.training.metrics import Accuracy, Top5, DetectionMetrics, IoU, PixelAccuracy, BinaryIOU, Dice,\
    3. BinaryDice
    4. METRICS = {
    5. Metrics.ACCURACY: Accuracy,
    6. Metrics.TOP5: Top5,
    7. Metrics.DETECTION_METRICS: DetectionMetrics,
    8. Metrics.IOU: IoU,
    9. Metrics.BINARY_IOU: BinaryIOU,
    10. Metrics.DICE: Dice,
    11. Metrics.BINARY_DICE: BinaryDice,
    12. Metrics.PIXEL_ACCURACY: PixelAccuracy,
    13. }
    Discard
    @@ -17,6 +17,6 @@ from super_gradients.training.models.classification_models.vgg import *
     from super_gradients.training.models.classification_models.vit import *
     from super_gradients.training.models.classification_models.vit import *
     from super_gradients.training.models.segmentation_models.shelfnet import *
     from super_gradients.training.models.segmentation_models.shelfnet import *
     from super_gradients.training.models.classification_models.efficientnet import *
     from super_gradients.training.models.classification_models.efficientnet import *
    -from super_gradients.training.models.all_architectures import ARCHITECTURES
    +from super_gradients.training.models.all_architectures import ARCHITECTURES, ModelNames
     from super_gradients.training.models.user_models import *
     from super_gradients.training.models.user_models import *
     from super_gradients.training.models.model_factory import get
     from super_gradients.training.models.model_factory import get
    Discard
    @@ -22,108 +22,206 @@ from super_gradients.training.models.segmentation_models.stdc import STDC1Classi
     from super_gradients.training.models.kd_modules.kd_module import KDModule
     from super_gradients.training.models.kd_modules.kd_module import KDModule
     from super_gradients.training.models.classification_models.beit import beit_base_patch16_224, beit_large_patch16_224
     from super_gradients.training.models.classification_models.beit import beit_base_patch16_224, beit_large_patch16_224
     from super_gradients.training.models.segmentation_models.ppliteseg import PPLiteSegT, PPLiteSegB
     from super_gradients.training.models.segmentation_models.ppliteseg import PPLiteSegT, PPLiteSegB
    -# IF YOU UPDATE THE ARCHITECTURE DICT PLEASE ALSO UPDATE THE ENUM CLASS DOWN BELOW.
     
     
     
     
    -ARCHITECTURES = {"resnet18": resnet.ResNet18,
    -                 "resnet34": resnet.ResNet34,
    -                 "resnet50_3343": resnet.ResNet50_3343,
    -                 "resnet50": resnet.ResNet50,
    -                 "resnet101": resnet.ResNet101,
    -                 "resnet152": resnet.ResNet152,
    -                 "resnet18_cifar": resnet.ResNet18Cifar,
    -                 "custom_resnet": resnet.CustomizedResnet,
    -                 "custom_resnet50": resnet.CustomizedResnet50,
    -                 "custom_resnet_cifar": resnet.CustomizedResnetCifar,
    -                 "custom_resnet50_cifar": resnet.CustomizedResnet50Cifar,
    -                 "mobilenet_v2": mobile_net_v2,
    -                 "mobile_net_v2_135": mobile_net_v2_135,
    -                 "custom_mobilenet_v2": custom_mobile_net_v2,
    -                 "mobilenet_v3_large": mobilenetv3_large,
    -                 "mobilenet_v3_small": mobilenetv3_small,
    -                 "mobilenet_v3_custom": mobilenetv3_custom,
    -                 "custom_densenet": densenet.CustomizedDensnet,
    -                 "densenet121": densenet.densenet121,
    -                 "densenet161": densenet.densenet161,
    -                 "densenet169": densenet.densenet169,
    -                 "densenet201": densenet.densenet201,
    -                 "shelfnet18_lw": ShelfNet18_LW,
    -                 "shelfnet34_lw": ShelfNet34_LW,
    -                 "shelfnet50_3343": ShelfNet503343,
    -                 "shelfnet50": ShelfNet50,
    -                 "shelfnet101": ShelfNet101,
    -                 "shufflenet_v2_x0_5": ShufflenetV2_x0_5,
    -                 "shufflenet_v2_x1_0": ShufflenetV2_x1_0,
    -                 "shufflenet_v2_x1_5": ShufflenetV2_x1_5,
    -                 "shufflenet_v2_x2_0": ShufflenetV2_x2_0,
    -                 "shufflenet_v2_custom5": CustomizedShuffleNetV2,
    -                 'darknet53': Darknet53,
    -                 'csp_darknet53': CSPDarknet53,
    -                 "resnext50": ResNeXt50,
    -                 "resnext101": ResNeXt101,
    -                 "googlenet_v1": googlenet_v1,
    -                 "efficientnet_b0": efficientnet.b0,
    -                 "efficientnet_b1": efficientnet.b1,
    -                 "efficientnet_b2": efficientnet.b2,
    -                 "efficientnet_b3": efficientnet.b3,
    -                 "efficientnet_b4": efficientnet.b4,
    -                 "efficientnet_b5": efficientnet.b5,
    -                 "efficientnet_b6": efficientnet.b6,
    -                 "efficientnet_b7": efficientnet.b7,
    -                 "efficientnet_b8": efficientnet.b8,
    -                 "efficientnet_l2": efficientnet.l2,
    -                 "CustomizedEfficientnet": efficientnet.CustomizedEfficientnet,
    -                 "regnetY200": regnet.RegNetY200,
    -                 "regnetY400": regnet.RegNetY400,
    -                 "regnetY600": regnet.RegNetY600,
    -                 "regnetY800": regnet.RegNetY800,
    -                 "custom_regnet": regnet.CustomRegNet,
    -                 "custom_anynet": regnet.CustomAnyNet,
    -                 "nas_regnet": regnet.NASRegNet,
    -                 "yolox_n": YoloX_N,
    -                 "yolox_t": YoloX_T,
    -                 "yolox_s": YoloX_S,
    -                 "yolox_m": YoloX_M,
    -                 "yolox_l": YoloX_L,
    -                 "yolox_x": YoloX_X,
    -                 "custom_yolox": CustomYoloX,
    -                 "ssd_mobilenet_v1": SSDMobileNetV1,
    -                 "ssd_lite_mobilenet_v2": SSDLiteMobileNetV2,
    -                 "repvgg_a0": repvgg.RepVggA0,
    -                 "repvgg_a1": repvgg.RepVggA1,
    -                 "repvgg_a2": repvgg.RepVggA2,
    -                 "repvgg_b0": repvgg.RepVggB0,
    -                 "repvgg_b1": repvgg.RepVggB1,
    -                 "repvgg_b2": repvgg.RepVggB2,
    -                 "repvgg_b3": repvgg.RepVggB3,
    -                 "repvgg_d2se": repvgg.RepVggD2SE,
    -                 "repvgg_custom": repvgg.RepVggCustom,
    -                 "ddrnet_23": DDRNet23,
    -                 "ddrnet_23_slim": DDRNet23Slim,
    -                 "custom_ddrnet_23": AnyBackBoneDDRNet23,
    -                 "stdc1_classification": STDC1Classification,
    -                 "stdc2_classification": STDC2Classification,
    -                 "stdc1_seg": STDC1Seg,
    -                 "stdc1_seg50": STDC1Seg,
    -                 "stdc1_seg75": STDC1Seg,
    -                 "stdc2_seg": STDC2Seg,
    -                 "stdc2_seg50": STDC2Seg,
    -                 "stdc2_seg75": STDC2Seg,
    -                 "regseg48": RegSeg48,
    -                 "kd_module": KDModule,
    -                 "vit_base": vit_base,
    -                 "vit_large": vit_large,
    -                 "vit_huge": vit_huge,
    -                 "beit_base_patch16_224": beit_base_patch16_224,
    -                 "beit_large_patch16_224": beit_large_patch16_224,
    -                 "pp_lite_t_seg": PPLiteSegT,
    -                 "pp_lite_t_seg50": PPLiteSegT,
    -                 "pp_lite_t_seg75": PPLiteSegT,
    -                 "pp_lite_b_seg": PPLiteSegB,
    -                 "pp_lite_b_seg50": PPLiteSegB,
    -                 "pp_lite_b_seg75": PPLiteSegB,
    +class ModelNames:
    +    """Static class to hold all the available model names"""""
    +    RESNET18 = "resnet18"
    +    RESNET34 = "resnet34"
    +    RESNET50_3343 = "resnet50_3343"
    +    RESNET50 = "resnet50"
    +    RESNET101 = "resnet101"
    +    RESNET152 = "resnet152"
    +    RESNET18_CIFAR = "resnet18_cifar"
    +    CUSTOM_RESNET = "custom_resnet"
    +    CUSTOM_RESNET50 = "custom_resnet50"
    +    CUSTOM_RESNET_CIFAR = "custom_resnet_cifar"
    +    CUSTOM_RESNET50_CIFAR = "custom_resnet50_cifar"
    +    MOBILENET_V2 = "mobilenet_v2"
    +    MOBILE_NET_V2_135 = "mobile_net_v2_135"
    +    CUSTOM_MOBILENET_V2 = "custom_mobilenet_v2"
    +    MOBILENET_V3_LARGE = "mobilenet_v3_large"
    +    MOBILENET_V3_SMALL = "mobilenet_v3_small"
    +    MOBILENET_V3_CUSTOM = "mobilenet_v3_custom"
    +    CUSTOM_DENSENET = "custom_densenet"
    +    DENSENET121 = "densenet121"
    +    DENSENET161 = "densenet161"
    +    DENSENET169 = "densenet169"
    +    DENSENET201 = "densenet201"
    +    SHELFNET18_LW = "shelfnet18_lw"
    +    SHELFNET34_LW = "shelfnet34_lw"
    +    SHELFNET50_3343 = "shelfnet50_3343"
    +    SHELFNET50 = "shelfnet50"
    +    SHELFNET101 = "shelfnet101"
    +    SHUFFLENET_V2_X0_5 = "shufflenet_v2_x0_5"
    +    SHUFFLENET_V2_X1_0 = "shufflenet_v2_x1_0"
    +    SHUFFLENET_V2_X1_5 = "shufflenet_v2_x1_5"
    +    SHUFFLENET_V2_X2_0 = "shufflenet_v2_x2_0"
    +    SHUFFLENET_V2_CUSTOM5 = "shufflenet_v2_custom5"
    +    DARKNET53 = "darknet53"
    +    CSP_DARKNET53 = "csp_darknet53"
    +    RESNEXT50 = "resnext50"
    +    RESNEXT101 = "resnext101"
    +    GOOGLENET_V1 = "googlenet_v1"
    +    EFFICIENTNET_B0 = "efficientnet_b0"
    +    EFFICIENTNET_B1 = "efficientnet_b1"
    +    EFFICIENTNET_B2 = "efficientnet_b2"
    +    EFFICIENTNET_B3 = "efficientnet_b3"
    +    EFFICIENTNET_B4 = "efficientnet_b4"
    +    EFFICIENTNET_B5 = "efficientnet_b5"
    +    EFFICIENTNET_B6 = "efficientnet_b6"
    +    EFFICIENTNET_B7 = "efficientnet_b7"
    +    EFFICIENTNET_B8 = "efficientnet_b8"
    +    EFFICIENTNET_L2 = "efficientnet_l2"
    +    CUSTOMIZEDEFFICIENTNET = "CustomizedEfficientnet"
    +    REGNETY200 = "regnetY200"
    +    REGNETY400 = "regnetY400"
    +    REGNETY600 = "regnetY600"
    +    REGNETY800 = "regnetY800"
    +    CUSTOM_REGNET = "custom_regnet"
    +    NAS_REGNET = "nas_regnet"
    +    YOLOX_N = "yolox_n"
    +    YOLOX_T = "yolox_t"
    +    YOLOX_S = "yolox_s"
    +    YOLOX_M = "yolox_m"
    +    YOLOX_L = "yolox_l"
    +    YOLOX_X = "yolox_x"
    +    CUSTOM_YOLO_X = "CustomYoloX"
    +    SSD_MOBILENET_V1 = "ssd_mobilenet_v1"
    +    SSD_LITE_MOBILENET_V2 = "ssd_lite_mobilenet_v2"
    +    REPVGG_A0 = "repvgg_a0"
    +    REPVGG_A1 = "repvgg_a1"
    +    REPVGG_A2 = "repvgg_a2"
    +    REPVGG_B0 = "repvgg_b0"
    +    REPVGG_B1 = "repvgg_b1"
    +    REPVGG_B2 = "repvgg_b2"
    +    REPVGG_B3 = "repvgg_b3"
    +    REPVGG_D2SE = "repvgg_d2se"
    +    REPVGG_CUSTOM = "repvgg_custom"
    +    DDRNET_23 = "ddrnet_23"
    +    DDRNET_23_SLIM = "ddrnet_23_slim"
    +    CUSTOM_DDRNET_23 = "custom_ddrnet_23"
    +    STDC1_CLASSIFICATION = "stdc1_classification"
    +    STDC2_CLASSIFICATION = "stdc2_classification"
    +    STDC1_SEG = "stdc1_seg"
    +    STDC1_SEG50 = "stdc1_seg50"
    +    STDC1_SEG75 = "stdc1_seg75"
    +    STDC2_SEG = "stdc2_seg"
    +    STDC2_SEG50 = "stdc2_seg50"
    +    STDC2_SEG75 = "stdc2_seg75"
    +    REGSEG48 = "regseg48"
    +    KD_MODULE = "kd_module"
    +    VIT_BASE = "vit_base"
    +    VIT_LARGE = "vit_large"
    +    VIT_HUGE = "vit_huge"
    +    BEIT_BASE_PATCH16_224 = "beit_base_patch16_224"
    +    BEIT_LARGE_PATCH16_224 = "beit_large_patch16_224"
    +    PP_LITE_T_SEG = "pp_lite_t_seg"
    +    PP_LITE_T_SEG50 = "pp_lite_t_seg50"
    +    PP_LITE_T_SEG75 = "pp_lite_t_seg75"
    +    PP_LITE_B_SEG = "pp_lite_b_seg"
    +    PP_LITE_B_SEG50 = "pp_lite_b_seg50"
    +    PP_LITE_B_SEG75 = "pp_lite_b_seg75"
    +
    +
    +ARCHITECTURES = {ModelNames.RESNET18: resnet.ResNet18,
    +                 ModelNames.RESNET34: resnet.ResNet34,
    +                 ModelNames.RESNET50_3343: resnet.ResNet50_3343,
    +                 ModelNames.RESNET50: resnet.ResNet50,
    +                 ModelNames.RESNET101: resnet.ResNet101,
    +                 ModelNames.RESNET152: resnet.ResNet152,
    +                 ModelNames.RESNET18_CIFAR: resnet.ResNet18Cifar,
    +                 ModelNames.CUSTOM_RESNET: resnet.CustomizedResnet,
    +                 ModelNames.CUSTOM_RESNET50: resnet.CustomizedResnet50,
    +                 ModelNames.CUSTOM_RESNET_CIFAR: resnet.CustomizedResnetCifar,
    +                 ModelNames.CUSTOM_RESNET50_CIFAR: resnet.CustomizedResnet50Cifar,
    +                 ModelNames.MOBILENET_V2: mobile_net_v2,
    +                 ModelNames.MOBILE_NET_V2_135: mobile_net_v2_135,
    +                 ModelNames.CUSTOM_MOBILENET_V2: custom_mobile_net_v2,
    +                 ModelNames.MOBILENET_V3_LARGE: mobilenetv3_large,
    +                 ModelNames.MOBILENET_V3_SMALL: mobilenetv3_small,
    +                 ModelNames.MOBILENET_V3_CUSTOM: mobilenetv3_custom,
    +                 ModelNames.CUSTOM_DENSENET: densenet.CustomizedDensnet,
    +                 ModelNames.DENSENET121: densenet.densenet121,
    +                 ModelNames.DENSENET161: densenet.densenet161,
    +                 ModelNames.DENSENET169: densenet.densenet169,
    +                 ModelNames.DENSENET201: densenet.densenet201,
    +                 ModelNames.SHELFNET18_LW: ShelfNet18_LW,
    +                 ModelNames.SHELFNET34_LW: ShelfNet34_LW,
    +                 ModelNames.SHELFNET50_3343: ShelfNet503343,
    +                 ModelNames.SHELFNET50: ShelfNet50,
    +                 ModelNames.SHELFNET101: ShelfNet101,
    +                 ModelNames.SHUFFLENET_V2_X0_5: ShufflenetV2_x0_5,
    +                 ModelNames.SHUFFLENET_V2_X1_0: ShufflenetV2_x1_0,
    +                 ModelNames.SHUFFLENET_V2_X1_5: ShufflenetV2_x1_5,
    +                 ModelNames.SHUFFLENET_V2_X2_0: ShufflenetV2_x2_0,
    +                 ModelNames.SHUFFLENET_V2_CUSTOM5: CustomizedShuffleNetV2,
    +                 ModelNames.DARKNET53: Darknet53,
    +                 ModelNames.CSP_DARKNET53: CSPDarknet53,
    +                 ModelNames.RESNEXT50: ResNeXt50,
    +                 ModelNames.RESNEXT101: ResNeXt101,
    +                 ModelNames.GOOGLENET_V1: googlenet_v1,
    +                 ModelNames.EFFICIENTNET_B0: efficientnet.b0,
    +                 ModelNames.EFFICIENTNET_B1: efficientnet.b1,
    +                 ModelNames.EFFICIENTNET_B2: efficientnet.b2,
    +                 ModelNames.EFFICIENTNET_B3: efficientnet.b3,
    +                 ModelNames.EFFICIENTNET_B4: efficientnet.b4,
    +                 ModelNames.EFFICIENTNET_B5: efficientnet.b5,
    +                 ModelNames.EFFICIENTNET_B6: efficientnet.b6,
    +                 ModelNames.EFFICIENTNET_B7: efficientnet.b7,
    +                 ModelNames.EFFICIENTNET_B8: efficientnet.b8,
    +                 ModelNames.EFFICIENTNET_L2: efficientnet.l2,
    +                 ModelNames.CUSTOMIZEDEFFICIENTNET: efficientnet.CustomizedEfficientnet,
    +                 ModelNames.REGNETY200: regnet.RegNetY200,
    +                 ModelNames.REGNETY400: regnet.RegNetY400,
    +                 ModelNames.REGNETY600: regnet.RegNetY600,
    +                 ModelNames.REGNETY800: regnet.RegNetY800,
    +                 ModelNames.CUSTOM_REGNET: regnet.CustomRegNet,
    +                 ModelNames.NAS_REGNET: regnet.NASRegNet,
    +                 ModelNames.YOLOX_N: YoloX_N,
    +                 ModelNames.YOLOX_T: YoloX_T,
    +                 ModelNames.YOLOX_S: YoloX_S,
    +                 ModelNames.YOLOX_M: YoloX_M,
    +                 ModelNames.YOLOX_L: YoloX_L,
    +                 ModelNames.YOLOX_X: YoloX_X,
    +                 ModelNames.CUSTOM_YOLO_X: CustomYoloX,
    +                 ModelNames.SSD_MOBILENET_V1: SSDMobileNetV1,
    +                 ModelNames.SSD_LITE_MOBILENET_V2: SSDLiteMobileNetV2,
    +                 ModelNames.REPVGG_A0: repvgg.RepVggA0,
    +                 ModelNames.REPVGG_A1: repvgg.RepVggA1,
    +                 ModelNames.REPVGG_A2: repvgg.RepVggA2,
    +                 ModelNames.REPVGG_B0: repvgg.RepVggB0,
    +                 ModelNames.REPVGG_B1: repvgg.RepVggB1,
    +                 ModelNames.REPVGG_B2: repvgg.RepVggB2,
    +                 ModelNames.REPVGG_B3: repvgg.RepVggB3,
    +                 ModelNames.REPVGG_D2SE: repvgg.RepVggD2SE,
    +                 ModelNames.REPVGG_CUSTOM: repvgg.RepVggCustom,
    +                 ModelNames.DDRNET_23: DDRNet23,
    +                 ModelNames.DDRNET_23_SLIM: DDRNet23Slim,
    +                 ModelNames.CUSTOM_DDRNET_23: AnyBackBoneDDRNet23,
    +                 ModelNames.STDC1_CLASSIFICATION: STDC1Classification,
    +                 ModelNames.STDC2_CLASSIFICATION: STDC2Classification,
    +                 ModelNames.STDC1_SEG: STDC1Seg,
    +                 ModelNames.STDC1_SEG50: STDC1Seg,
    +                 ModelNames.STDC1_SEG75: STDC1Seg,
    +                 ModelNames.STDC2_SEG: STDC2Seg,
    +                 ModelNames.STDC2_SEG50: STDC2Seg,
    +                 ModelNames.STDC2_SEG75: STDC2Seg,
    +                 ModelNames.REGSEG48: RegSeg48,
    +                 ModelNames.KD_MODULE: KDModule,
    +                 ModelNames.VIT_BASE: vit_base,
    +                 ModelNames.VIT_LARGE: vit_large,
    +                 ModelNames.VIT_HUGE: vit_huge,
    +                 ModelNames.BEIT_BASE_PATCH16_224: beit_base_patch16_224,
    +                 ModelNames.BEIT_LARGE_PATCH16_224: beit_large_patch16_224,
    +                 ModelNames.PP_LITE_T_SEG: PPLiteSegT,
    +                 ModelNames.PP_LITE_T_SEG50: PPLiteSegT,
    +                 ModelNames.PP_LITE_T_SEG75: PPLiteSegT,
    +                 ModelNames.PP_LITE_B_SEG: PPLiteSegB,
    +                 ModelNames.PP_LITE_B_SEG50: PPLiteSegB,
    +                 ModelNames.PP_LITE_B_SEG75: PPLiteSegB,
                      }
                      }
     
     
     KD_ARCHITECTURES = {
     KD_ARCHITECTURES = {
    -    "kd_module": KDModule
    +    ModelNames.KD_MODULE: KDModule
     }
     }
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    1. class Losses:
    2. """Static class holding all the supported loss names"""""
    3. CROSS_ENTROPY = "cross_entropy"
    4. MSE = "mse"
    5. R_SQUARED_LOSS = "r_squared_loss"
    6. SHELFNET_OHEM_LOSS = "shelfnet_ohem_loss"
    7. SHELFNET_SE_LOSS = "shelfnet_se_loss"
    8. YOLOX_LOSS = "yolox_loss"
    9. YOLOX_FAST_LOSS = "yolox_fast_loss"
    10. SSD_LOSS = "ssd_loss"
    11. STDC_LOSS = "stdc_loss"
    12. BCE_DICE_LOSS = "bce_dice_loss"
    13. KD_LOSS = "kd_loss"
    14. DICE_CE_EDGE_LOSS = "dice_ce_edge_loss"
    15. class Metrics:
    16. """Static class holding all the supported metric names"""""
    17. ACCURACY = 'Accuracy'
    18. TOP5 = 'Top5'
    19. DETECTION_METRICS = 'DetectionMetrics'
    20. IOU = 'IoU'
    21. BINARY_IOU = "BinaryIOU"
    22. DICE = "Dice"
    23. BINARY_DICE = "BinaryDice"
    24. PIXEL_ACCURACY = 'PixelAccuracy'
    25. class Transforms:
    26. """Static class holding all the supported transform names"""""
    27. # From SG
    28. SegRandomFlip = "SegRandomFlip"
    29. SegResize = "SegResize"
    30. SegRescale = "SegRescale"
    31. SegRandomRescale = "SegRandomRescale"
    32. SegRandomRotate = "SegRandomRotate"
    33. SegCropImageAndMask = "SegCropImageAndMask"
    34. SegRandomGaussianBlur = "SegRandomGaussianBlur"
    35. SegPadShortToCropSize = "SegPadShortToCropSize"
    36. SegColorJitter = "SegColorJitter"
    37. DetectionMosaic = "DetectionMosaic"
    38. DetectionRandomAffine = "DetectionRandomAffine"
    39. DetectionMixup = "DetectionMixup"
    40. DetectionHSV = "DetectionHSV"
    41. DetectionHorizontalFlip = "DetectionHorizontalFlip"
    42. DetectionPaddedRescale = "DetectionPaddedRescale"
    43. DetectionTargetsFormat = "DetectionTargetsFormat"
    44. DetectionTargetsFormatTransform = "DetectionTargetsFormatTransform"
    45. RandomResizedCropAndInterpolation = "RandomResizedCropAndInterpolation"
    46. RandAugmentTransform = "RandAugmentTransform"
    47. Lighting = "Lighting"
    48. RandomErase = "RandomErase"
    49. # From torch
    50. Compose = "Compose"
    51. ToTensor = "ToTensor"
    52. PILToTensor = "PILToTensor"
    53. ConvertImageDtype = "ConvertImageDtype"
    54. ToPILImage = "ToPILImage"
    55. Normalize = "Normalize"
    56. Resize = "Resize"
    57. CenterCrop = "CenterCrop"
    58. Pad = "Pad"
    59. Lambda = "Lambda"
    60. RandomApply = "RandomApply"
    61. RandomChoice = "RandomChoice"
    62. RandomOrder = "RandomOrder"
    63. RandomCrop = "RandomCrop"
    64. RandomHorizontalFlip = "RandomHorizontalFlip"
    65. RandomVerticalFlip = "RandomVerticalFlip"
    66. RandomResizedCrop = "RandomResizedCrop"
    67. FiveCrop = "FiveCrop"
    68. TenCrop = "TenCrop"
    69. LinearTransformation = "LinearTransformation"
    70. ColorJitter = "ColorJitter"
    71. RandomRotation = "RandomRotation"
    72. RandomAffine = "RandomAffine"
    73. Grayscale = "Grayscale"
    74. RandomGrayscale = "RandomGrayscale"
    75. RandomPerspective = "RandomPerspective"
    76. RandomErasing = "RandomErasing"
    77. GaussianBlur = "GaussianBlur"
    78. InterpolationMode = "InterpolationMode"
    79. RandomInvert = "RandomInvert"
    80. RandomPosterize = "RandomPosterize"
    81. RandomSolarize = "RandomSolarize"
    82. RandomAdjustSharpness = "RandomAdjustSharpness"
    83. RandomAutocontrast = "RandomAutocontrast"
    84. RandomEqualize = "RandomEqualize"
    85. class Optimizers:
    86. """Static class holding all the supported optimizer names"""""
    87. SGD = "SGD"
    88. ADAM = "Adam"
    89. RMS_PROP = "RMSprop"
    90. RMS_PROP_TF = "RMSpropTF"
    91. LAMB = "Lamb"
    92. class Callbacks:
    93. """Static class holding all the supported callback names"""""
    94. DECI_LAB_UPLOAD = 'DeciLabUploadCallback'
    95. LR_CALLBACK_BASE = 'LRCallbackBase'
    96. LR_SCHEDULER = 'LRSchedulerCallback'
    97. METRICS_UPDATE = 'MetricsUpdateCallback'
    98. MODEL_CONVERSION_CHECK = 'ModelConversionCheckCallback'
    99. EARLY_STOP = 'EarlyStop'
    100. DETECTION_MULTISCALE_PREPREDICTION = 'DetectionMultiscalePrePredictionCallback'
    101. YOLOX_TRAINING_STAGE_SWITCH = 'YoloXTrainingStageSwitchCallback'
    102. class LRSchedulers:
    103. """Static class to hold all the supported LR Scheduler names"""""
    104. STEP = "step"
    105. POLY = "poly"
    106. COSINE = "cosine"
    107. EXP = "exp"
    108. FUNCTION = "function"
    109. class LRWarmups:
    110. """Static class to hold all the supported LR Warmup names"""""
    111. LINEAR_STEP = "linear_step"
    112. class Samplers:
    113. """Static class to hold all the supported Samplers names"""""
    114. INFINITE = "InfiniteSampler"
    115. REPEAT_AUG = "RepeatAugSampler"
    116. DISTRIBUTED = "DistributedSampler"
    Discard
    @@ -2,8 +2,9 @@
     import cv2
     import cv2
     from super_gradients.training.transforms.transforms import DetectionMosaic, DetectionRandomAffine, DetectionHSV,\
     from super_gradients.training.transforms.transforms import DetectionMosaic, DetectionRandomAffine, DetectionHSV,\
         DetectionPaddedRescale, DetectionTargetsFormatTransform
         DetectionPaddedRescale, DetectionTargetsFormatTransform
    +from super_gradients.training.transforms.all_transforms import TRANSFORMS, Transforms
     
     
    -__all__ = ['DetectionMosaic', 'DetectionRandomAffine', 'DetectionHSV', 'DetectionPaddedRescale',
    +__all__ = ['TRANSFORMS', 'Transforms', 'DetectionMosaic', 'DetectionRandomAffine', 'DetectionHSV', 'DetectionPaddedRescale',
                'DetectionTargetsFormatTransform']
                'DetectionTargetsFormatTransform']
     
     
     cv2.setNumThreads(0)
     cv2.setNumThreads(0)
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    1. from super_gradients.training.object_names import Transforms
    2. from super_gradients.training.datasets.data_augmentation import Lighting, RandomErase
    3. from super_gradients.training.datasets.datasets_utils import RandomResizedCropAndInterpolation, rand_augment_transform
    4. from super_gradients.training.transforms.transforms import SegRandomFlip, SegRescale, SegRandomRescale, SegRandomRotate, \
    5. SegCropImageAndMask, SegRandomGaussianBlur, SegPadShortToCropSize, SegResize, SegColorJitter, DetectionMosaic, DetectionRandomAffine, \
    6. DetectionMixup, DetectionHSV, \
    7. DetectionHorizontalFlip, DetectionTargetsFormat, DetectionPaddedRescale, \
    8. DetectionTargetsFormatTransform
    9. from torchvision.transforms import Compose, ToTensor, PILToTensor, ConvertImageDtype, ToPILImage, Normalize, Resize, CenterCrop, Pad, Lambda, RandomApply,\
    10. RandomChoice, RandomOrder, RandomCrop, RandomHorizontalFlip, RandomVerticalFlip, RandomResizedCrop, FiveCrop, TenCrop, LinearTransformation, ColorJitter,\
    11. RandomRotation, RandomAffine, Grayscale, RandomGrayscale, RandomPerspective, RandomErasing, GaussianBlur, InterpolationMode, RandomInvert, RandomPosterize,\
    12. RandomSolarize, RandomAdjustSharpness, RandomAutocontrast, RandomEqualize
    13. TRANSFORMS = {
    14. Transforms.SegRandomFlip: SegRandomFlip,
    15. Transforms.SegResize: SegResize,
    16. Transforms.SegRescale: SegRescale,
    17. Transforms.SegRandomRescale: SegRandomRescale,
    18. Transforms.SegRandomRotate: SegRandomRotate,
    19. Transforms.SegCropImageAndMask: SegCropImageAndMask,
    20. Transforms.SegRandomGaussianBlur: SegRandomGaussianBlur,
    21. Transforms.SegPadShortToCropSize: SegPadShortToCropSize,
    22. Transforms.SegColorJitter: SegColorJitter,
    23. Transforms.DetectionMosaic: DetectionMosaic,
    24. Transforms.DetectionRandomAffine: DetectionRandomAffine,
    25. Transforms.DetectionMixup: DetectionMixup,
    26. Transforms.DetectionHSV: DetectionHSV,
    27. Transforms.DetectionHorizontalFlip: DetectionHorizontalFlip,
    28. Transforms.DetectionPaddedRescale: DetectionPaddedRescale,
    29. Transforms.DetectionTargetsFormat: DetectionTargetsFormat,
    30. Transforms.DetectionTargetsFormatTransform: DetectionTargetsFormatTransform,
    31. Transforms.RandomResizedCropAndInterpolation: RandomResizedCropAndInterpolation,
    32. Transforms.RandAugmentTransform: rand_augment_transform,
    33. Transforms.Lighting: Lighting,
    34. Transforms.RandomErase: RandomErase,
    35. # From torch
    36. Transforms.Compose: Compose,
    37. Transforms.ToTensor: ToTensor,
    38. Transforms.PILToTensor: PILToTensor,
    39. Transforms.ConvertImageDtype: ConvertImageDtype,
    40. Transforms.ToPILImage: ToPILImage,
    41. Transforms.Normalize: Normalize,
    42. Transforms.Resize: Resize,
    43. Transforms.CenterCrop: CenterCrop,
    44. Transforms.Pad: Pad,
    45. Transforms.Lambda: Lambda,
    46. Transforms.RandomApply: RandomApply,
    47. Transforms.RandomChoice: RandomChoice,
    48. Transforms.RandomOrder: RandomOrder,
    49. Transforms.RandomCrop: RandomCrop,
    50. Transforms.RandomHorizontalFlip: RandomHorizontalFlip,
    51. Transforms.RandomVerticalFlip: RandomVerticalFlip,
    52. Transforms.RandomResizedCrop: RandomResizedCrop,
    53. Transforms.FiveCrop: FiveCrop,
    54. Transforms.TenCrop: TenCrop,
    55. Transforms.LinearTransformation: LinearTransformation,
    56. Transforms.ColorJitter: ColorJitter,
    57. Transforms.RandomRotation: RandomRotation,
    58. Transforms.RandomAffine: RandomAffine,
    59. Transforms.Grayscale: Grayscale,
    60. Transforms.RandomGrayscale: RandomGrayscale,
    61. Transforms.RandomPerspective: RandomPerspective,
    62. Transforms.RandomErasing: RandomErasing,
    63. Transforms.GaussianBlur: GaussianBlur,
    64. Transforms.InterpolationMode: InterpolationMode,
    65. Transforms.RandomInvert: RandomInvert,
    66. Transforms.RandomPosterize: RandomPosterize,
    67. Transforms.RandomSolarize: RandomSolarize,
    68. Transforms.RandomAdjustSharpness: RandomAdjustSharpness,
    69. Transforms.RandomAutocontrast: RandomAutocontrast,
    70. Transforms.RandomEqualize: RandomEqualize,
    71. }
    Discard
    @@ -25,7 +25,7 @@ class SegmentationTransform:
             return self.__class__.__name__ + str(self.__dict__).replace('{', '(').replace('}', ')')
             return self.__class__.__name__ + str(self.__dict__).replace('{', '(').replace('}', ')')
     
     
     
     
    -class ResizeSeg(SegmentationTransform):
    +class SegResize(SegmentationTransform):
         def __init__(self, h, w):
         def __init__(self, h, w):
             self.h = h
             self.h = h
             self.w = w
             self.w = w
    @@ -38,7 +38,7 @@ class ResizeSeg(SegmentationTransform):
             return sample
             return sample
     
     
     
     
    -class RandomFlip(SegmentationTransform):
    +class SegRandomFlip(SegmentationTransform):
         """
         """
         Randomly flips the image and mask (synchronously) with probability 'prob'.
         Randomly flips the image and mask (synchronously) with probability 'prob'.
         """
         """
    @@ -59,7 +59,7 @@ class RandomFlip(SegmentationTransform):
             return sample
             return sample
     
     
     
     
    -class Rescale(SegmentationTransform):
    +class SegRescale(SegmentationTransform):
         """
         """
         Rescales the image and mask (synchronously) while preserving aspect ratio.
         Rescales the image and mask (synchronously) while preserving aspect ratio.
         The rescaling can be done according to scale_factor, short_size or long_size.
         The rescaling can be done according to scale_factor, short_size or long_size.
    @@ -118,7 +118,7 @@ class Rescale(SegmentationTransform):
                 raise ValueError(f"Long size must be a positive number, found: {self.long_size}")
                 raise ValueError(f"Long size must be a positive number, found: {self.long_size}")
     
     
     
     
    -class RandomRescale:
    +class SegRandomRescale:
         """
         """
         Random rescale the image and mask (synchronously) while preserving aspect ratio.
         Random rescale the image and mask (synchronously) while preserving aspect ratio.
         Scale factor is randomly picked between scales [min, max]
         Scale factor is randomly picked between scales [min, max]
    @@ -159,13 +159,13 @@ class RandomRescale:
                     self.scales = (1, self.scales)
                     self.scales = (1, self.scales)
     
     
             if self.scales[0] < 0 or self.scales[1] < 0:
             if self.scales[0] < 0 or self.scales[1] < 0:
    -            raise ValueError(f"RandomRescale scale values must be positive numbers, found: {self.scales}")
    +            raise ValueError(f"SegRandomRescale scale values must be positive numbers, found: {self.scales}")
             if self.scales[0] > self.scales[1]:
             if self.scales[0] > self.scales[1]:
                 self.scales = (self.scales[1], self.scales[0])
                 self.scales = (self.scales[1], self.scales[0])
             return self.scales
             return self.scales
     
     
     
     
    -class RandomRotate(SegmentationTransform):
    +class SegRandomRotate(SegmentationTransform):
         """
         """
         Randomly rotates image and mask (synchronously) between 'min_deg' and 'max_deg'.
         Randomly rotates image and mask (synchronously) between 'min_deg' and 'max_deg'.
         """
         """
    @@ -197,7 +197,7 @@ class RandomRotate(SegmentationTransform):
             self.fill_mask, self.fill_image = _validate_fill_values_arguments(self.fill_mask, self.fill_image)
             self.fill_mask, self.fill_image = _validate_fill_values_arguments(self.fill_mask, self.fill_image)
     
     
     
     
    -class CropImageAndMask(SegmentationTransform):
    +class SegCropImageAndMask(SegmentationTransform):
         """
         """
         Crops image and mask (synchronously).
         Crops image and mask (synchronously).
         In "center" mode a center crop is performed while, in "random" mode the drop will be positioned around
         In "center" mode a center crop is performed while, in "random" mode the drop will be positioned around
    @@ -248,7 +248,7 @@ class CropImageAndMask(SegmentationTransform):
                 raise ValueError(f"Crop size must be positive numbers, found: {self.crop_size}")
                 raise ValueError(f"Crop size must be positive numbers, found: {self.crop_size}")
     
     
     
     
    -class RandomGaussianBlur(SegmentationTransform):
    +class SegRandomGaussianBlur(SegmentationTransform):
         """
         """
         Adds random Gaussian Blur to image with probability 'prob'.
         Adds random Gaussian Blur to image with probability 'prob'.
         """
         """
    @@ -271,10 +271,10 @@ class RandomGaussianBlur(SegmentationTransform):
             return sample
             return sample
     
     
     
     
    -class PadShortToCropSize(SegmentationTransform):
    +class SegPadShortToCropSize(SegmentationTransform):
         """
         """
         Pads image to 'crop_size'.
         Pads image to 'crop_size'.
    -    Should be called only after "Rescale" or "RandomRescale" in augmentations pipeline.
    +    Should be called only after "SegRescale" or "SegRandomRescale" in augmentations pipeline.
         """
         """
     
     
         def __init__(self, crop_size: Union[float, Tuple, List], fill_mask: int = 0,
         def __init__(self, crop_size: Union[float, Tuple, List], fill_mask: int = 0,
    @@ -321,9 +321,9 @@ class PadShortToCropSize(SegmentationTransform):
             self.fill_mask, self.fill_image = _validate_fill_values_arguments(self.fill_mask, self.fill_image)
             self.fill_mask, self.fill_image = _validate_fill_values_arguments(self.fill_mask, self.fill_image)
     
     
     
     
    -class ColorJitterSeg(transforms.ColorJitter):
    +class SegColorJitter(transforms.ColorJitter):
         def __call__(self, sample):
         def __call__(self, sample):
    -        sample["image"] = super(ColorJitterSeg, self).__call__(sample["image"])
    +        sample["image"] = super(SegColorJitter, self).__call__(sample["image"])
             return sample
             return sample
     
     
     
     
    @@ -663,11 +663,11 @@ class DetectionPaddedRescale(DetectionTransform):
             return sample
             return sample
     
     
         def _rescale_target(self, targets: np.array, r: float) -> np.array:
         def _rescale_target(self, targets: np.array, r: float) -> np.array:
    -        """Rescale the target according to a coefficient used to rescale the image.
    +        """SegRescale the target according to a coefficient used to rescale the image.
             This is done to have images and targets at the same scale.
             This is done to have images and targets at the same scale.
     
     
             :param targets:  Targets to rescale, shape (batch_size, 6)
             :param targets:  Targets to rescale, shape (batch_size, 6)
    -        :param r:        Rescale coefficient that was applied to the image
    +        :param r:        SegRescale coefficient that was applied to the image
     
     
             :return:         Rescaled targets, shape (batch_size, 6)
             :return:         Rescaled targets, shape (batch_size, 6)
             """
             """
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    1. from super_gradients.training.utils.callbacks.callbacks import Phase, ContextSgMethods, PhaseContext, PhaseCallback, ModelConversionCheckCallback,\
    2. DeciLabUploadCallback, LRCallbackBase, WarmupLRCallback, StepLRCallback, ExponentialLRCallback, PolyLRCallback, CosineLRCallback, FunctionLRCallback,\
    3. IllegalLRSchedulerMetric, LRSchedulerCallback, MetricsUpdateCallback, KDModelMetricsUpdateCallback, PhaseContextTestCallback,\
    4. DetectionVisualizationCallback, BinarySegmentationVisualizationCallback, TrainingStageSwitchCallbackBase, YoloXTrainingStageSwitchCallback,\
    5. CallbackHandler, TestLRCallback
    6. from super_gradients.training.utils.callbacks.all_callbacks import Callbacks, CALLBACKS, LRSchedulers, LR_SCHEDULERS_CLS_DICT, LRWarmups, LR_WARMUP_CLS_DICT
    7. __all__ = ["Callbacks", "CALLBACKS", "LRSchedulers", "LR_SCHEDULERS_CLS_DICT", "LRWarmups", "LR_WARMUP_CLS_DICT", "Phase", "ContextSgMethods",
    8. "PhaseContext", "PhaseCallback", "ModelConversionCheckCallback", "DeciLabUploadCallback", "LRCallbackBase", "WarmupLRCallback", "StepLRCallback",
    9. "ExponentialLRCallback", "PolyLRCallback", "CosineLRCallback", "FunctionLRCallback", "IllegalLRSchedulerMetric", "LRSchedulerCallback",
    10. "MetricsUpdateCallback", "KDModelMetricsUpdateCallback", "PhaseContextTestCallback", "DetectionVisualizationCallback",
    11. "BinarySegmentationVisualizationCallback", "TrainingStageSwitchCallbackBase", "YoloXTrainingStageSwitchCallback", "CallbackHandler",
    12. "TestLRCallback"]
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    1. from super_gradients.training.object_names import Callbacks, LRSchedulers, LRWarmups
    2. from super_gradients.training.utils.callbacks.callbacks import DeciLabUploadCallback, LRCallbackBase, LRSchedulerCallback, MetricsUpdateCallback, \
    3. ModelConversionCheckCallback, YoloXTrainingStageSwitchCallback, StepLRCallback, PolyLRCallback,\
    4. CosineLRCallback, ExponentialLRCallback, FunctionLRCallback, WarmupLRCallback
    5. from super_gradients.training.utils.early_stopping import EarlyStop
    6. from super_gradients.training.datasets.datasets_utils import DetectionMultiscalePrePredictionCallback
    7. CALLBACKS = {
    8. Callbacks.DECI_LAB_UPLOAD: DeciLabUploadCallback,
    9. Callbacks.LR_CALLBACK_BASE: LRCallbackBase,
    10. Callbacks.LR_SCHEDULER: LRSchedulerCallback,
    11. Callbacks.METRICS_UPDATE: MetricsUpdateCallback,
    12. Callbacks.MODEL_CONVERSION_CHECK: ModelConversionCheckCallback,
    13. Callbacks.EARLY_STOP: EarlyStop,
    14. Callbacks.DETECTION_MULTISCALE_PREPREDICTION: DetectionMultiscalePrePredictionCallback,
    15. Callbacks.YOLOX_TRAINING_STAGE_SWITCH: YoloXTrainingStageSwitchCallback
    16. }
    17. LR_SCHEDULERS_CLS_DICT = {
    18. LRSchedulers.STEP: StepLRCallback,
    19. LRSchedulers.POLY: PolyLRCallback,
    20. LRSchedulers.COSINE: CosineLRCallback,
    21. LRSchedulers.EXP: ExponentialLRCallback,
    22. LRSchedulers.FUNCTION: FunctionLRCallback,
    23. }
    24. LR_WARMUP_CLS_DICT = {LRWarmups.LINEAR_STEP: WarmupLRCallback}
    Discard
    @@ -739,18 +739,6 @@ class CallbackHandler:
                     callback(context)
                     callback(context)
     
     
     
     
    -# DICT FOR LEGACY LR HARD-CODED REGIMES, WILL BE DELETED IN THE FUTURE
    -LR_SCHEDULERS_CLS_DICT = {
    -    "step": StepLRCallback,
    -    "poly": PolyLRCallback,
    -    "cosine": CosineLRCallback,
    -    "exp": ExponentialLRCallback,
    -    "function": FunctionLRCallback,
    -}
    -
    -LR_WARMUP_CLS_DICT = {"linear_step": WarmupLRCallback}
    -
    -
     class TestLRCallback(PhaseCallback):
     class TestLRCallback(PhaseCallback):
         """
         """
         Phase callback that collects the learning rates in lr_placeholder at the end of each epoch (used for testing). In
         Phase callback that collects the learning rates in lr_placeholder at the end of each epoch (used for testing). In
    Discard
    @@ -0,0 +1,6 @@
    +from super_gradients.training.utils.optimizers.rmsprop_tf import RMSpropTF
    +from super_gradients.training.utils.optimizers.lamb import Lamb
    +
    +from super_gradients.training.utils.optimizers.all_optimizers import OPTIMIZERS, Optimizers
    +
    +__all__ = ['OPTIMIZERS', 'Optimizers', 'RMSpropTF', 'Lamb']
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    1. from torch import optim
    2. from super_gradients.training.object_names import Optimizers
    3. from super_gradients.training.utils.optimizers.rmsprop_tf import RMSpropTF
    4. from super_gradients.training.utils.optimizers.lamb import Lamb
    5. OPTIMIZERS = {
    6. Optimizers.SGD: optim.SGD,
    7. Optimizers.ADAM: optim.Adam,
    8. Optimizers.RMS_PROP: optim.RMSprop,
    9. Optimizers.RMS_PROP_TF: RMSpropTF,
    10. Optimizers.LAMB: Lamb
    11. }
    Discard
    @@ -2,7 +2,7 @@ import unittest
     
     
     import torch
     import torch
     from torchvision.transforms import Compose, ToTensor
     from torchvision.transforms import Compose, ToTensor
    -from super_gradients.training.transforms.transforms import Rescale, RandomRescale, CropImageAndMask, PadShortToCropSize
    +from super_gradients.training.transforms.transforms import SegRescale, SegRandomRescale, SegCropImageAndMask, SegPadShortToCropSize
     from PIL import Image
     from PIL import Image
     from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
     from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
     
     
    @@ -23,79 +23,79 @@ class SegmentationTransformsTest(unittest.TestCase):
         def test_rescale_with_scale_factor(self):
         def test_rescale_with_scale_factor(self):
             # test raise exception for negative and zero scale factor
             # test raise exception for negative and zero scale factor
             kwargs = {"scale_factor": -2}
             kwargs = {"scale_factor": -2}
    -        self.failUnlessRaises(ValueError, Rescale, **kwargs)
    +        self.failUnlessRaises(ValueError, SegRescale, **kwargs)
             kwargs = {"scale_factor": 0}
             kwargs = {"scale_factor": 0}
    -        self.failUnlessRaises(ValueError, Rescale, **kwargs)
    +        self.failUnlessRaises(ValueError, SegRescale, **kwargs)
     
     
             # test scale down
             # test scale down
             sample = self.create_sample((1024, 512))
             sample = self.create_sample((1024, 512))
    -        rescale_scale05 = Rescale(scale_factor=0.5)
    +        rescale_scale05 = SegRescale(scale_factor=0.5)
             out = rescale_scale05(sample)
             out = rescale_scale05(sample)
             self.assertEqual((512, 256), out["image"].size)
             self.assertEqual((512, 256), out["image"].size)
     
     
             # test scale up
             # test scale up
             sample = self.create_sample((1024, 512))
             sample = self.create_sample((1024, 512))
    -        rescale_scale2 = Rescale(scale_factor=2.0)
    +        rescale_scale2 = SegRescale(scale_factor=2.0)
             out = rescale_scale2(sample)
             out = rescale_scale2(sample)
             self.assertEqual((2048, 1024), out["image"].size)
             self.assertEqual((2048, 1024), out["image"].size)
     
     
             # test scale_factor is stronger than other params
             # test scale_factor is stronger than other params
             sample = self.create_sample((1024, 512))
             sample = self.create_sample((1024, 512))
    -        rescale_scale05 = Rescale(scale_factor=0.5, short_size=300, long_size=600)
    +        rescale_scale05 = SegRescale(scale_factor=0.5, short_size=300, long_size=600)
             out = rescale_scale05(sample)
             out = rescale_scale05(sample)
             self.assertEqual((512, 256), out["image"].size)
             self.assertEqual((512, 256), out["image"].size)
     
     
         def test_rescale_with_short_size(self):
         def test_rescale_with_short_size(self):
             # test raise exception for negative and zero short_size
             # test raise exception for negative and zero short_size
             kwargs = {"short_size": 0}
             kwargs = {"short_size": 0}
    -        self.failUnlessRaises(ValueError, Rescale, **kwargs)
    +        self.failUnlessRaises(ValueError, SegRescale, **kwargs)
             kwargs = {"short_size": -200}
             kwargs = {"short_size": -200}
    -        self.failUnlessRaises(ValueError, Rescale, **kwargs)
    +        self.failUnlessRaises(ValueError, SegRescale, **kwargs)
     
     
             # test scale by short size
             # test scale by short size
             sample = self.create_sample((1024, 512))
             sample = self.create_sample((1024, 512))
    -        rescale_short256 = Rescale(short_size=256)
    +        rescale_short256 = SegRescale(short_size=256)
             out = rescale_short256(sample)
             out = rescale_short256(sample)
             self.assertEqual((512, 256), out["image"].size)
             self.assertEqual((512, 256), out["image"].size)
     
     
             # test short_size is stronger than long_size
             # test short_size is stronger than long_size
             sample = self.create_sample((1024, 512))
             sample = self.create_sample((1024, 512))
    -        rescale_scale05 = Rescale(short_size=301, long_size=301)
    +        rescale_scale05 = SegRescale(short_size=301, long_size=301)
             out = rescale_scale05(sample)
             out = rescale_scale05(sample)
             self.assertEqual((602, 301), out["image"].size)
             self.assertEqual((602, 301), out["image"].size)
     
     
         def test_rescale_with_long_size(self):
         def test_rescale_with_long_size(self):
             # test raise exception for negative and zero short_size
             # test raise exception for negative and zero short_size
             kwargs = {"long_size": 0}
             kwargs = {"long_size": 0}
    -        self.failUnlessRaises(ValueError, Rescale, **kwargs)
    +        self.failUnlessRaises(ValueError, SegRescale, **kwargs)
             kwargs = {"long_size": -200}
             kwargs = {"long_size": -200}
    -        self.failUnlessRaises(ValueError, Rescale, **kwargs)
    +        self.failUnlessRaises(ValueError, SegRescale, **kwargs)
     
     
             # test scale by long size
             # test scale by long size
             sample = self.create_sample((1024, 512))
             sample = self.create_sample((1024, 512))
    -        rescale_long256 = Rescale(long_size=256)
    +        rescale_long256 = SegRescale(long_size=256)
             out = rescale_long256(sample)
             out = rescale_long256(sample)
             self.assertEqual((256, 128), out["image"].size)
             self.assertEqual((256, 128), out["image"].size)
     
     
         def test_random_rescale(self):
         def test_random_rescale(self):
             # test passing scales argument
             # test passing scales argument
    -        random_rescale = RandomRescale(scales=0.1)
    +        random_rescale = SegRandomRescale(scales=0.1)
             self.assertEqual((0.1, 1), random_rescale.scales)
             self.assertEqual((0.1, 1), random_rescale.scales)
     
     
    -        random_rescale = RandomRescale(scales=1.2)
    +        random_rescale = SegRandomRescale(scales=1.2)
             self.assertEqual((1, 1.2), random_rescale.scales)
             self.assertEqual((1, 1.2), random_rescale.scales)
     
     
    -        random_rescale = RandomRescale(scales=(0.5, 1.2))
    +        random_rescale = SegRandomRescale(scales=(0.5, 1.2))
             self.assertEqual((0.5, 1.2), random_rescale.scales)
             self.assertEqual((0.5, 1.2), random_rescale.scales)
     
     
             kwargs = {"scales": -0.5}
             kwargs = {"scales": -0.5}
    -        self.failUnlessRaises(ValueError, RandomRescale, **kwargs)
    +        self.failUnlessRaises(ValueError, SegRandomRescale, **kwargs)
     
     
             # test random rescale
             # test random rescale
             size = [1024, 512]
             size = [1024, 512]
             scales = [0.8, 1.2]
             scales = [0.8, 1.2]
             sample = self.create_sample(size)
             sample = self.create_sample(size)
    -        random_rescale = RandomRescale(scales=(0.8, 1.2))
    +        random_rescale = SegRandomRescale(scales=(0.8, 1.2))
             min_size = [scales[0] * s for s in size]
             min_size = [scales[0] * s for s in size]
             max_size = [scales[1] * s for s in size]
             max_size = [scales[1] * s for s in size]
     
     
    @@ -106,30 +106,30 @@ class SegmentationTransformsTest(unittest.TestCase):
     
     
         def test_padding(self):
         def test_padding(self):
             # test arguments are valid
             # test arguments are valid
    -        pad = PadShortToCropSize(crop_size=200)
    +        pad = SegPadShortToCropSize(crop_size=200)
             self.assertEqual((200, 200), pad.crop_size)
             self.assertEqual((200, 200), pad.crop_size)
     
     
             kwargs = {"crop_size": (0, 200)}
             kwargs = {"crop_size": (0, 200)}
    -        self.failUnlessRaises(ValueError, PadShortToCropSize, **kwargs)
    +        self.failUnlessRaises(ValueError, SegPadShortToCropSize, **kwargs)
     
     
             kwargs = {"crop_size": 200, "fill_image": 256}
             kwargs = {"crop_size": 200, "fill_image": 256}
    -        self.failUnlessRaises(ValueError, PadShortToCropSize, **kwargs)
    +        self.failUnlessRaises(ValueError, SegPadShortToCropSize, **kwargs)
     
     
             kwargs = {"crop_size": 200, "fill_mask": 256}
             kwargs = {"crop_size": 200, "fill_mask": 256}
    -        self.failUnlessRaises(ValueError, PadShortToCropSize, **kwargs)
    +        self.failUnlessRaises(ValueError, SegPadShortToCropSize, **kwargs)
     
     
             in_size = (512, 256)
             in_size = (512, 256)
     
     
             out_size = (512, 512)
             out_size = (512, 512)
             sample = self.create_sample(in_size)
             sample = self.create_sample(in_size)
    -        padding = PadShortToCropSize(crop_size=out_size)
    +        padding = SegPadShortToCropSize(crop_size=out_size)
             out = padding(sample)
             out = padding(sample)
             self.assertEqual(out_size, out["image"].size)
             self.assertEqual(out_size, out["image"].size)
     
     
             # pad to odd size
             # pad to odd size
             out_size = (512, 501)
             out_size = (512, 501)
             sample = self.create_sample(in_size)
             sample = self.create_sample(in_size)
    -        padding = PadShortToCropSize(crop_size=out_size)
    +        padding = SegPadShortToCropSize(crop_size=out_size)
             out = padding(sample)
             out = padding(sample)
             self.assertEqual(out_size, out["image"].size)
             self.assertEqual(out_size, out["image"].size)
     
     
    @@ -143,7 +143,7 @@ class SegmentationTransformsTest(unittest.TestCase):
             fill_image_value = 127
             fill_image_value = 127
     
     
             sample = self.create_sample(in_size)
             sample = self.create_sample(in_size)
    -        padding = PadShortToCropSize(crop_size=out_size, fill_mask=fill_mask_value, fill_image=fill_image_value)
    +        padding = SegPadShortToCropSize(crop_size=out_size, fill_mask=fill_mask_value, fill_image=fill_image_value)
             out = padding(sample)
             out = padding(sample)
     
     
             out_mask = SegmentationDataSet.target_transform(out["mask"])
             out_mask = SegmentationDataSet.target_transform(out["mask"])
    @@ -172,20 +172,20 @@ class SegmentationTransformsTest(unittest.TestCase):
     
     
         def test_crop(self):
         def test_crop(self):
             # test arguments are valid
             # test arguments are valid
    -        pad = CropImageAndMask(crop_size=200, mode="center")
    +        pad = SegCropImageAndMask(crop_size=200, mode="center")
             self.assertEqual((200, 200), pad.crop_size)
             self.assertEqual((200, 200), pad.crop_size)
     
     
             kwargs = {"crop_size": (0, 200), "mode": "random"}
             kwargs = {"crop_size": (0, 200), "mode": "random"}
    -        self.failUnlessRaises(ValueError, CropImageAndMask, **kwargs)
    +        self.failUnlessRaises(ValueError, SegCropImageAndMask, **kwargs)
             # test unsupported mode
             # test unsupported mode
             kwargs = {"crop_size": (200, 200), "mode": "deci"}
             kwargs = {"crop_size": (200, 200), "mode": "deci"}
    -        self.failUnlessRaises(ValueError, CropImageAndMask, **kwargs)
    +        self.failUnlessRaises(ValueError, SegCropImageAndMask, **kwargs)
     
     
             in_size = (1024, 512)
             in_size = (1024, 512)
             out_size = (128, 256)
             out_size = (128, 256)
     
     
    -        crop_center = CropImageAndMask(crop_size=out_size, mode="center")
    -        crop_random = CropImageAndMask(crop_size=out_size, mode="random")
    +        crop_center = SegCropImageAndMask(crop_size=out_size, mode="center")
    +        crop_random = SegCropImageAndMask(crop_size=out_size, mode="random")
     
     
             sample = self.create_sample(in_size)
             sample = self.create_sample(in_size)
             out_center = crop_center(sample)
             out_center = crop_center(sample)
    @@ -202,8 +202,8 @@ class SegmentationTransformsTest(unittest.TestCase):
             sample = self.create_sample(in_size)
             sample = self.create_sample(in_size)
     
     
             transform = Compose([
             transform = Compose([
    -            Rescale(long_size=out_size[0]),         # rescale to (512, 256)
    -            PadShortToCropSize(crop_size=out_size)  # pad to (512, 512)
    +            SegRescale(long_size=out_size[0]),         # rescale to (512, 256)
    +            SegPadShortToCropSize(crop_size=out_size)  # pad to (512, 512)
             ])
             ])
             out = transform(sample)
             out = transform(sample)
             self.assertEqual(out_size, out["image"].size)
             self.assertEqual(out_size, out["image"].size)
    @@ -214,9 +214,9 @@ class SegmentationTransformsTest(unittest.TestCase):
             sample = self.create_sample(img_size)
             sample = self.create_sample(img_size)
     
     
             transform = Compose([
             transform = Compose([
    -            RandomRescale(scales=(0.1, 2.0)),
    -            PadShortToCropSize(crop_size=crop_size),
    -            CropImageAndMask(crop_size=crop_size, mode="random")
    +            SegRandomRescale(scales=(0.1, 2.0)),
    +            SegPadShortToCropSize(crop_size=crop_size),
    +            SegCropImageAndMask(crop_size=crop_size, mode="random")
             ])
             ])
     
     
             out = transform(sample)
             out = transform(sample)
    Discard
    Discard