Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#780 New Transforms: DetectionPadToSize, DetectionImagePermute

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-746-deci-yolo-integrate-dataset
@@ -57,6 +57,8 @@ class Transforms:
     DetectionRandomRotate90 = "DetectionRandomRotate90"
     DetectionRandomRotate90 = "DetectionRandomRotate90"
     DetectionHorizontalFlip = "DetectionHorizontalFlip"
     DetectionHorizontalFlip = "DetectionHorizontalFlip"
     DetectionRescale = "DetectionRescale"
     DetectionRescale = "DetectionRescale"
+    DetectionPadToSize = "DetectionPadToSize"
+    DetectionImagePermute = "DetectionImagePermute"
     DetectionPaddedRescale = "DetectionPaddedRescale"
     DetectionPaddedRescale = "DetectionPaddedRescale"
     DetectionTargetsFormatTransform = "DetectionTargetsFormatTransform"
     DetectionTargetsFormatTransform = "DetectionTargetsFormatTransform"
     DetectionNormalize = "DetectionNormalize"
     DetectionNormalize = "DetectionNormalize"
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
  1. train_dataset_params:
  2. data_dir: /data/coco # root path to coco data
  3. subdir: images/train2017 # sub directory path of data_dir containing the train data.
  4. json_file: instances_train2017.json # path to coco train json file, data_dir/annotations/train_json_file.
  5. input_dim: [640, 640]
  6. cache_dir:
  7. cache: False
  8. transforms:
  9. - DetectionRandomAffine:
  10. degrees: 0 # rotation degrees, randomly sampled from [-degrees, degrees]
  11. translate: 0.25 # image translation fraction
  12. scales: [ 0.5, 1.5 ] # random rescale range (keeps size by padding/cropping) after mosaic transform.
  13. shear: 0.0 # shear degrees, randomly sampled from [-degrees, degrees]
  14. target_size:
  15. filter_box_candidates: True # whether to filter out transformed bboxes by edge size, area ratio, and aspect ratio.
  16. wh_thr: 2 # edge size threshold when filter_box_candidates = True (pixels)
  17. area_thr: 0.1 # threshold for area ratio between original image and the transformed one, when when filter_box_candidates = True
  18. ar_thr: 20 # aspect ratio threshold when filter_box_candidates = True
  19. - DetectionRGB2BGR:
  20. prob: 0.5
  21. - DetectionHSV:
  22. prob: 0.5 # probability to apply HSV transform
  23. hgain: 18 # HSV transform hue gain (randomly sampled from [-hgain, hgain])
  24. sgain: 30 # HSV transform saturation gain (randomly sampled from [-sgain, sgain])
  25. vgain: 30 # HSV transform value gain (randomly sampled from [-vgain, vgain])
  26. - DetectionHorizontalFlip:
  27. prob: 0.5 # probability to apply horizontal flip
  28. - DetectionMixup:
  29. input_dim:
  30. mixup_scale: [ 0.5, 1.5 ] # random rescale range for the additional sample in mixup
  31. prob: 0.5 # probability to apply per-sample mixup
  32. flip_prob: 0.5 # probability to apply horizontal flip
  33. - DetectionStandardizeImage:
  34. max_value: 255.
  35. - DetectionPaddedRescale:
  36. input_dim: [640, 640]
  37. max_targets: 120
  38. pad_value: 114
  39. - DetectionTargetsFormatTransform:
  40. max_targets: 256
  41. output_format: LABEL_NORMALIZED_CXCYWH
  42. tight_box_rotation: False
  43. class_inclusion_list:
  44. max_num_samples:
  45. with_crowd: False
  46. train_dataloader_params:
  47. batch_size: 25
  48. num_workers: 8
  49. shuffle: True
  50. drop_last: True
  51. pin_memory: True
  52. collate_fn:
  53. _target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
  54. val_dataset_params:
  55. data_dir: /data/coco # root path to coco data
  56. subdir: images/val2017 # sub directory path of data_dir containing the train data.
  57. json_file: instances_val2017.json # path to coco train json file, data_dir/annotations/train_json_file.
  58. input_dim: [636, 636]
  59. cache_dir:
  60. cache: False
  61. transforms:
  62. - DetectionRGB2BGR:
  63. prob: 1
  64. - DetectionPadToSize:
  65. output_size: [640, 640]
  66. pad_value: 114
  67. - DetectionStandardizeImage:
  68. max_value: 255.
  69. - DetectionImagePermute:
  70. - DetectionTargetsFormatTransform:
  71. max_targets: 50
  72. input_dim: [640, 640]
  73. output_format: LABEL_NORMALIZED_CXCYWH
  74. tight_box_rotation: False
  75. class_inclusion_list:
  76. max_num_samples:
  77. with_crowd: True
  78. val_dataloader_params:
  79. batch_size: 25
  80. num_workers: 8
  81. drop_last: False
  82. pin_memory: True
  83. collate_fn:
  84. _target_: super_gradients.training.utils.detection_utils.CrowdDetectionCollateFN
  85. _convert_: all
Discard
@@ -9,6 +9,8 @@ from .dataloaders import (
     coco2017_val_ppyoloe,
     coco2017_val_ppyoloe,
     coco2017_pose_train,
     coco2017_pose_train,
     coco2017_pose_val,
     coco2017_pose_val,
+    coco2017_train_deci_yolo,
+    coco2017_val_deci_yolo,
     imagenet_train,
     imagenet_train,
     imagenet_val,
     imagenet_val,
     imagenet_efficientnet_train,
     imagenet_efficientnet_train,
@@ -66,6 +68,8 @@ __all__ = [
     "coco2017_val_ppyoloe",
     "coco2017_val_ppyoloe",
     "coco2017_pose_train",
     "coco2017_pose_train",
     "coco2017_pose_val",
     "coco2017_pose_val",
+    "coco2017_train_deci_yolo",
+    "coco2017_val_deci_yolo",
     "imagenet_train",
     "imagenet_train",
     "imagenet_val",
     "imagenet_val",
     "imagenet_efficientnet_train",
     "imagenet_efficientnet_train",
Discard
@@ -151,6 +151,26 @@ def coco2017_val(dataset_params: Dict = None, dataloader_params: Dict = None):
     )
     )
 
 
 
 
+def coco2017_train_deci_yolo(dataset_params: Dict = None, dataloader_params: Dict = None):
+    return get_data_loader(
+        config_name="coco_detection_deci_yolo_dataset_params",
+        dataset_cls=COCODetectionDataset,
+        train=True,
+        dataset_params=dataset_params,
+        dataloader_params=dataloader_params,
+    )
+
+
+def coco2017_val_deci_yolo(dataset_params: Dict = None, dataloader_params: Dict = None):
+    return get_data_loader(
+        config_name="coco_detection_deci_yolo_dataset_params",
+        dataset_cls=COCODetectionDataset,
+        train=False,
+        dataset_params=dataset_params,
+        dataloader_params=dataloader_params,
+    )
+
+
 def coco2017_train_ppyoloe(dataset_params: Dict = None, dataloader_params: Dict = None):
 def coco2017_train_ppyoloe(dataset_params: Dict = None, dataloader_params: Dict = None):
     return get_data_loader(
     return get_data_loader(
         config_name="coco_detection_ppyoloe_dataset_params",
         config_name="coco_detection_ppyoloe_dataset_params",
@@ -660,6 +680,8 @@ ALL_DATALOADERS = {
     "coco2017_val_ssd_lite_mobilenet_v2": coco2017_val_ssd_lite_mobilenet_v2,
     "coco2017_val_ssd_lite_mobilenet_v2": coco2017_val_ssd_lite_mobilenet_v2,
     "coco2017_pose_train": coco2017_pose_train,
     "coco2017_pose_train": coco2017_pose_train,
     "coco2017_pose_val": coco2017_pose_val,
     "coco2017_pose_val": coco2017_pose_val,
+    "coco2017_train_deci_yolo": coco2017_train_deci_yolo,
+    "coco2017_val_deci_yolo": coco2017_val_deci_yolo,
     "imagenet_train": imagenet_train,
     "imagenet_train": imagenet_train,
     "imagenet_val": imagenet_val,
     "imagenet_val": imagenet_val,
     "imagenet_efficientnet_train": imagenet_efficientnet_train,
     "imagenet_efficientnet_train": imagenet_efficientnet_train,
Discard
@@ -30,6 +30,8 @@ from super_gradients.training.transforms.transforms import (
     DetectionTargetsFormatTransform,
     DetectionTargetsFormatTransform,
     DetectionNormalize,
     DetectionNormalize,
     Standardize,
     Standardize,
+    DetectionPadToSize,
+    DetectionImagePermute,
 )
 )
 from torchvision.transforms import (
 from torchvision.transforms import (
     Compose,
     Compose,
@@ -99,6 +101,8 @@ TRANSFORMS = {
     Transforms.DetectionRandomRotate90: DetectionRandomRotate90,
     Transforms.DetectionRandomRotate90: DetectionRandomRotate90,
     Transforms.DetectionHorizontalFlip: DetectionHorizontalFlip,
     Transforms.DetectionHorizontalFlip: DetectionHorizontalFlip,
     Transforms.DetectionRescale: DetectionRescale,
     Transforms.DetectionRescale: DetectionRescale,
+    Transforms.DetectionImagePermute: DetectionImagePermute,
+    Transforms.DetectionPadToSize: DetectionPadToSize,
     Transforms.DetectionPaddedRescale: DetectionPaddedRescale,
     Transforms.DetectionPaddedRescale: DetectionPaddedRescale,
     Transforms.DetectionTargetsFormatTransform: DetectionTargetsFormatTransform,
     Transforms.DetectionTargetsFormatTransform: DetectionTargetsFormatTransform,
     Transforms.DetectionNormalize: DetectionNormalize,
     Transforms.DetectionNormalize: DetectionNormalize,
Discard
@@ -418,7 +418,7 @@ class DetectionStandardize(DetectionTransform):
         self.max_value = max_value
         self.max_value = max_value
 
 
     def __call__(self, sample: dict) -> dict:
     def __call__(self, sample: dict) -> dict:
-        sample["image"] = sample["image"] / self.max_value
+        sample["image"] = (sample["image"] / self.max_value).astype(np.float32)
         return sample
         return sample
 
 
 
 
@@ -697,6 +697,85 @@ class DetectionMixup(DetectionTransform):
         return sample
         return sample
 
 
 
 
+class DetectionImagePermute(DetectionTransform):
+    """
+    Permute image dims. Useful for converting image from HWC to CHW format.
+    """
+
+    def __init__(self, dims=(2, 0, 1)):
+        """
+
+        :param dims: Specify new order of dims. Default value (2, 0, 1) suitable for converting from HWC to CHW format.
+        """
+        super().__init__()
+        self.dims = tuple(dims)
+
+    def __call__(self, sample: Dict[str, np.array]):
+        sample["image"] = np.ascontiguousarray(sample["image"].transpose(*self.dims))
+        return sample
+
+
+class DetectionPadToSize(DetectionTransform):
+    """
+    Preprocessing transform to pad image and bboxes to `input_dim` shape (rows, cols).
+    Transform does center padding, so that input image with bboxes located in the center of the produced image.
+
+    Note: This transformation assume that dimensions of input image is equal or less than `output_size`.
+    """
+
+    def __init__(self, output_size: Tuple[int, int], pad_value: int):
+        """
+        Constructor for DetectionPadToSize transform.
+
+        :param output_size: Output image size (rows, cols)
+        :param pad_value: Padding value for image
+        """
+        super().__init__()
+        self.output_size = output_size
+        self.pad_value = pad_value
+
+    def __call__(self, sample: Dict[str, np.array]):
+        img, targets, crowd_targets = sample["image"], sample["target"], sample.get("crowd_target")
+        img, shift_w, shift_h = self._apply_to_image(img, final_shape=self.output_size, pad_value=self.pad_value)
+        sample["image"] = img
+        sample["target"] = self._apply_to_bboxes(targets, shift_w, shift_h)
+        if crowd_targets is not None:
+            sample["crowd_target"] = self._apply_to_bboxes(crowd_targets, shift_w, shift_h)
+        return sample
+
+    def _apply_to_bboxes(self, targets: np.array, shift_w: float, shift_h: float) -> np.array:
+        """Translate bboxes with respect to padding values.
+
+        :param targets:  Bboxes to transform of shape (N, 5).
+                         Bboxes expected to have format [x1, y1, x2, y2, class_id, ...]
+        :param shift_w:  shift width in pixels
+        :param shift_h:  shift height in pixels
+        :return:         Bboxes to transform of shape (N, 5)
+                         Bboxes will have same format [x1, y1, x2, y2, class_id, ...]
+        """
+        targets = targets.copy() if len(targets) > 0 else np.zeros((0, 5), dtype=np.float32)
+        boxes, labels = targets[:, :4], targets[:, 4:]
+        boxes[:, [0, 2]] += shift_w
+        boxes[:, [1, 3]] += shift_h
+        return np.concatenate((boxes, labels), 1)
+
+    def _apply_to_image(self, image, final_shape: Tuple[int, int], pad_value: int):
+        """
+        Pad image to final_shape.
+        :param image:
+        :param final_shape: Output image size (rows, cols).
+        :param pad_value:
+        :return:
+        """
+        pad_h, pad_w = final_shape[0] - image.shape[0], final_shape[1] - image.shape[1]
+        shift_h, shift_w = pad_h // 2, pad_w // 2
+        pad_h = (shift_h, pad_h - shift_h)
+        pad_w = (shift_w, pad_w - shift_w)
+
+        image = np.pad(image, (pad_h, pad_w, (0, 0)), "constant", constant_values=pad_value)
+        return image, shift_w, shift_h
+
+
 class DetectionPaddedRescale(DetectionTransform):
 class DetectionPaddedRescale(DetectionTransform):
     """
     """
     Preprocessing transform to be applied last of all transforms for validation.
     Preprocessing transform to be applied last of all transforms for validation.
Discard
@@ -9,6 +9,7 @@ from super_gradients.training.transforms.keypoint_transforms import (
     KeypointsPadIfNeeded,
     KeypointsPadIfNeeded,
     KeypointsLongestMaxSize,
     KeypointsLongestMaxSize,
 )
 )
+from super_gradients.training.transforms.transforms import DetectionImagePermute, DetectionPadToSize
 
 
 
 
 class TestTransforms(unittest.TestCase):
 class TestTransforms(unittest.TestCase):
@@ -83,6 +84,32 @@ class TestTransforms(unittest.TestCase):
         self.assertTrue((aug_joints[..., 0] < aug_image.shape[1]).all())
         self.assertTrue((aug_joints[..., 0] < aug_image.shape[1]).all())
         self.assertTrue((aug_joints[..., 1] < aug_image.shape[0]).all())
         self.assertTrue((aug_joints[..., 1] < aug_image.shape[0]).all())
 
 
+    def test_detection_image_permute(self):
+        aug = DetectionImagePermute(dims=(2, 1, 0))
+        image = np.random.rand(640, 480, 3)
+        sample = {"image": image}
+
+        output = aug(sample)
+        self.assertEqual(output["image"].shape, (3, 480, 640))
+
+    def test_detection_pad_to_size(self):
+        aug = DetectionPadToSize(output_size=(640, 640))
+        image = np.ones((512, 480, 3))
+
+        # Boxes in format (x1, y1, x2, y2, class_id)
+        boxes = np.array([[0, 0, 100, 100, 0], [100, 100, 200, 200, 1]])
+
+        sample = {"image": image, "target": boxes}
+        output = aug(sample)
+
+        shift_x = (640 - 480) // 2
+        shift_y = (640 - 512) // 2
+        expected_boxes = np.array(
+            [[0 + shift_x, 0 + shift_y, 100 + shift_x, 100 + shift_y, 0], [100 + shift_x, 100 + shift_y, 200 + shift_x, 200 + shift_y, 1]]
+        )
+        self.assertEqual(output["image"].shape, (640, 640, 3))
+        np.testing.assert_array_equal(output["target"], expected_boxes)
+
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     unittest.main()
     unittest.main()
Discard