Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

params.py 5.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
  1. from super_gradients.training.utils import HpmStruct
  2. from copy import deepcopy
  3. DEFAULT_TRAINING_PARAMS = {
  4. "lr_warmup_epochs": 0,
  5. "lr_warmup_steps": 0,
  6. "lr_cooldown_epochs": 0,
  7. "warmup_initial_lr": None,
  8. "cosine_final_lr_ratio": 0.01,
  9. "optimizer": "SGD",
  10. "optimizer_params": {},
  11. "criterion_params": {},
  12. "ema": False,
  13. "batch_accumulate": 1, # number of batches to accumulate before every backward pass
  14. "ema_params": {},
  15. "zero_weight_decay_on_bias_and_bn": False,
  16. "load_opt_params": True,
  17. "run_validation_freq": 1,
  18. "run_test_freq": 1,
  19. "save_model": True,
  20. "metric_to_watch": "Accuracy",
  21. "launch_tensorboard": False,
  22. "tb_files_user_prompt": False, # Asks User for Tensorboard Deletion Prompt
  23. "silent_mode": False, # Silents the Print outs
  24. "mixed_precision": False,
  25. "tensorboard_port": None,
  26. "save_ckpt_epoch_list": [], # indices where the ckpt will save automatically
  27. "average_best_models": True,
  28. "dataset_statistics": False, # add a dataset statistical analysis and sample images to tensorboard
  29. "save_tensorboard_to_s3": False,
  30. "lr_schedule_function": None,
  31. "train_metrics_list": [],
  32. "valid_metrics_list": [],
  33. "greater_metric_to_watch_is_better": True,
  34. "precise_bn": False,
  35. "precise_bn_batch_size": None,
  36. "seed": 42,
  37. "lr_mode": None,
  38. "phase_callbacks": None,
  39. "log_installed_packages": True,
  40. "sg_logger": "base_sg_logger",
  41. "sg_logger_params": {
  42. "tb_files_user_prompt": False, # Asks User for Tensorboard Deletion Prompt
  43. "project_name": "",
  44. "launch_tensorboard": False,
  45. "tensorboard_port": None,
  46. "save_checkpoints_remote": False, # upload checkpoint files to s3
  47. "save_tensorboard_remote": False, # upload tensorboard files to s3
  48. "save_logs_remote": False,
  49. }, # upload log files to s3
  50. "warmup_mode": "LinearEpochLRWarmup",
  51. "step_lr_update_freq": None,
  52. "lr_updates": [],
  53. "initial_lr": None,
  54. "clip_grad_norm": None,
  55. "pre_prediction_callback": None,
  56. "ckpt_best_name": "ckpt_best.pth",
  57. "enable_qat": False,
  58. "resume": False,
  59. "resume_path": None,
  60. "ckpt_name": "ckpt_latest.pth",
  61. "resume_strict_load": False,
  62. "sync_bn": False,
  63. "kill_ddp_pgroup_on_end": True, # Whether to kill the DDP process group in the end of training.
  64. "max_train_batches": None, # For debug- when not None- will break out of inner train loop
  65. # (i.e iterating over train_loader) when reaching this number of batches.
  66. "max_valid_batches": None, # For debug- when not None- will break out of inner valid loop
  67. # (i.e iterating over valid_loader) when reaching this number of batches.
  68. "resume_from_remote_sg_logger": False, # When true, ckpt_name (checkpoint filename to resume, ckpt_latest.pth by
  69. # default) will be downloaded into the experiment checkpoints directory prior to loading weights, then resumed
  70. # from that checkpoint. The source is unique to every logger, and currently supported for WandB loggers only.
  71. # Note that for this to work, the experiment must be ran with sg_logger_params.save_checkpoints_remote=True. For
  72. # WandB loggers, one must also pass the run id through the wandb_id arg in sg_logger_params.
  73. "torch_compile": False, # Enable or disable use of torch.compile to optimize the model
  74. "torch_compile_loss": False, # Enable or disable use of torch.compile to optimize the loss
  75. "torch_compile_options": {
  76. "mode": "reduce-overhead", # Can be either “default”, “reduce-overhead” or “max-autotune”
  77. "fullgraph": False, # Whether it is ok to break model into several subgraphs
  78. "dynamic": False, # Use dynamic shape tracing
  79. "backend": "inductor", # backend to be used
  80. "options": None, # A dictionary of options to pass to the backend.
  81. "disable": False, # Turn torch.compile() into a no-op for testing
  82. }, # torch.compile options from https://pytorch.org/docs/stable/generated/torch.compile.html
  83. "finetune": False # Whether to freeze a fixed part of the model (supported only for models that implement
  84. # get_finetune_lr_dict, see SgModule.get_finetune_lr_dict. Tailored for each model class.)
  85. }
  86. DEFAULT_OPTIMIZER_PARAMS_SGD = {"weight_decay": 1e-4, "momentum": 0.9}
  87. DEFAULT_OPTIMIZER_PARAMS_ADAM = {"weight_decay": 1e-4}
  88. DEFAULT_OPTIMIZER_PARAMS_RMSPROP = {"weight_decay": 1e-4, "momentum": 0.9}
  89. DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF = {"weight_decay": 1e-4, "momentum": 0.9}
  90. TRAINING_PARAM_SCHEMA = {
  91. "type": "object",
  92. "properties": {
  93. "max_epochs": {"type": "number", "minimum": 1},
  94. # FIXME: CHECK THE IMPORTANCE OF THE COMMENTED SCHEMA- AS IT CAUSES HYDRA USE TO CRASH
  95. # "lr_updates": {"type": "array", "minItems": 1},
  96. "lr_decay_factor": {"type": "number", "minimum": 0, "maximum": 1},
  97. "lr_warmup_epochs": {"type": "number", "minimum": 0, "maximum": 10},
  98. "initial_lr": {
  99. "anyOf": [
  100. {"type": ["number", "string", "boolean", "null"]},
  101. {"type": "object", "patternProperties": {"^[a-zA-Z0-9_.]+$": {"type": "number"}}, "additionalProperties": False},
  102. ]
  103. },
  104. },
  105. "if": {"properties": {"lr_mode": {"const": "StepLRScheduler"}}},
  106. "then": {"required": ["lr_updates", "lr_decay_factor"]},
  107. "required": ["max_epochs", "lr_mode", "initial_lr", "loss"],
  108. }
  109. class TrainingParams(HpmStruct):
  110. def __init__(self, **entries):
  111. # WE initialize by the default training params, overridden by the provided params
  112. default_training_params = deepcopy(DEFAULT_TRAINING_PARAMS)
  113. super().__init__(**default_training_params)
  114. self.set_schema(TRAINING_PARAM_SCHEMA)
  115. if len(entries) > 0:
  116. self.override(**entries)
  117. def override(self, **entries):
  118. super().override(**entries)
  119. self.validate()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...