Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#641 Added new detection transforms that are used in PPYoloE

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-344-PPYolo-Detection-Transforms
@@ -48,7 +48,10 @@ class Transforms:
     DetectionRandomAffine = "DetectionRandomAffine"
     DetectionRandomAffine = "DetectionRandomAffine"
     DetectionMixup = "DetectionMixup"
     DetectionMixup = "DetectionMixup"
     DetectionHSV = "DetectionHSV"
     DetectionHSV = "DetectionHSV"
+    DetectionRGB2BGR = "DetectionRGB2BGR"
+    DetectionRandomRotate90 = "DetectionRandomRotate90"
     DetectionHorizontalFlip = "DetectionHorizontalFlip"
     DetectionHorizontalFlip = "DetectionHorizontalFlip"
+    DetectionRescale = "DetectionRescale"
     DetectionPaddedRescale = "DetectionPaddedRescale"
     DetectionPaddedRescale = "DetectionPaddedRescale"
     DetectionTargetsFormatTransform = "DetectionTargetsFormatTransform"
     DetectionTargetsFormatTransform = "DetectionTargetsFormatTransform"
     RandomResizedCropAndInterpolation = "RandomResizedCropAndInterpolation"
     RandomResizedCropAndInterpolation = "RandomResizedCropAndInterpolation"
Discard
@@ -21,7 +21,10 @@ from super_gradients.training.transforms.transforms import (
     DetectionRandomAffine,
     DetectionRandomAffine,
     DetectionMixup,
     DetectionMixup,
     DetectionHSV,
     DetectionHSV,
+    DetectionRGB2BGR,
+    DetectionRandomRotate90,
     DetectionHorizontalFlip,
     DetectionHorizontalFlip,
+    DetectionRescale,
     DetectionPaddedRescale,
     DetectionPaddedRescale,
     DetectionTargetsFormatTransform,
     DetectionTargetsFormatTransform,
     Standardize,
     Standardize,
@@ -79,7 +82,10 @@ TRANSFORMS = {
     Transforms.DetectionRandomAffine: DetectionRandomAffine,
     Transforms.DetectionRandomAffine: DetectionRandomAffine,
     Transforms.DetectionMixup: DetectionMixup,
     Transforms.DetectionMixup: DetectionMixup,
     Transforms.DetectionHSV: DetectionHSV,
     Transforms.DetectionHSV: DetectionHSV,
+    Transforms.DetectionRGB2BGR: DetectionRGB2BGR,
+    Transforms.DetectionRandomRotate90: DetectionRandomRotate90,
     Transforms.DetectionHorizontalFlip: DetectionHorizontalFlip,
     Transforms.DetectionHorizontalFlip: DetectionHorizontalFlip,
+    Transforms.DetectionRescale: DetectionRescale,
     Transforms.DetectionPaddedRescale: DetectionPaddedRescale,
     Transforms.DetectionPaddedRescale: DetectionPaddedRescale,
     Transforms.DetectionTargetsFormatTransform: DetectionTargetsFormatTransform,
     Transforms.DetectionTargetsFormatTransform: DetectionTargetsFormatTransform,
     Transforms.RandomResizedCropAndInterpolation: RandomResizedCropAndInterpolation,
     Transforms.RandomResizedCropAndInterpolation: RandomResizedCropAndInterpolation,
Discard
@@ -718,6 +718,126 @@ class DetectionHorizontalFlip(DetectionTransform):
         return sample
         return sample
 
 
 
 
+class DetectionRescale(DetectionTransform):
+    """
+    Resize image and bounding boxes to given image dimensions without preserving aspect ratio
+    Attributes:
+        input_dim: (tuple) (rows, cols)
+        swap: image axis's to be rearranged.
+    """
+
+    def __init__(self, input_dim: Tuple[int, int], swap=(2, 0, 1)):
+        super().__init__()
+        self.swap = swap
+        self.input_dim = input_dim
+
+    def __call__(self, sample: Dict[str, np.array]):
+        img, targets, crowd_targets = sample["image"], sample["target"], sample.get("crowd_target")
+
+        img_resized, scale_factors = self._rescale_image(img)
+
+        sample["image"] = img_resized.transpose(self.swap).astype(np.float32, copy=True)
+        sample["target"] = self._rescale_target(targets, scale_factors)
+        if crowd_targets is not None:
+            sample["crowd_target"] = self._rescale_target(crowd_targets, scale_factors)
+        return sample
+
+    def _rescale_image(self, image):
+        sy, sx = self.input_dim[0] / image.shape[0], self.input_dim[1] / image.shape[1]
+        resized_img = cv2.resize(
+            image,
+            dsize=(int(self.input_dim[1]), int(self.input_dim[0])),
+            interpolation=cv2.INTER_LINEAR,
+        )
+        scale_factors = sy, sx
+        return resized_img, scale_factors
+
+    def _rescale_target(self, targets: np.array, scale_factors: Tuple[float, float]) -> np.array:
+        """SegRescale the target according to a coefficient used to rescale the image.
+        This is done to have images and targets at the same scale.
+        :param targets:  Target XYXY bboxes to rescale, shape (num_boxes, 5)
+        :param r:        SegRescale coefficient that was applied to the image
+        :return:         Rescaled targets, shape (num_boxes, 5)
+        """
+        sy, sx = scale_factors
+        targets = targets.astype(np.float32, copy=True) if len(targets) > 0 else np.zeros((0, 5), dtype=np.float32)
+        targets[:, 0:4] *= np.array([[sx, sy, sx, sy]], dtype=targets.dtype)
+        return targets
+
+
+class DetectionRandomRotate90(DetectionTransform):
+    def __init__(self, prob: float = 0.5):
+        super().__init__()
+        self.prob = prob
+
+    def __call__(self, sample: dict) -> dict:
+        if random.random() < self.prob:
+            k = random.randrange(0, 4)
+
+            img, targets, crowd_targets = sample["image"], sample["target"], sample.get("crowd_target")
+
+            sample["image"] = np.ascontiguousarray(np.rot90(img, k))
+            sample["target"] = self.rotate_bboxes(targets, k, img.shape[:2])
+            if crowd_targets is not None:
+                sample["crowd_target"] = self.rotate_bboxes(crowd_targets, k, img.shape[:2])
+
+        return sample
+
+    @classmethod
+    def rotate_bboxes(cls, targets, k: int, image_shape):
+        if k == 0:
+            return targets
+        rows, cols = image_shape
+        targets = targets.copy()
+        targets[:, 0:4] = cls.xyxy_bbox_rot90(targets[:, 0:4], k, rows, cols)
+        return targets
+
+    @classmethod
+    def xyxy_bbox_rot90(cls, bboxes, factor: int, rows: int, cols: int):
+        """Rotates a bounding box by 90 degrees CCW (see np.rot90)
+        Args:
+            bbox: A bounding box tuple (x_min, y_min, x_max, y_max).
+            factor: Number of CCW rotations. Must be in set {0, 1, 2, 3} See np.rot90.
+            rows: Image rows.
+            cols: Image cols.
+        Returns:
+            tuple: A bounding box tuple (x_min, y_min, x_max, y_max).
+        """
+        x_min, y_min, x_max, y_max = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3]
+
+        if factor == 0:
+            bbox = x_min, y_min, x_max, y_max
+        elif factor == 1:
+            bbox = y_min, cols - x_max, y_max, cols - x_min
+        elif factor == 2:
+            bbox = cols - x_max, rows - y_max, cols - x_min, rows - y_min
+        elif factor == 3:
+            bbox = rows - y_max, x_min, rows - y_min, x_max
+        else:
+            raise ValueError("Parameter n must be in set {0, 1, 2, 3}")
+        return np.stack(bbox, axis=1)
+
+
+class DetectionRGB2BGR(DetectionTransform):
+    """
+    Detection change Red & Blue channel of the image
+    Attributes:
+        prob: (float) probability to apply the transform.
+    """
+
+    def __init__(self, prob: float = 0.5):
+        super().__init__()
+        self.prob = prob
+
+    def __call__(self, sample: dict) -> dict:
+        if sample["image"].shape[2] != 3:
+            raise ValueError("DetectionRGB2BGR expects image to have 3 channels, got: " + str(sample["image"].shape[2]))
+
+        if random.random() < self.prob:
+            sample["image"] = sample["image"][..., ::-1]
+        return sample
+
+
 class DetectionHSV(DetectionTransform):
 class DetectionHSV(DetectionTransform):
     """
     """
     Detection HSV transform.
     Detection HSV transform.
Discard