Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

dataset_statistics_test.py 5.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
  1. import unittest
  2. from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CoCoDetectionDatasetInterface
  3. from super_gradients.training.metrics.detection_metrics import DetectionMetrics
  4. from super_gradients.training import Trainer
  5. from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
  6. from super_gradients.training.utils.detection_utils import CrowdDetectionCollateFN, DetectionCollateFN, \
  7. DetectionTargetsFormat
  8. class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
  9. def test_dataset_statistics_tensorboard_logger(self):
  10. """
  11. ** IMPORTANT NOTE **
  12. This test is not the usual fail/pass test - it is a visual test. The success criteria is your own visual check
  13. After launching the test, follow the log the see where was the tensorboard opened. open the tensorboard in your
  14. browser and make sure the text and plots in the tensorboard are as expected.
  15. """
  16. # Create dataset
  17. dataset = CoCoDetectionDatasetInterface(dataset_params={"data_dir": "/data/coco",
  18. "train_subdir": "images/train2017",
  19. "val_subdir": "images/val2017",
  20. "train_json_file": "instances_train2017.json",
  21. "val_json_file": "instances_val2017.json",
  22. "batch_size": 16,
  23. "val_batch_size": 128,
  24. "val_image_size": 640,
  25. "train_image_size": 640,
  26. "hgain": 5,
  27. "sgain": 30,
  28. "vgain": 30,
  29. "mixup_prob": 1.0,
  30. "degrees": 10.,
  31. "shear": 2.0,
  32. "flip_prob": 0.5,
  33. "hsv_prob": 1.0,
  34. "mosaic_scale": [0.1, 2],
  35. "mixup_scale": [0.5, 1.5],
  36. "mosaic_prob": 1.,
  37. "translate": 0.1,
  38. "val_collate_fn": CrowdDetectionCollateFN(),
  39. "train_collate_fn": DetectionCollateFN(),
  40. "cache_dir_path": None,
  41. "cache_train_images": False,
  42. "cache_val_images": False,
  43. "targets_format": DetectionTargetsFormat.LABEL_CXCYWH,
  44. "with_crowd": True,
  45. "filter_box_candidates": False,
  46. "wh_thr": 0,
  47. "ar_thr": 0,
  48. "area_thr": 0
  49. })
  50. trainer = Trainer('dataset_statistics_visual_test',
  51. model_checkpoints_location='local',
  52. post_prediction_callback=YoloPostPredictionCallback())
  53. trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
  54. trainer.build_model("yolox_s")
  55. training_params = {"max_epochs": 1, # we dont really need the actual training to run
  56. "lr_mode": "cosine",
  57. "initial_lr": 0.01,
  58. "loss": "yolox_loss",
  59. "criterion_params": {"strides": [8, 16, 32], "num_classes": 80},
  60. "dataset_statistics": True,
  61. "launch_tensorboard": True,
  62. "valid_metrics_list": [
  63. DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
  64. normalize_targets=True,
  65. num_cls=80)],
  66. "loss_logging_items_names": ["iou", "obj", "cls", "l1", "num_fg", "Loss"],
  67. "metric_to_watch": "mAP@0.50:0.95",
  68. }
  69. trainer.train(training_params=training_params)
  70. if __name__ == '__main__':
  71. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...