Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

yolo_v3_subclass.py 2.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
  1. from super_gradients.training import SgModel, MultiGPUMode
  2. from super_gradients.training.datasets import CoCoDetectionDatasetInterface
  3. from super_gradients.training.utils.detection_utils import base_detection_collate_fn
  4. from super_gradients.training.datasets.datasets_utils import ComposedCollateFunction, MultiScaleCollateFunction
  5. from super_gradients.training.utils.detection_utils import YoloV3NonMaxSuppression
  6. from super_gradients.training.metrics.detection_metrics import DetectionMetrics
  7. collate_fn_holder = ComposedCollateFunction([base_detection_collate_fn, MultiScaleCollateFunction(320)])
  8. yolo_v3_dataset_params = {"batch_size": 16,
  9. "test_batch_size": 16,
  10. "dataset_dir": "/data/coco/",
  11. "s3_link": None,
  12. "image_size": 320,
  13. "test_collate_fn": base_detection_collate_fn,
  14. "train_collate_fn": collate_fn_holder,
  15. "class_inclusion_list": ['person']
  16. }
  17. yolo_v3_arch_params = {"image_size": 320, "iou_t": 0.225, "multi_gpu_mode": "distributed_data_parallel"}
  18. post_prediction_callback = YoloV3NonMaxSuppression()
  19. model = SgModel('yolo_v3_spp_example', model_checkpoints_location='local', multi_gpu=MultiGPUMode.OFF,
  20. post_prediction_callback=post_prediction_callback)
  21. coco_datasaet_interface = CoCoDetectionDatasetInterface(dataset_params=yolo_v3_dataset_params, cache_labels=True)
  22. model.connect_dataset_interface(coco_datasaet_interface, data_loader_num_workers=8)
  23. model.build_model('yolo_v3', arch_params=yolo_v3_arch_params, load_checkpoint=False)
  24. yolo_v3_training_params = {"max_epochs": 273, 'lr_mode': "step", "lr_updates": [219, 246], "lr_decay_factor": 0.1,
  25. "initial_lr": 0.00579, "batch_accumulate": 4,
  26. "loss": "detection_loss", "criterion_params": {"model": model}, "optimizer": "SGD",
  27. "optimizer_params": {"momentum": 0.937, "weight_decay": 0.000484, "nesterov": True},
  28. "mixed_precision": True,
  29. "train_metrics_list": [],
  30. "valid_metrics_list": [DetectionMetrics(post_prediction_callback=post_prediction_callback,
  31. num_cls=len(
  32. coco_datasaet_interface.coco_classes))],
  33. "loss_logging_items_names": ["GIoU", "obj", "cls", "Loss"],
  34. "metric_to_watch": "mAP@0.50:0.95",
  35. "greater_metric_to_watch_is_better": True}
  36. model.train(training_params=yolo_v3_training_params)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...