Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#740 Feature/sg 691 register losses dynamically

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-691-register_dynamically_0
@@ -1,5 +1,5 @@
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.base_factory import BaseFactory
-from super_gradients.training.losses import LOSSES
+from super_gradients.common.registry.registry import LOSSES
 
 
 
 
 class LossesFactory(BaseFactory):
 class LossesFactory(BaseFactory):
Discard
@@ -1,12 +1,14 @@
 import inspect
 import inspect
 from typing import Callable, Dict, Optional
 from typing import Callable, Dict, Optional
 
 
+from torch import nn
+
+from super_gradients.common import object_names
 from super_gradients.training.utils.callbacks import LR_SCHEDULERS_CLS_DICT
 from super_gradients.training.utils.callbacks import LR_SCHEDULERS_CLS_DICT
 from super_gradients.common.sg_loggers import SG_LOGGERS
 from super_gradients.common.sg_loggers import SG_LOGGERS
 from super_gradients.training.dataloaders.dataloaders import ALL_DATALOADERS
 from super_gradients.training.dataloaders.dataloaders import ALL_DATALOADERS
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.metrics.all_metrics import METRICS
 from super_gradients.training.metrics.all_metrics import METRICS
-from super_gradients.training.losses.all_losses import LOSSES
 from super_gradients.modules.detection_modules import ALL_DETECTION_MODULES
 from super_gradients.modules.detection_modules import ALL_DETECTION_MODULES
 from super_gradients.training.utils.callbacks.all_callbacks import CALLBACKS
 from super_gradients.training.utils.callbacks.all_callbacks import CALLBACKS
 from super_gradients.training.transforms.all_transforms import TRANSFORMS
 from super_gradients.training.transforms.all_transforms import TRANSFORMS
@@ -24,8 +26,8 @@ def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
     """
     """
     Create a decorator that registers object of specified type (model, metric, ...)
     Create a decorator that registers object of specified type (model, metric, ...)
 
 
-    :param registry: The registry (maps name to object that you register)
-    :return:         Register function
+    :param registry:    Dict including registered objects (maps name to object that you register)
+    :return:            Register function
     """
     """
 
 
     def register(name: Optional[str] = None) -> Callable:
     def register(name: Optional[str] = None) -> Callable:
@@ -55,7 +57,10 @@ def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
 register_model = create_register_decorator(registry=ARCHITECTURES)
 register_model = create_register_decorator(registry=ARCHITECTURES)
 register_detection_module = create_register_decorator(registry=ALL_DETECTION_MODULES)
 register_detection_module = create_register_decorator(registry=ALL_DETECTION_MODULES)
 register_metric = create_register_decorator(registry=METRICS)
 register_metric = create_register_decorator(registry=METRICS)
+
+LOSSES = {object_names.Losses.MSE: nn.MSELoss}
 register_loss = create_register_decorator(registry=LOSSES)
 register_loss = create_register_decorator(registry=LOSSES)
+
 register_dataloader = create_register_decorator(registry=ALL_DATALOADERS)
 register_dataloader = create_register_decorator(registry=ALL_DATALOADERS)
 register_callback = create_register_decorator(registry=CALLBACKS)
 register_callback = create_register_decorator(registry=CALLBACKS)
 register_transform = create_register_decorator(registry=TRANSFORMS)
 register_transform = create_register_decorator(registry=TRANSFORMS)
Discard
    Discard
    @@ -1,3 +1,6 @@
    +from super_gradients.common.registry.registry import LOSSES
    +from super_gradients.common.object_names import Losses
    +
     from super_gradients.training.losses.focal_loss import FocalLoss
     from super_gradients.training.losses.focal_loss import FocalLoss
     from super_gradients.training.losses.kd_losses import KDLogitsLoss
     from super_gradients.training.losses.kd_losses import KDLogitsLoss
     from super_gradients.training.losses.label_smoothing_cross_entropy_loss import LabelSmoothingCrossEntropyLoss
     from super_gradients.training.losses.label_smoothing_cross_entropy_loss import LabelSmoothingCrossEntropyLoss
    @@ -8,7 +11,6 @@ from super_gradients.training.losses.yolox_loss import YoloXDetectionLoss, YoloX
     from super_gradients.training.losses.ssd_loss import SSDLoss
     from super_gradients.training.losses.ssd_loss import SSDLoss
     from super_gradients.training.losses.bce_dice_loss import BCEDiceLoss
     from super_gradients.training.losses.bce_dice_loss import BCEDiceLoss
     from super_gradients.training.losses.dice_ce_edge_loss import DiceCEEdgeLoss
     from super_gradients.training.losses.dice_ce_edge_loss import DiceCEEdgeLoss
    -from super_gradients.training.losses.all_losses import LOSSES, Losses
     from super_gradients.training.losses.ppyolo_loss import PPYoloELoss
     from super_gradients.training.losses.ppyolo_loss import PPYoloELoss
     
     
     __all__ = [
     __all__ = [
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    1. from torch import nn
    2. from super_gradients.common.object_names import Losses
    3. from super_gradients.training.losses import (
    4. LabelSmoothingCrossEntropyLoss,
    5. ShelfNetOHEMLoss,
    6. ShelfNetSemanticEncodingLoss,
    7. RSquaredLoss,
    8. SSDLoss,
    9. BCEDiceLoss,
    10. YoloXDetectionLoss,
    11. YoloXFastDetectionLoss,
    12. KDLogitsLoss,
    13. DiceCEEdgeLoss,
    14. )
    15. from super_gradients.training.losses.stdc_loss import STDCLoss
    16. from super_gradients.training.losses.ppyolo_loss import PPYoloELoss
    17. from super_gradients.training.losses.dekr_loss import DEKRLoss
    18. LOSSES = {
    19. Losses.CROSS_ENTROPY: LabelSmoothingCrossEntropyLoss,
    20. Losses.MSE: nn.MSELoss,
    21. Losses.R_SQUARED_LOSS: RSquaredLoss,
    22. Losses.SHELFNET_OHEM_LOSS: ShelfNetOHEMLoss,
    23. Losses.SHELFNET_SE_LOSS: ShelfNetSemanticEncodingLoss,
    24. Losses.YOLOX_LOSS: YoloXDetectionLoss,
    25. Losses.YOLOX_FAST_LOSS: YoloXFastDetectionLoss,
    26. Losses.SSD_LOSS: SSDLoss,
    27. Losses.STDC_LOSS: STDCLoss,
    28. Losses.BCE_DICE_LOSS: BCEDiceLoss,
    29. Losses.KD_LOSS: KDLogitsLoss,
    30. Losses.DICE_CE_EDGE_LOSS: DiceCEEdgeLoss,
    31. Losses.PPYOLOE_LOSS: PPYoloELoss,
    32. Losses.DEKR_LOSS: DEKRLoss,
    33. }
    Discard
    @@ -1,9 +1,12 @@
     import torch
     import torch
     
     
    +from super_gradients.common.object_names import Losses
    +from super_gradients.common.registry.registry import register_loss
     from super_gradients.training.losses.bce_loss import BCE
     from super_gradients.training.losses.bce_loss import BCE
     from super_gradients.training.losses.dice_loss import BinaryDiceLoss
     from super_gradients.training.losses.dice_loss import BinaryDiceLoss
     
     
     
     
    +@register_loss(Losses.BCE_DICE_LOSS)
     class BCEDiceLoss(torch.nn.Module):
     class BCEDiceLoss(torch.nn.Module):
         """
         """
         Binary Cross Entropy + Dice Loss
         Binary Cross Entropy + Dice Loss
    Discard
    @@ -3,7 +3,11 @@ from typing import Tuple
     import torch
     import torch
     from torch import Tensor, nn
     from torch import Tensor, nn
     
     
    +from super_gradients.common.object_names import Losses
    +from super_gradients.common.registry.registry import register_loss
     
     
    +
    +@register_loss(Losses.DEKR_LOSS)
     class DEKRLoss(nn.Module):
     class DEKRLoss(nn.Module):
         """
         """
         Implementation of the loss function from the "Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression"
         Implementation of the loss function from the "Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression"
    Discard
    @@ -4,9 +4,13 @@ from super_gradients.training.losses.dice_loss import DiceLoss, BinaryDiceLoss
     from super_gradients.training.utils.segmentation_utils import target_to_binary_edge
     from super_gradients.training.utils.segmentation_utils import target_to_binary_edge
     from torch.nn.modules.loss import _Loss
     from torch.nn.modules.loss import _Loss
     from typing import Union, Tuple
     from typing import Union, Tuple
    +
    +from super_gradients.common.object_names import Losses
    +from super_gradients.common.registry.registry import register_loss
     from super_gradients.training.losses.mask_loss import MaskAttentionLoss
     from super_gradients.training.losses.mask_loss import MaskAttentionLoss
     
     
     
     
    +@register_loss(Losses.DICE_CE_EDGE_LOSS)
     class DiceCEEdgeLoss(_Loss):
     class DiceCEEdgeLoss(_Loss):
         def __init__(
         def __init__(
             self,
             self,
    Discard
    @@ -1,6 +1,9 @@
     from torch.nn.modules.loss import _Loss, KLDivLoss
     from torch.nn.modules.loss import _Loss, KLDivLoss
     import torch
     import torch
     
     
    +from super_gradients.common.object_names import Losses
    +from super_gradients.common.registry.registry import register_loss
    +
     
     
     class KDklDivLoss(KLDivLoss):
     class KDklDivLoss(KLDivLoss):
         """KL divergence wrapper for knowledge distillation"""
         """KL divergence wrapper for knowledge distillation"""
    @@ -12,6 +15,7 @@ class KDklDivLoss(KLDivLoss):
             return super(KDklDivLoss, self).forward(torch.log_softmax(student_output, dim=1), torch.softmax(teacher_output, dim=1))
             return super(KDklDivLoss, self).forward(torch.log_softmax(student_output, dim=1), torch.softmax(teacher_output, dim=1))
     
     
     
     
    +@register_loss(Losses.KD_LOSS)
     class KDLogitsLoss(_Loss):
     class KDLogitsLoss(_Loss):
         """Knowledge distillation loss, wraps the task loss and distillation loss"""
         """Knowledge distillation loss, wraps the task loss and distillation loss"""
     
     
    Discard
    @@ -2,6 +2,9 @@ import torch
     from torch import nn
     from torch import nn
     import torch.nn.functional as F
     import torch.nn.functional as F
     
     
    +from super_gradients.common.object_names import Losses
    +from super_gradients.common.registry.registry import register_loss
    +
     
     
     def onehot(indexes, N=None, ignore_index=None):
     def onehot(indexes, N=None, ignore_index=None):
         """
         """
    @@ -80,6 +83,7 @@ def cross_entropy(inputs, target, weight=None, ignore_index=-100, reduction="mea
         return loss
         return loss
     
     
     
     
    +@register_loss(Losses.CROSS_ENTROPY)
     class LabelSmoothingCrossEntropyLoss(nn.CrossEntropyLoss):
     class LabelSmoothingCrossEntropyLoss(nn.CrossEntropyLoss):
         """CrossEntropyLoss - with ability to recieve distrbution as targets, and optional label smoothing"""
         """CrossEntropyLoss - with ability to recieve distrbution as targets, and optional label smoothing"""
     
     
    Discard
    @@ -6,6 +6,8 @@ import torch.nn.functional as F
     from torch import nn, Tensor
     from torch import nn, Tensor
     
     
     import super_gradients
     import super_gradients
    +from super_gradients.common.object_names import Losses
    +from super_gradients.common.registry.registry import register_loss
     from super_gradients.training.datasets.data_formats.bbox_formats.cxcywh import cxcywh_to_xyxy
     from super_gradients.training.datasets.data_formats.bbox_formats.cxcywh import cxcywh_to_xyxy
     from super_gradients.training.utils.bbox_utils import batch_distance2bbox
     from super_gradients.training.utils.bbox_utils import batch_distance2bbox
     from super_gradients.training.utils.distributed_training_utils import (
     from super_gradients.training.utils.distributed_training_utils import (
    @@ -640,6 +642,7 @@ class GIoULoss(object):
             return loss * self.loss_weight
             return loss * self.loss_weight
     
     
     
     
    +@register_loss(Losses.PPYOLOE_LOSS)
     class PPYoloELoss(nn.Module):
     class PPYoloELoss(nn.Module):
         def __init__(
         def __init__(
             self,
             self,
    Discard
    @@ -7,6 +7,11 @@ from torch.nn.modules.loss import _Loss
     from super_gradients.training.utils import convert_to_tensor
     from super_gradients.training.utils import convert_to_tensor
     
     
     
     
    +from super_gradients.common.object_names import Losses
    +from super_gradients.common.registry.registry import register_loss
    +
    +
    +@register_loss(Losses.R_SQUARED_LOSS)
     class RSquaredLoss(_Loss):
     class RSquaredLoss(_Loss):
         def forward(self, output, target):
         def forward(self, output, target):
             # FIXME - THIS NEEDS TO BE CHANGED SUCH THAT THIS CLASS INHERETS FROM _Loss (TAKE A LOOK AT YoLoV3DetectionLoss)
             # FIXME - THIS NEEDS TO BE CHANGED SUCH THAT THIS CLASS INHERETS FROM _Loss (TAKE A LOOK AT YoLoV3DetectionLoss)
    Discard
    @@ -1,8 +1,11 @@
     import torch
     import torch
     
     
    +from super_gradients.common.object_names import Losses
    +from super_gradients.common.registry.registry import register_loss
     from super_gradients.training.losses.ohem_ce_loss import OhemCELoss
     from super_gradients.training.losses.ohem_ce_loss import OhemCELoss
     
     
     
     
    +@register_loss(Losses.SHELFNET_OHEM_LOSS)
     class ShelfNetOHEMLoss(OhemCELoss):
     class ShelfNetOHEMLoss(OhemCELoss):
         def __init__(self, threshold: float = 0.7, mining_percent: float = 1e-4, ignore_lb: int = 255):
         def __init__(self, threshold: float = 0.7, mining_percent: float = 1e-4, ignore_lb: int = 255):
             """
             """
    Discard
    @@ -3,6 +3,11 @@ from torch import nn
     from torch.autograd import Variable
     from torch.autograd import Variable
     
     
     
     
    +from super_gradients.common.object_names import Losses
    +from super_gradients.common.registry.registry import register_loss
    +
    +
    +@register_loss(Losses.SHELFNET_SE_LOSS)
     class ShelfNetSemanticEncodingLoss(nn.CrossEntropyLoss):
     class ShelfNetSemanticEncodingLoss(nn.CrossEntropyLoss):
         """2D Cross Entropy Loss with Auxilary Loss"""
         """2D Cross Entropy Loss with Auxilary Loss"""
     
     
    Discard
    @@ -4,6 +4,8 @@ import torch
     from torch import nn
     from torch import nn
     from torch.nn.modules.loss import _Loss
     from torch.nn.modules.loss import _Loss
     
     
    +from super_gradients.common.object_names import Losses
    +from super_gradients.common.registry.registry import register_loss
     from super_gradients.training.utils.detection_utils import calculate_bbox_iou_matrix
     from super_gradients.training.utils.detection_utils import calculate_bbox_iou_matrix
     from super_gradients.training.utils.ssd_utils import DefaultBoxes
     from super_gradients.training.utils.ssd_utils import DefaultBoxes
     
     
    @@ -50,6 +52,7 @@ class HardMiningCrossEntropyLoss(_Loss):
             return closs
             return closs
     
     
     
     
    +@register_loss(Losses.SSD_LOSS)
     class SSDLoss(_Loss):
     class SSDLoss(_Loss):
         """
         """
             Implements the loss as the sum of the followings:
             Implements the loss as the sum of the followings:
    Discard
    @@ -1,11 +1,15 @@
    +from typing import Union, Tuple
    +
     import torch
     import torch
     import torch.nn as nn
     import torch.nn as nn
     import torch.nn.functional as F
     import torch.nn.functional as F
    -from super_gradients.training.utils.segmentation_utils import to_one_hot
     from torch.nn.modules.loss import _Loss
     from torch.nn.modules.loss import _Loss
    +
    +from super_gradients.common.object_names import Losses
    +from super_gradients.common.registry.registry import register_loss
    +from super_gradients.training.utils.segmentation_utils import to_one_hot
     from super_gradients.training.losses.ohem_ce_loss import OhemCELoss, OhemBCELoss, OhemLoss
     from super_gradients.training.losses.ohem_ce_loss import OhemCELoss, OhemBCELoss, OhemLoss
     from super_gradients.training.losses.dice_loss import BinaryDiceLoss
     from super_gradients.training.losses.dice_loss import BinaryDiceLoss
    -from typing import Union, Tuple
     
     
     
     
     class DetailAggregateModule(nn.Module):
     class DetailAggregateModule(nn.Module):
    @@ -107,6 +111,7 @@ class DetailLoss(_Loss):
             return self.weights[0] * bce_loss + self.weights[1] * dice_loss
             return self.weights[0] * bce_loss + self.weights[1] * dice_loss
     
     
     
     
    +@register_loss(Losses.STDC_LOSS)
     class STDCLoss(_Loss):
     class STDCLoss(_Loss):
         """
         """
         Loss class of STDC-Seg training.
         Loss class of STDC-Seg training.
    Discard
    @@ -12,6 +12,9 @@ from torch import nn
     from torch.nn.modules.loss import _Loss
     from torch.nn.modules.loss import _Loss
     import torch.nn.functional as F
     import torch.nn.functional as F
     
     
    +
    +from super_gradients.common.object_names import Losses
    +from super_gradients.common.registry.registry import register_loss
     from super_gradients.common.abstractions.abstract_logger import get_logger
     from super_gradients.common.abstractions.abstract_logger import get_logger
     from super_gradients.training.utils import torch_version_is_greater_or_equal
     from super_gradients.training.utils import torch_version_is_greater_or_equal
     from super_gradients.training.utils.detection_utils import calculate_bbox_iou_matrix
     from super_gradients.training.utils.detection_utils import calculate_bbox_iou_matrix
    @@ -79,6 +82,7 @@ class IOUloss(nn.Module):
             return loss
             return loss
     
     
     
     
    +@register_loss(Losses.YOLOX_LOSS)
     class YoloXDetectionLoss(_Loss):
     class YoloXDetectionLoss(_Loss):
         """
         """
         Calculate YOLOX loss:
         Calculate YOLOX loss:
    @@ -607,6 +611,7 @@ class YoloXDetectionLoss(_Loss):
             return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
             return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
     
     
     
     
    +@register_loss(Losses.YOLOX_FAST_LOSS)
     class YoloXFastDetectionLoss(YoloXDetectionLoss):
     class YoloXFastDetectionLoss(YoloXDetectionLoss):
         """
         """
         A completely new implementation of YOLOX loss.
         A completely new implementation of YOLOX loss.
    Discard