Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#235 Add iou loss

Merged
GitHub User merged 1 commits into Deci-AI:master from deci-ai:feature/ALG-277_iou_loss
@@ -1,99 +1,43 @@
+from typing import Union
+
 import torch
 import torch
-from typing import Union, Optional
-from torch.nn.modules.loss import _Loss
-from super_gradients.training.utils.segmentation_utils import to_one_hot
-from super_gradients.training.losses.loss_utils import apply_reduce, LossReduction
+
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
+from super_gradients.training.losses.loss_utils import LossReduction
+from super_gradients.training.losses.structure_loss import AbstarctSegmentationStructureLoss
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class DiceLoss(_Loss):
+class DiceLoss(AbstarctSegmentationStructureLoss):
     """
     """
     Compute average Dice loss between two tensors, It can support both multi-classes and binary tasks.
     Compute average Dice loss between two tensors, It can support both multi-classes and binary tasks.
     Defined in the paper: "V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation"
     Defined in the paper: "V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation"
     """
     """
-    def __init__(self,
-                 apply_softmax: bool = True,
-                 ignore_index: int = None,
-                 smooth: float = 1.,
-                 eps: float = 1e-5,
-                 sum_over_batches: bool = False,
-                 generalized_dice: bool = False,
-                 weight: Optional[torch.Tensor] = None,
-                 reduction: Union[LossReduction, str] = "mean"):
-        """
-        :param apply_softmax: Whether to apply softmax to the predictions.
-        :param smooth: laplace smoothing, also known as additive smoothing. The larger smooth value is, closer the dice
-            coefficient is to 1, which can be used as a regularization effect.
-            As mentioned in: https://github.com/pytorch/pytorch/issues/1249#issuecomment-337999895
-        :param eps: epsilon value to avoid inf.
-        :param sum_over_batches: Whether to average dice over the batch axis if set True,
-         default is `False` to average over the classes axis.
-        :param generalized_dice: Whether to apply normalization by the volume of each class.
-        :param weight: a manual rescaling weight given to each class. If given, it has to be a Tensor of size `C`.
-        :param reduction: Specifies the reduction to apply to the output: `none` | `mean` | `sum`.
-            `none`: no reduction will be applied.
-            `mean`: the sum of the output will be divided by the number of elements in the output.
-            `sum`: the output will be summed.
-            Default: `mean`
+    def _calc_numerator_denominator(self, labels_one_hot, predict):
         """
         """
-        super().__init__(reduction=reduction)
-        self.ignore_index = ignore_index
-        self.apply_softmax = apply_softmax
-        self.eps = eps
-        self.smooth = smooth
-        self.sum_over_batches = sum_over_batches
-        self.generalized_dice = generalized_dice
-        self.weight = weight
-        if self.generalized_dice:
-            assert self.weight is None, "Cannot use Dice Loss with weight classes and generalized dice normalization"
-            if self.eps > 1e-12:
-                logger.warning("When using GeneralizedDiceLoss, it is recommended to use eps below 1e-12, to not affect"
-                               "small values normalized terms.")
-            if self.smooth != 0:
-                logger.warning("When using GeneralizedDiceLoss, it is recommended to set smooth value as 0.")
+        Calculate dice metric's numerator and denominator.
 
 
-    def forward(self, predict, target):
-        if self.apply_softmax:
-            predict = torch.softmax(predict, dim=1)
-        # target to one hot format
-        if target.size() == predict.size():
-            labels_one_hot = target
-        elif len(target.size()) == 3:       # if target tensor is in class indexes format.
-            if predict.size(1) == 1 and self.ignore_index is None:    # if one class prediction task
-                labels_one_hot = target.unsqueeze(1)
-            else:
-                labels_one_hot = to_one_hot(target, num_classes=predict.shape[1], ignore_index=self.ignore_index)
-        else:
-            raise AssertionError(f"Mismatch of target shape: {target.size()} and prediction shape: {predict.size()},"
-                                 f" target must be [NxWxH] tensor for to_one_hot conversion"
-                                 f" or to have the same num of channels like prediction tensor")
-
-        reduce_spatial_dims = list(range(2, len(predict.shape)))
-        reduce_dims = [1] + reduce_spatial_dims if self.sum_over_batches else [0] + reduce_spatial_dims
-
-        intersection = torch.sum(labels_one_hot * predict, dim=reduce_dims)
+        :param labels_one_hot: target in one hot format.   shape: [BS, num_classes, img_width, img_height]
+        :param predict: predictions tensor.                shape: [BS, num_classes, img_width, img_height]
+        :return:
+            numerator = intersection between predictions and target. shape: [BS, num_classes, img_width, img_height]
+            denominator = sum of predictions and target areas.       shape: [BS, num_classes, img_width, img_height]
+        """
+        numerator = labels_one_hot * predict
         denominator = labels_one_hot + predict
         denominator = labels_one_hot + predict
-        # exclude ignore labels from denominator, false positive predicted on ignore samples are not included in
-        # total denominator.
-        if self.ignore_index is not None:
-            valid_mask = target.ne(self.ignore_index).unsqueeze(1).expand_as(denominator)
-            denominator *= valid_mask
-        denominator = torch.sum(denominator, dim=reduce_dims)
+        return numerator, denominator
 
 
-        if self.generalized_dice:
-            weights = 1. / (torch.sum(labels_one_hot, dim=reduce_dims) ** 2)
-            # if some classes are not in batch, weights will be inf.
-            infs = torch.isinf(weights)
-            weights[infs] = 0.0
-            intersection *= weights
-            denominator *= weights
+    def _calc_loss(self, numerator, denominator):
+        """
+        Calculate dice loss.
+        All tensors are of shape [BS] if self.reduce_over_batches else [num_classes].
 
 
-        dices = 1. - ((2. * intersection + self.smooth) / (denominator + self.eps + self.smooth))
-        if self.weight is not None:
-            dices *= self.weight
-        return apply_reduce(dices, reduction=self.reduction)
+        :param numerator: intersection between predictions and target.
+        :param denominator: total number of pixels of prediction and target.
+        """
+        loss = 1. - ((2. * numerator + self.smooth) / (denominator + self.eps + self.smooth))
+        return loss
 
 
 
 
 class BinaryDiceLoss(DiceLoss):
 class BinaryDiceLoss(DiceLoss):
@@ -112,7 +56,7 @@ class BinaryDiceLoss(DiceLoss):
             As mentioned in: https://github.com/pytorch/pytorch/issues/1249#issuecomment-337999895
             As mentioned in: https://github.com/pytorch/pytorch/issues/1249#issuecomment-337999895
         :param eps: epsilon value to avoid inf.
         :param eps: epsilon value to avoid inf.
         """
         """
-        super().__init__(apply_softmax=False, ignore_index=None, smooth=smooth, eps=eps, sum_over_batches=False)
+        super().__init__(apply_softmax=False, ignore_index=None, smooth=smooth, eps=eps, reduce_over_batches=False)
         self.apply_sigmoid = apply_sigmoid
         self.apply_sigmoid = apply_sigmoid
 
 
     def forward(self, predict, target):
     def forward(self, predict, target):
@@ -138,7 +82,7 @@ class GeneralizedDiceLoss(DiceLoss):
                  ignore_index: int = None,
                  ignore_index: int = None,
                  smooth: float = 0.0,
                  smooth: float = 0.0,
                  eps: float = 1e-17,
                  eps: float = 1e-17,
-                 sum_over_batches: bool = False,
+                 reduce_over_batches: bool = False,
                  reduction: Union[LossReduction, str] = "mean"
                  reduction: Union[LossReduction, str] = "mean"
                  ):
                  ):
         """
         """
@@ -147,7 +91,7 @@ class GeneralizedDiceLoss(DiceLoss):
             coefficient is to 1, which can be used as a regularization effect.
             coefficient is to 1, which can be used as a regularization effect.
             As mentioned in: https://github.com/pytorch/pytorch/issues/1249#issuecomment-337999895
             As mentioned in: https://github.com/pytorch/pytorch/issues/1249#issuecomment-337999895
         :param eps: epsilon value to avoid inf.
         :param eps: epsilon value to avoid inf.
-        :param sum_over_batches: Whether to average dice over the batch axis if set True,
+        :param reduce_over_batches: Whether to apply reduction over the batch axis if set True,
          default is `False` to average over the classes axis.
          default is `False` to average over the classes axis.
         :param reduction: Specifies the reduction to apply to the output: `none` | `mean` | `sum`.
         :param reduction: Specifies the reduction to apply to the output: `none` | `mean` | `sum`.
             `none`: no reduction will be applied.
             `none`: no reduction will be applied.
@@ -156,4 +100,5 @@ class GeneralizedDiceLoss(DiceLoss):
             Default: `mean`
             Default: `mean`
         """
         """
         super().__init__(apply_softmax=apply_softmax, ignore_index=ignore_index, smooth=smooth, eps=eps,
         super().__init__(apply_softmax=apply_softmax, ignore_index=ignore_index, smooth=smooth, eps=eps,
-                         sum_over_batches=sum_over_batches, generalized_dice=True, weight=None, reduction=reduction)
+                         reduce_over_batches=reduce_over_batches, generalized_metric=True, weight=None,
+                         reduction=reduction)
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
  1. from typing import Union
  2. import torch
  3. from super_gradients.common.abstractions.abstract_logger import get_logger
  4. from super_gradients.training.losses.loss_utils import LossReduction
  5. from super_gradients.training.losses.structure_loss import AbstarctSegmentationStructureLoss
  6. logger = get_logger(__name__)
  7. class IoULoss(AbstarctSegmentationStructureLoss):
  8. """
  9. Compute average IoU loss between two tensors, It can support both multi-classes and binary tasks.
  10. """
  11. def _calc_numerator_denominator(self, labels_one_hot, predict):
  12. """
  13. Calculate iou metric's numerator and denominator.
  14. :param labels_one_hot: target in one hot format. shape: [BS, num_classes, img_width, img_height]
  15. :param predict: predictions tensor. shape: [BS, num_classes, img_width, img_height]
  16. :return:
  17. numerator = intersection between predictions and target. shape: [BS, num_classes, img_width, img_height]
  18. denominator = area of union between predictions and target. shape: [BS, num_classes, img_width, img_height]
  19. """
  20. numerator = labels_one_hot * predict
  21. denominator = labels_one_hot + predict - numerator
  22. return numerator, denominator
  23. def _calc_loss(self, numerator, denominator):
  24. """
  25. Calculate iou loss.
  26. All tensors are of shape [BS] if self.reduce_over_batches else [num_classes]
  27. :param numerator: intersection between predictions and target.
  28. :param denominator: area of union between prediction pixels and target pixels.
  29. """
  30. loss = 1. - ((numerator + self.smooth) / (denominator + self.eps + self.smooth))
  31. return loss
  32. class BinaryIoULoss(IoULoss):
  33. """
  34. Compute IoU Loss for binary class tasks (1 class only).
  35. Except target to be a binary map with 0 and 1 values.
  36. """
  37. def __init__(self,
  38. apply_sigmoid: bool = True,
  39. smooth: float = 1.,
  40. eps: float = 1e-5):
  41. """
  42. :param apply_sigmoid: Whether to apply sigmoid to the predictions.
  43. :param smooth: laplace smoothing, also known as additive smoothing. The larger smooth value is, closer the IoU
  44. coefficient is to 1, which can be used as a regularization effect.
  45. As mentioned in: https://github.com/pytorch/pytorch/issues/1249#issuecomment-337999895
  46. :param eps: epsilon value to avoid inf.
  47. """
  48. super().__init__(apply_softmax=False, ignore_index=None, smooth=smooth, eps=eps, reduce_over_batches=False)
  49. self.apply_sigmoid = apply_sigmoid
  50. def forward(self, predict, target):
  51. if self.apply_sigmoid:
  52. predict = torch.sigmoid(predict)
  53. return super().forward(predict=predict, target=target)
  54. class GeneralizedIoULoss(IoULoss):
  55. """
  56. Compute the Generalised IoU loss, contribution of each label is normalized by the inverse of its volume, in order
  57. to deal with class imbalance.
  58. Args:
  59. smooth (float): default value is 0, smooth laplacian is not recommended to be used with GeneralizedIoULoss.
  60. because the weighted values to be added are very small.
  61. eps (float): default value is 1e-17, must be a very small value, because weighted `intersection` and
  62. `denominator` are very small after multiplication with `1 / counts ** 2`
  63. """
  64. def __init__(self,
  65. apply_softmax: bool = True,
  66. ignore_index: int = None,
  67. smooth: float = 0.0,
  68. eps: float = 1e-17,
  69. reduce_over_batches: bool = False,
  70. reduction: Union[LossReduction, str] = "mean"
  71. ):
  72. """
  73. :param apply_softmax: Whether to apply softmax to the predictions.
  74. :param smooth: laplace smoothing, also known as additive smoothing. The larger smooth value is, closer the iou
  75. coefficient is to 1, which can be used as a regularization effect.
  76. As mentioned in: https://github.com/pytorch/pytorch/issues/1249#issuecomment-337999895
  77. :param eps: epsilon value to avoid inf.
  78. :param reduce_over_batches: Whether to apply reduction over the batch axis if set True,
  79. default is `False` to average over the classes axis.
  80. :param reduction: Specifies the reduction to apply to the output: `none` | `mean` | `sum`.
  81. `none`: no reduction will be applied.
  82. `mean`: the sum of the output will be divided by the number of elements in the output.
  83. `sum`: the output will be summed.
  84. Default: `mean`
  85. """
  86. super().__init__(apply_softmax=apply_softmax, ignore_index=ignore_index, smooth=smooth, eps=eps,
  87. reduce_over_batches=reduce_over_batches, generalized_metric=True, weight=None,
  88. reduction=reduction)
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
  1. from abc import ABC, abstractmethod
  2. from typing import Union, Optional
  3. import torch
  4. from torch.nn.modules.loss import _Loss
  5. from super_gradients.common.abstractions.abstract_logger import get_logger
  6. from super_gradients.training.losses.loss_utils import apply_reduce, LossReduction
  7. from super_gradients.training.utils.segmentation_utils import to_one_hot
  8. logger = get_logger(__name__)
  9. class AbstarctSegmentationStructureLoss(_Loss, ABC):
  10. """
  11. Abstract computation of structure loss between two tensors, It can support both multi-classes and binary tasks.
  12. """
  13. def __init__(self,
  14. apply_softmax: bool = True,
  15. ignore_index: int = None,
  16. smooth: float = 1.,
  17. eps: float = 1e-5,
  18. reduce_over_batches: bool = False,
  19. generalized_metric: bool = False,
  20. weight: Optional[torch.Tensor] = None,
  21. reduction: Union[LossReduction, str] = "mean"):
  22. """
  23. :param apply_softmax: Whether to apply softmax to the predictions.
  24. :param smooth: laplace smoothing, also known as additive smoothing. The larger smooth value is, closer the metric
  25. coefficient is to 1, which can be used as a regularization effect.
  26. As mentioned in: https://github.com/pytorch/pytorch/issues/1249#issuecomment-337999895
  27. :param eps: epsilon value to avoid inf.
  28. :param reduce_over_batches: Whether to average metric over the batch axis if set True,
  29. default is `False` to average over the classes axis.
  30. :param generalized_metric: Whether to apply normalization by the volume of each class.
  31. :param weight: a manual rescaling weight given to each class. If given, it has to be a Tensor of size `C`.
  32. :param reduction: Specifies the reduction to apply to the output: `none` | `mean` | `sum`.
  33. `none`: no reduction will be applied.
  34. `mean`: the sum of the output will be divided by the number of elements in the output.
  35. `sum`: the output will be summed.
  36. Default: `mean`
  37. """
  38. super().__init__(reduction=reduction)
  39. self.ignore_index = ignore_index
  40. self.apply_softmax = apply_softmax
  41. self.eps = eps
  42. self.smooth = smooth
  43. self.reduce_over_batches = reduce_over_batches
  44. self.generalized_metric = generalized_metric
  45. self.weight = weight
  46. if self.generalized_metric:
  47. assert self.weight is None, "Cannot use structured Loss with weight classes and generalized normalization"
  48. if self.eps > 1e-12:
  49. logger.warning("When using GeneralizedLoss, it is recommended to use eps below 1e-12, to not affect"
  50. "small values normalized terms.")
  51. if self.smooth != 0:
  52. logger.warning("When using GeneralizedLoss, it is recommended to set smooth value as 0.")
  53. @abstractmethod
  54. def _calc_numerator_denominator(self, labels_one_hot, predict) -> (torch.Tensor, torch.Tensor):
  55. """
  56. All base classes must implement this function.
  57. Return: 2 tensor of shape [BS, num_classes, img_width, img_height].
  58. """
  59. raise NotImplementedError()
  60. @abstractmethod
  61. def _calc_loss(self, numerator, denominator) -> torch.Tensor:
  62. """
  63. All base classes must implement this function.
  64. Return a tensors of shape [BS] if self.reduce_over_batches else [num_classes].
  65. """
  66. raise NotImplementedError()
  67. def forward(self, predict, target):
  68. if self.apply_softmax:
  69. predict = torch.softmax(predict, dim=1)
  70. # target to one hot format
  71. if target.size() == predict.size():
  72. labels_one_hot = target
  73. elif target.dim() == 3: # if target tensor is in class indexes format.
  74. if predict.size(1) == 1 and self.ignore_index is None: # if one class prediction task
  75. labels_one_hot = target.unsqueeze(1)
  76. else:
  77. labels_one_hot = to_one_hot(target, num_classes=predict.shape[1], ignore_index=self.ignore_index)
  78. else:
  79. raise AssertionError(f"Mismatch of target shape: {target.size()} and prediction shape: {predict.size()},"
  80. f" target must be [NxWxH] tensor for to_one_hot conversion"
  81. f" or to have the same num of channels like prediction tensor")
  82. reduce_spatial_dims = list(range(2, len(predict.shape)))
  83. reduce_dims = [1] + reduce_spatial_dims if self.reduce_over_batches else [0] + reduce_spatial_dims
  84. # Calculate the numerator and denominator of the chosen metric
  85. numerator, denominator = self._calc_numerator_denominator(labels_one_hot, predict)
  86. # exclude ignore labels from numerator and denominator, false positive predicted on ignore samples
  87. # are not included in the total calculation.
  88. if self.ignore_index is not None:
  89. valid_mask = target.ne(self.ignore_index).unsqueeze(1).expand_as(denominator)
  90. numerator *= valid_mask
  91. denominator *= valid_mask
  92. numerator = torch.sum(numerator, dim=reduce_dims)
  93. denominator = torch.sum(denominator, dim=reduce_dims)
  94. if self.generalized_metric:
  95. weights = 1. / (torch.sum(labels_one_hot, dim=reduce_dims) ** 2)
  96. # if some classes are not in batch, weights will be inf.
  97. infs = torch.isinf(weights)
  98. weights[infs] = 0.0
  99. numerator *= weights
  100. denominator *= weights
  101. # Calculate the loss of the chosen metric
  102. losses = self._calc_loss(numerator, denominator)
  103. if self.weight is not None:
  104. losses *= self.weight
  105. return apply_reduce(losses, reduction=self.reduction)
Discard
@@ -16,6 +16,7 @@ from tests.unit_tests.lr_warmup_test import LRWarmupTest
 from tests.unit_tests.kd_ema_test import KDEMATest
 from tests.unit_tests.kd_ema_test import KDEMATest
 from tests.unit_tests.kd_model_test import KDModelTest
 from tests.unit_tests.kd_model_test import KDModelTest
 from tests.unit_tests.dice_loss_test import DiceLossTest
 from tests.unit_tests.dice_loss_test import DiceLossTest
+from tests.unit_tests.iou_loss_test import IoULossTest
 from tests.unit_tests.update_param_groups_unit_test import UpdateParamGroupsTest
 from tests.unit_tests.update_param_groups_unit_test import UpdateParamGroupsTest
 from tests.unit_tests.vit_unit_test import TestViT
 from tests.unit_tests.vit_unit_test import TestViT
 from tests.unit_tests.yolox_unit_test import TestYOLOX
 from tests.unit_tests.yolox_unit_test import TestYOLOX
@@ -68,6 +69,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(ContextMethodsTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(ContextMethodsTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(UpdateParamGroupsTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(UpdateParamGroupsTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(MaskAttentionLossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(MaskAttentionLossTest))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(IoULossTest))
 
 
     def _add_modules_to_end_to_end_tests_suite(self):
     def _add_modules_to_end_to_end_tests_suite(self):
         """
         """
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
  1. import torch
  2. import unittest
  3. from super_gradients.training.losses.iou_loss import IoULoss, GeneralizedIoULoss, BinaryIoULoss
  4. class IoULossTest(unittest.TestCase):
  5. def setUp(self) -> None:
  6. self.img_size = 32
  7. self.eps = 1e-5
  8. self.num_classes = 2
  9. def _get_default_predictions_tensor(self, fill_value: float):
  10. return torch.empty(3, self.num_classes, self.img_size, self.img_size).fill_(fill_value)
  11. def _get_default_target_zeroes_tensor(self):
  12. return torch.zeros((3, self.img_size, self.img_size)).long()
  13. def _assertion_iou_torch_values(self, expected_value: torch.Tensor, found_value: torch.Tensor, rtol: float = 1e-5):
  14. self.assertTrue(
  15. torch.allclose(found_value, expected_value, rtol=rtol),
  16. msg=f"Unequal iou loss: excepted: {expected_value}, found: {found_value}"
  17. )
  18. def test_iou(self):
  19. predictions = self._get_default_predictions_tensor(0.)
  20. # only label 0 is predicted as positive.
  21. predictions[:, 0] = 1.
  22. target = self._get_default_target_zeroes_tensor()
  23. # half target with label 0, the other half with 1.
  24. target[:, :self.img_size // 2] = 1
  25. intersection = torch.tensor([0.5, 0.])
  26. union = torch.tensor([1., 0.5])
  27. expected_iou_loss = 1. - (intersection / (union + self.eps))
  28. expected_iou_loss = expected_iou_loss.mean()
  29. criterion = IoULoss(smooth=0, eps=self.eps, apply_softmax=False)
  30. iou_loss = criterion(predictions, target)
  31. self._assertion_iou_torch_values(expected_iou_loss, iou_loss)
  32. def test_iou_binary(self):
  33. # all predictions are 0.6
  34. predictions = torch.ones((1, 1, self.img_size, self.img_size)) * 0.6
  35. target = self._get_default_target_zeroes_tensor()
  36. # half target with label 0, the other half with 1.
  37. target[:, :self.img_size // 2] = 1
  38. intersection = torch.tensor([0.6 * 0.5])
  39. union = torch.tensor([0.6 + 0.5 - 0.6 * 0.5])
  40. expected_iou_loss = 1. - (intersection / (union + self.eps))
  41. expected_iou_loss = expected_iou_loss.mean()
  42. criterion = BinaryIoULoss(smooth=0, eps=self.eps, apply_sigmoid=False)
  43. iou_loss = criterion(predictions, target)
  44. self._assertion_iou_torch_values(expected_iou_loss, iou_loss, rtol=1e-3)
  45. def test_iou_weight_classes(self):
  46. weight = torch.tensor([0.25, 0.66])
  47. predictions = self._get_default_predictions_tensor(0.)
  48. # only label 0 is predicted as positive.
  49. predictions[:, 0] = 1.
  50. target = self._get_default_target_zeroes_tensor()
  51. # half target with label 0, the other half with 1.
  52. target[:, :self.img_size // 2] = 1
  53. intersection = torch.tensor([0.5, 0.])
  54. union = torch.tensor([1., 0.5])
  55. expected_iou_loss = 1. - (intersection / (union + self.eps))
  56. expected_iou_loss *= weight
  57. expected_iou_loss = expected_iou_loss.mean()
  58. criterion = IoULoss(smooth=0, eps=self.eps, apply_softmax=False, weight=weight)
  59. iou_loss = criterion(predictions, target)
  60. self._assertion_iou_torch_values(expected_iou_loss, iou_loss)
  61. def test_iou_with_ignore(self):
  62. ignore_index = 2
  63. predictions = self._get_default_predictions_tensor(0.)
  64. # only label 0 is predicted as positive.
  65. predictions[:, 0] = 1.
  66. target = self._get_default_target_zeroes_tensor()
  67. # half target with label 0, quarter with 1 and quarter with ignore.
  68. target[:, :self.img_size // 2, :self.img_size // 2] = 1
  69. target[:, :self.img_size // 2, self.img_size // 2:] = ignore_index
  70. # ignore samples are excluded in both intersection and union.
  71. intersection = torch.tensor([0.5, 0.])
  72. union = torch.tensor([0.75, 0.25])
  73. expected_iou_loss = 1. - (intersection / (union + self.eps))
  74. expected_iou_loss = expected_iou_loss.mean()
  75. criterion = IoULoss(smooth=0, eps=self.eps, apply_softmax=False, ignore_index=ignore_index)
  76. iou_loss = criterion(predictions, target)
  77. self._assertion_iou_torch_values(expected_iou_loss, iou_loss)
  78. def test_generalized_iou(self):
  79. predictions = self._get_default_predictions_tensor(0.)
  80. # half prediction are 0 class, the other half 1 class.
  81. predictions[:, 0, :self.img_size // 2] = 1.
  82. predictions[:, 1, self.img_size // 2:] = 1.
  83. # only 0 class in target.
  84. target = self._get_default_target_zeroes_tensor()
  85. intersection = torch.tensor([0.5, 0.])
  86. union = torch.tensor([1., 0.5])
  87. counts = torch.tensor([target.numel(), 0.])
  88. weights = 1 / (counts ** 2)
  89. weights[1] = 0.0 # instead of inf
  90. eps = 1e-17
  91. expected_iou_loss = 1. - ((weights * intersection) / (weights * union + eps))
  92. expected_iou_loss = expected_iou_loss.mean()
  93. criterion = GeneralizedIoULoss(smooth=0, eps=eps, apply_softmax=False)
  94. iou_loss = criterion(predictions, target)
  95. self._assertion_iou_torch_values(expected_iou_loss, iou_loss)
  96. if __name__ == '__main__':
  97. unittest.main()
Discard