Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

yolo_v3_coco.py 3.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
  1. # YoLo V3 (Big) Detection training on CoCo Dataset:
  2. # mAP@0.5 ~ 56.8
  3. # The code is optimized for running with a Mini-Batch of 64 examples... So depending on the amount of GPUs,
  4. # you should change the "batch_accumulate" param in the training_params dict to be batch_size * gpu_num * batch_accumulate = 64.
  5. # IMPORTANT: The final step estimates the IOU threshold differently, in order to boost the mAP score - so please run it
  6. # all the way to the final epoch
  7. from super_gradients.training import SgModel, MultiGPUMode
  8. from super_gradients.training.datasets import CoCoDetectionDatasetInterface
  9. from super_gradients.training.utils.detection_utils import base_detection_collate_fn
  10. from super_gradients.training.datasets.datasets_utils import ComposedCollateFunction, MultiScaleCollateFunction
  11. from super_gradients.training.utils.detection_utils import YoloV3NonMaxSuppression
  12. from super_gradients.training.metrics.detection_metrics import DetectionMetrics
  13. collate_fn_holder = ComposedCollateFunction([base_detection_collate_fn, MultiScaleCollateFunction(320)])
  14. yolo_v3_dataset_params = {"batch_size": 16,
  15. "test_batch_size": 16,
  16. "dataset_dir": "/data/coco/",
  17. "s3_link": None,
  18. "image_size": 320,
  19. "test_collate_fn": base_detection_collate_fn,
  20. "train_collate_fn": collate_fn_holder,
  21. }
  22. yolo_v3_arch_params = {"image_size": 320, "iou_t": 0.225, "multi_gpu_mode": "distributed_data_parallel"}
  23. post_prediction_callback = YoloV3NonMaxSuppression()
  24. model = SgModel('yolo_v3_spp_example', model_checkpoints_location='local', multi_gpu=MultiGPUMode.OFF,
  25. post_prediction_callback=post_prediction_callback)
  26. coco_datasaet_interface = CoCoDetectionDatasetInterface(dataset_params=yolo_v3_dataset_params)
  27. model.connect_dataset_interface(coco_datasaet_interface, data_loader_num_workers=8)
  28. model.build_model('yolo_v3', arch_params=yolo_v3_arch_params, load_checkpoint=False)
  29. yolo_v3_training_params = {"max_epochs": 273, 'lr_mode': "step", "lr_updates": [219, 246], "lr_decay_factor": 0.1,
  30. "initial_lr": 0.00579, "batch_accumulate": 4,
  31. "loss": "detection_loss", "criterion_params": {"model": model}, "optimizer": "SGD",
  32. "optimizer_params": {"momentum": 0.937, "weight_decay": 0.000484, "nesterov": True},
  33. "mixed_precision": True,
  34. "train_metrics_list": [],
  35. "valid_metrics_list": [DetectionMetrics(post_prediction_callback=post_prediction_callback,
  36. num_cls=len(
  37. coco_datasaet_interface.coco_classes))],
  38. "loss_logging_items_names": ["GIoU", "obj", "cls", "Loss"],
  39. "metric_to_watch": "mAP@0.50:0.95",
  40. "greater_metric_to_watch_is_better": True}
  41. model.train(training_params=yolo_v3_training_params)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...