Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

detection_utils_test.py 5.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
  1. import os
  2. import tempfile
  3. import unittest
  4. import numpy as np
  5. import torch.cuda
  6. from super_gradients.common.object_names import Models
  7. from super_gradients.training import utils as core_utils, models
  8. from super_gradients.training.dataloaders.dataloaders import coco2017_val
  9. from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
  10. from super_gradients.training.datasets.pose_estimation_datasets.yolo_nas_pose_collate_fn import flat_collate_tensors_with_batch_index
  11. from super_gradients.training.metrics import DetectionMetrics, DetectionMetrics_050
  12. from super_gradients.training.models.detection_models.yolo_base import YoloXPostPredictionCallback
  13. from super_gradients.training.utils.detection_utils import DetectionVisualization, DetectionPostPredictionCallback, xyxy2cxcywh
  14. from tests.core_test_utils import is_data_available
  15. class TestDetectionUtils(unittest.TestCase):
  16. def setUp(self):
  17. self.device = "cuda" if torch.cuda.is_available() else "cpu"
  18. self.model = models.get(Models.YOLOX_N, pretrained_weights="coco").to(self.device)
  19. self.model.eval()
  20. def test_detection_metric_with_calc_best_score_thresholds(self):
  21. class DummyCallback(DetectionPostPredictionCallback):
  22. def forward(self, p, device=None):
  23. return p
  24. class_names = ["A", "B", "C"]
  25. num_classes = len(class_names)
  26. metric = DetectionMetrics(
  27. num_cls=num_classes,
  28. post_prediction_callback=DummyCallback(),
  29. normalize_targets=True,
  30. calc_best_score_thresholds=True,
  31. include_classwise_ap=True,
  32. class_names=class_names,
  33. )
  34. # x1, y1, x2, y2, confidence, class_label
  35. num_predictions = 100
  36. num_targets = 64
  37. preds = torch.cat(
  38. [
  39. torch.randint(0, 100, (num_predictions, 2)), # [x1,y1]
  40. torch.randint(100, 200, (num_predictions, 2)), # [x2,y2]
  41. torch.randn((num_predictions, 1)).sigmoid(),
  42. torch.randint(0, num_classes, (num_predictions, 1)), # [x2,y2]
  43. ],
  44. dim=-1,
  45. ).float()
  46. targets = torch.cat(
  47. [
  48. torch.randint(0, num_classes, (num_targets, 1)), # [x2,y2]
  49. torch.randint(0, 100, (num_targets, 2)), # [x1,y1]
  50. torch.randint(100, 200, (num_targets, 2)), # [x2,y2]
  51. ],
  52. dim=-1,
  53. ).float()
  54. targets[:, 1:] = xyxy2cxcywh(targets[:, 1:])
  55. targets_flat = flat_collate_tensors_with_batch_index([targets])
  56. metric(preds=[preds], target=targets_flat, device="cpu", inputs=torch.zeros((1, 3, 640, 640)))
  57. metric_values = metric.compute()
  58. self.assertTrue("Best_score_threshold" in metric_values)
  59. for metric_value_name in metric.best_threshold_per_class_names:
  60. self.assertTrue(metric_value_name in metric_values)
  61. @unittest.skipIf(not is_data_available(), "run only when /data is available")
  62. def test_visualization(self):
  63. with tempfile.TemporaryDirectory() as tmpdirname:
  64. valid_loader = coco2017_val(dataloader_params={"batch_size": 16, "num_workers": 0})
  65. post_prediction_callback = YoloXPostPredictionCallback()
  66. # Simulate one iteration of validation subset
  67. batch_i, batch = 0, next(iter(valid_loader))
  68. imgs, targets = batch[:2]
  69. imgs = core_utils.tensor_container_to_device(imgs, self.device)
  70. targets = core_utils.tensor_container_to_device(targets, self.device)
  71. output = self.model(imgs)
  72. output = post_prediction_callback(output)
  73. # Visualize the batch
  74. DetectionVisualization.visualize_batch(imgs, output, targets, batch_i, COCO_DETECTION_CLASSES_LIST, tmpdirname)
  75. # Assert images ware created and delete them
  76. img_name = "{}/{}_{}.jpg"
  77. for i in range(4):
  78. img_path = img_name.format(tmpdirname, batch_i, i)
  79. self.assertTrue(os.path.exists(img_path))
  80. os.remove(img_path)
  81. @unittest.skipIf(not is_data_available(), "run only when /data is available")
  82. def test_detection_metrics(self):
  83. valid_loader = coco2017_val(dataloader_params={"batch_size": 16, "num_workers": 0})
  84. metrics = [
  85. DetectionMetrics(num_cls=80, post_prediction_callback=YoloXPostPredictionCallback(), normalize_targets=True),
  86. DetectionMetrics_050(num_cls=80, post_prediction_callback=YoloXPostPredictionCallback(), normalize_targets=True),
  87. DetectionMetrics(num_cls=80, post_prediction_callback=YoloXPostPredictionCallback(conf=2), normalize_targets=True),
  88. ]
  89. ref_values = [
  90. np.array([0.24701539, 0.40294355, 0.34654024, 0.28485271]),
  91. np.array([0.34666198, 0.56854934, 0.5079478, 0.40414381]),
  92. np.array([0.0, 0.0, 0.0, 0.0]),
  93. ]
  94. for met, ref_val in zip(metrics, ref_values):
  95. met.reset()
  96. for i, (imgs, targets, extras) in enumerate(valid_loader):
  97. if i > 5:
  98. break
  99. imgs = core_utils.tensor_container_to_device(imgs, self.device)
  100. targets = core_utils.tensor_container_to_device(targets, self.device)
  101. output = self.model(imgs)
  102. met.update(output, targets, device=self.device, inputs=imgs)
  103. results = met.compute()
  104. values = np.array([x.item() for x in list(results.values())])
  105. for expected, actual in zip(ref_val, values):
  106. self.assertAlmostEqual(expected, actual, delta=5e-3)
  107. if __name__ == "__main__":
  108. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...