Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

params.py 4.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  1. from super_gradients.training.utils import HpmStruct
  2. DEFAULT_TRAINING_PARAMS = {"lr_warmup_epochs": 0,
  3. "cosine_final_lr_ratio": 0.01,
  4. "optimizer": "SGD",
  5. "criterion_params": {},
  6. "ema": False,
  7. "batch_accumulate": 1, # number of batches to accumulate before every backward pass
  8. "ema_params": {},
  9. "zero_weight_decay_on_bias_and_bn": False,
  10. "load_opt_params": True,
  11. "run_validation_freq": 1,
  12. "save_model": True,
  13. "metric_to_watch": "Accuracy",
  14. "launch_tensorboard": False,
  15. "tb_files_user_prompt": False, # Asks User for Tensorboard Deletion Prompt
  16. "silent_mode": False, # Silents the Print outs
  17. "mixed_precision": False,
  18. "tensorboard_port": None,
  19. "save_ckpt_epoch_list": [], # indices where the ckpt will save automatically
  20. "average_best_models": True,
  21. "dataset_statistics": False, # add a dataset statistical analysis and sample images to tensorboard
  22. "save_tensorboard_to_s3": False,
  23. "lr_schedule_function": None,
  24. "train_metrics_list": [],
  25. "valid_metrics_list": [],
  26. "loss_logging_items_names": ["Loss"],
  27. "greater_metric_to_watch_is_better": True,
  28. "precise_bn": False,
  29. "precise_bn_batch_size": None,
  30. "seed": 42,
  31. "lr_mode": None,
  32. "phase_callbacks": [],
  33. "log_installed_packages": True,
  34. "save_full_train_log": False,
  35. "sg_logger": "base_sg_logger",
  36. "sg_logger_params":
  37. {"tb_files_user_prompt": False, # Asks User for Tensorboard Deletion Prompt
  38. "project_name": "",
  39. "launch_tensorboard": False,
  40. "tensorboard_port": None,
  41. "save_checkpoints_remote": False, # upload checkpoint files to s3
  42. "save_tensorboard_remote": False, # upload tensorboard files to s3
  43. "save_logs_remote": False} # upload log files to s3
  44. }
  45. DEFAULT_OPTIMIZER_PARAMS_SGD = {"weight_decay": 1e-4, "momentum": 0.9}
  46. DEFAULT_OPTIMIZER_PARAMS_ADAM = {"weight_decay": 1e-4}
  47. DEFAULT_OPTIMIZER_PARAMS_RMSPROP = {"weight_decay": 1e-4, "momentum": 0.9}
  48. DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF = {"weight_decay": 1e-4, "momentum": 0.9}
  49. TRAINING_PARAM_SCHEMA = {"type": "object",
  50. "properties": {
  51. "max_epochs": {"type": "number", "minimum": 1, "maximum": 800},
  52. # FIXME: CHECK THE IMPORTANCE OF THE COMMENTED SCHEMA- AS IT CAUSES HYDRA USE TO CRASH
  53. # "lr_updates": {"type": "array", "minItems": 1},
  54. "lr_decay_factor": {"type": "number", "minimum": 0, "maximum": 1},
  55. "lr_warmup_epochs": {"type": "number", "minimum": 0, "maximum": 10},
  56. "initial_lr": {"type": "number", "exclusiveMinimum": 0, "maximum": 10}
  57. },
  58. "if": {
  59. "properties": {"lr_mode": {"const": "step"}}
  60. },
  61. "then": {
  62. "required": ["lr_updates", "lr_decay_factor"]
  63. },
  64. "required": ["max_epochs", "lr_mode", "initial_lr", "loss"]
  65. }
  66. class TrainingParams(HpmStruct):
  67. def __init__(self, **entries):
  68. # WE initialize by the default training params, overridden by the provided params
  69. super().__init__(**DEFAULT_TRAINING_PARAMS)
  70. self.set_schema(TRAINING_PARAM_SCHEMA)
  71. if len(entries) > 0:
  72. self.override(**entries)
  73. def override(self, **entries):
  74. super().override(**entries)
  75. self.validate()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...