Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

megamolbart_pretrain.py 5.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
  1. # Copyright (c) 2022, NVIDIA CORPORATION.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from omegaconf.omegaconf import OmegaConf, open_dict
  15. from pytorch_lightning import Trainer
  16. from pytorch_lightning.callbacks import ModelSummary
  17. from pytorch_lightning.callbacks.timer import Timer
  18. from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment
  19. from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
  20. from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
  21. from nemo.collections.nlp.parts.nlp_overrides import (
  22. GradScaler,
  23. MegatronHalfPrecisionPlugin,
  24. NLPDDPPlugin,
  25. PipelineMixedPrecisionPlugin,
  26. )
  27. from nemo.core.config import hydra_runner
  28. from nemo.utils import logging
  29. from nemo.utils.exp_manager import StatelessTimer, exp_manager
  30. from nemo_chem.models.megamolbart import MegaMolBARTModel
  31. from nemo_chem.data import MoleculeCsvDatasetConfig
  32. from nemo_chem.utils import recursive_make_dirs, update_dataclass_config
  33. from nemo_chem.data import Preprocess, CsvToBinary
  34. import os
  35. def setup_trainer(cfg):
  36. """Trainer setup functions"""
  37. megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
  38. plugins = [
  39. NLPDDPPlugin(
  40. no_ddp_communication_hook=True,
  41. gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
  42. find_unused_parameters=False,
  43. )
  44. ]
  45. if cfg.trainer.precision in [16, 'bf16']:
  46. scaler = None
  47. if cfg.trainer.precision == 16:
  48. scaler = GradScaler(
  49. init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
  50. growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
  51. hysteresis=cfg.model.get('hysteresis', 2),
  52. )
  53. if megatron_amp_o2:
  54. plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
  55. else:
  56. plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
  57. if cfg.get('cluster_type', None) == 'BCP':
  58. plugins.append(TorchElasticEnvironment())
  59. trainer = Trainer(plugins=plugins, **cfg.trainer, callbacks=[ModelSummary(max_depth=3)])
  60. exp_manager(trainer, cfg.get("exp_manager", None))
  61. # recursive_make_dirs(log_dir)
  62. # recursive_make_dirs(trainer.checkpoint_callback.dirpath)
  63. # update resume from checkpoint found by exp_manager
  64. if cfg.model.resume_from_checkpoint is not None:
  65. resume_from_checkpoint = cfg.model.resume_from_checkpoint
  66. else:
  67. resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
  68. logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')
  69. trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
  70. # Override timer callback to a stateless one
  71. for idx, callback in enumerate(trainer.callbacks):
  72. if isinstance(callback, Timer):
  73. trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,)
  74. # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams
  75. with open_dict(cfg):
  76. cfg.model.precision = cfg.trainer.precision
  77. return trainer
  78. @hydra_runner(config_path="conf", config_name="megamolbart_pretrain_xsmall_span_aug")
  79. def main(cfg) -> None:
  80. with open_dict(cfg):
  81. cfg.model.data = update_dataclass_config(cfg.model.data, MoleculeCsvDatasetConfig)
  82. logging.info("\n\n************** Experiment configuration ***********")
  83. logging.info(f'\n{OmegaConf.to_yaml(cfg)}')
  84. trainer = setup_trainer(cfg)
  85. model = MegaMolBARTModel(cfg.model, trainer)
  86. logging.info("************** Model parameters and their sizes ***********")
  87. for name, param in model.named_parameters():
  88. logging.info(f'{name}: {param.size()}')
  89. logging.info("***********************************************************")
  90. if cfg.do_training:
  91. logging.info("************** Starting Training ***********")
  92. trainer.fit(model)
  93. logging.info("************** Finished Training ***********")
  94. else:
  95. logging.info("************** Starting Data PreProcessing ***********")
  96. logging.info("Processing data into CSV files")
  97. preprocess = Preprocess()
  98. preprocess.prepare_dataset(links_file=cfg.model.data.links_file,
  99. output_dir=cfg.model.data.dataset_path)
  100. if cfg.model.data.dataset_format == "bin":
  101. logging.info("Converting CSV data into Binary")
  102. csvtobin = CsvToBinary(input_dir=cfg.model.data.dataset_path,
  103. out_dir=cfg.model.data.dataset_path,
  104. config=cfg,
  105. num_enumerations=cfg.model.data.num_enumerations,
  106. num_workers=cfg.model.data.num_workers)
  107. csvtobin.prepare_dataset()
  108. logging.info("************** Finished Data PreProcessing ***********")
  109. if cfg.do_testing:
  110. logging.info("************** Starting Testing ***********")
  111. trainer.test(model)
  112. logging.info("************** Finished Testing ***********")
  113. if __name__ == '__main__':
  114. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...