Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

trainer_utils.py 3.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  1. from typing import List, Optional
  2. import hydra
  3. import math
  4. import pytorch_lightning as pl
  5. from omegaconf import DictConfig
  6. from pytorch_lightning.loggers import TensorBoardLogger
  7. from pytorch_lightning.plugins import Plugin
  8. from molbart.utils.callbacks import CallbackCollection
  9. from molbart.utils.scores import ScoreCollection
  10. def instantiate_callbacks(callbacks_config: Optional[DictConfig]) -> CallbackCollection:
  11. """Instantiates callbacks from config."""
  12. callbacks = CallbackCollection()
  13. if not callbacks_config:
  14. print("No callbacks configs found! Skipping...")
  15. return callbacks
  16. callbacks.load_from_config(callbacks_config)
  17. return callbacks
  18. def instantiate_scorers(scorer_config: Optional[DictConfig]) -> CallbackCollection:
  19. """Instantiates scorer from config."""
  20. scorer = ScoreCollection()
  21. if not scorer_config:
  22. print("No scorer configs found! Skipping...")
  23. return scorer
  24. scorer.load_from_config(scorer_config)
  25. return scorer
  26. def instantiate_logger(logger_config: Optional[DictConfig]) -> TensorBoardLogger:
  27. """Instantiates logger from config."""
  28. logger: TensorBoardLogger = []
  29. if not logger_config:
  30. print("No logger configs found! Skipping...")
  31. return logger
  32. if not isinstance(logger_config, DictConfig):
  33. raise TypeError("Logger config must be a DictConfig!")
  34. if isinstance(logger_config, DictConfig) and "_target_" in logger_config:
  35. print(f"Instantiating logger <{logger_config._target_}>")
  36. logger = hydra.utils.instantiate(logger_config)
  37. return logger
  38. def instantiate_plugins(plugin_cfg: Optional[DictConfig]) -> List[Plugin]:
  39. """Instantiates plugins from config."""
  40. plugin: list[Plugin] = []
  41. if not plugin_cfg:
  42. print("No plugin configs found! Skipping...")
  43. return plugin
  44. if not isinstance(plugin_cfg, DictConfig):
  45. raise TypeError("Plugin config must be a DictConfig!")
  46. for _, plugin_conf in plugin_cfg.items():
  47. if isinstance(plugin_conf, DictConfig) and "_target_" in plugin_conf:
  48. print(f"Instantiating logger <{plugin_conf._target_}>")
  49. plugin.append(hydra.utils.instantiate(plugin_conf))
  50. return plugin
  51. def calc_train_steps(args, dm, n_gpus=None):
  52. n_gpus = getattr(args, "n_gpus", n_gpus)
  53. dm.setup()
  54. if n_gpus > 0:
  55. batches_per_gpu = math.ceil(len(dm.train_dataloader()) / float(n_gpus))
  56. else:
  57. raise ValueError("Number of GPUs should be > 0 in training.")
  58. train_steps = math.ceil(batches_per_gpu / args.acc_batches) * args.n_epochs
  59. return train_steps
  60. def build_trainer(config, n_gpus=None):
  61. print("Instantiating loggers...")
  62. logger = instantiate_logger(config.get("logger"))
  63. print("Instantiating callbacks...")
  64. callbacks: CallbackCollection = instantiate_callbacks(config.get("callbacks"))
  65. print("Instantiating plugins...")
  66. plugins: list[Plugin] = instantiate_plugins(config.get("plugin"))
  67. if n_gpus > 1:
  68. config.trainer.accelerator = "ddp"
  69. else:
  70. plugins = None
  71. print("Building trainer...")
  72. trainer: pl.Trainer = hydra.utils.instantiate(
  73. config.trainer, callbacks=callbacks.objects(), logger=logger, plugins=plugins
  74. )
  75. print("Finished trainer.")
  76. print(f"Default logging and checkpointing directory: {trainer.default_root_dir} or {trainer.weights_save_path}")
  77. return trainer
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...