Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

trainer.py 4.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  1. from omegaconf import DictConfig
  2. import hydra
  3. from super_gradients.training.sg_model import MultiGPUMode
  4. from super_gradients.common.abstractions.abstract_logger import get_logger
  5. import torch
  6. class Trainer:
  7. """
  8. Class for running SuperGradient's recipes.
  9. See train_from_recipe example in the examples directory to demonstrate it's usage.
  10. """
  11. # FIXME: REMOVE PARAMETER MANIPULATION SPECIFIC FOR YOLO
  12. @staticmethod
  13. def scale_params_for_yolov5(cfg):
  14. """
  15. Scale:
  16. * learning rate,
  17. * weight decay,
  18. * box_loss_gain,
  19. * cls_loss_gain,
  20. * obj_loss_gain
  21. according to:
  22. * effective batch size
  23. * DDP world size
  24. * image size
  25. * num YOLO output layers
  26. * num classes
  27. """
  28. logger = get_logger(__name__)
  29. # Scale LR and weight decay
  30. is_ddp = cfg.sg_model.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL and torch.distributed.is_initialized()
  31. world_size = torch.distributed.get_world_size() if is_ddp else 1
  32. # Scale LR and WD for DDP due to gradients being averaged between devices
  33. # Equivalent to loss * WORLD_SIZE in ultralytics
  34. cfg.training_hyperparams.initial_lr *= world_size
  35. cfg.training_hyperparams.warmup_bias_lr *= world_size
  36. cfg.training_hyperparams.optimizer_params.weight_decay /= world_size
  37. # Scale WD with a factor of [effective batch size]/64.
  38. batch_size, batch_accumulate = cfg.dataset_params.batch_size, cfg.training_hyperparams.batch_accumulate
  39. batch_size_factor = cfg.sg_model.num_devices if is_ddp else cfg.sg_model.dataset_interface.batch_size_factor
  40. effective_batch_size = batch_size * batch_size_factor * batch_accumulate
  41. cfg.training_hyperparams.optimizer_params.weight_decay *= effective_batch_size / 64.
  42. # Scale EMA beta to match Ultralytics update
  43. cfg.training_hyperparams.ema_params.beta = cfg.training_hyperparams.max_epochs * len(cfg.sg_model.train_loader) / 2000.
  44. log_msg = \
  45. f"""
  46. IMPORTANT:\n
  47. Training with world size of {world_size}, {'DDP' if is_ddp else 'no DDP'}, effective batch size of {effective_batch_size},
  48. scaled:
  49. * initial_lr to {cfg.training_hyperparams.initial_lr};
  50. * warmup_bias_lr to {cfg.training_hyperparams.warmup_bias_lr};
  51. * weight_decay to {cfg.training_hyperparams.optimizer_params.weight_decay};
  52. * EMA beta to {cfg.training_hyperparams.ema_params.beta};
  53. """
  54. if cfg.training_hyperparams.loss == 'yolo_v5_loss':
  55. # Scale loss gains
  56. model = cfg.sg_model.net
  57. model = model.module if hasattr(model, 'module') else model
  58. num_levels = model._head._modules_list[-1].detection_layers_num
  59. train_image_size = cfg.dataset_params.train_image_size
  60. num_branches_norm = 3. / num_levels
  61. num_classes_norm = len(cfg.sg_model.classes) / 80.
  62. image_size_norm = train_image_size / 640.
  63. cfg.training_hyperparams.criterion_params.box_loss_gain *= num_branches_norm
  64. cfg.training_hyperparams.criterion_params.cls_loss_gain *= num_classes_norm * num_branches_norm
  65. cfg.training_hyperparams.criterion_params.obj_loss_gain *= image_size_norm ** 2 * num_branches_norm
  66. log_msg += \
  67. f"""
  68. * box_loss_gain to {cfg.training_hyperparams.criterion_params.box_loss_gain};
  69. * cls_loss_gain to {cfg.training_hyperparams.criterion_params.cls_loss_gain};
  70. * obj_loss_gain to {cfg.training_hyperparams.criterion_params.obj_loss_gain};
  71. """
  72. logger.info(log_msg)
  73. return cfg
  74. @staticmethod
  75. def train(cfg: DictConfig) -> None:
  76. """
  77. Trains according to cfg recipe configuration.
  78. @param cfg: The parsed DictConfig from yaml recipe files
  79. @return: output of sg_model.train(...) (i.e results tuple)
  80. """
  81. # INSTANTIATE ALL OBJECTS IN CFG
  82. cfg = hydra.utils.instantiate(cfg)
  83. # CONNECT THE DATASET INTERFACE WITH DECI MODEL
  84. cfg.sg_model.connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
  85. # BUILD NETWORK
  86. cfg.sg_model.build_model(cfg.architecture, arch_params=cfg.arch_params, load_checkpoint=cfg.load_checkpoint)
  87. # FIXME: REMOVE PARAMETER MANIPULATION SPECIFIC FOR YOLO
  88. if str(cfg.architecture).startswith("yolo_v5"):
  89. cfg = Trainer.scale_params_for_yolov5(cfg)
  90. # TRAIN
  91. cfg.sg_model.train(training_params=cfg.training_hyperparams)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...