Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

detection_sliding_window_wrapper_test.py 1.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
  1. import unittest
  2. from pathlib import Path
  3. from super_gradients.training import models
  4. from super_gradients.training.dataloaders import coco2017_val_yolo_nas
  5. from super_gradients.training import Trainer
  6. from super_gradients.training.models.detection_models.sliding_window_detection_forward_wrapper import SlidingWindowInferenceDetectionWrapper
  7. from super_gradients.training.metrics import DetectionMetrics
  8. class SlidingWindowWrapperTest(unittest.TestCase):
  9. def setUp(self):
  10. self.mini_coco_data_dir = str(Path(__file__).parent.parent / "data" / "tinycoco")
  11. def test_yolo_nas_s_coco_with_sliding_window(self):
  12. trainer = Trainer("test_yolo_nas_s_coco_with_sliding_window")
  13. model = models.get("yolo_nas_s", num_classes=80, pretrained_weights="coco")
  14. model = SlidingWindowInferenceDetectionWrapper(tile_size=320, tile_step=160, model=model, tile_nms_iou=0.65, tile_nms_conf=0.03)
  15. dl = coco2017_val_yolo_nas(dataset_params=dict(data_dir=self.mini_coco_data_dir))
  16. metric = DetectionMetrics(
  17. normalize_targets=True,
  18. post_prediction_callback=None,
  19. num_cls=80,
  20. )
  21. metric_values = trainer.test(model=model, test_loader=dl, test_metrics_list=[metric])
  22. self.assertAlmostEqual(metric_values[metric.map_str], 0.342, delta=0.001)
  23. if __name__ == "__main__":
  24. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...