1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
|
- # Copyright (c) 2022, NVIDIA CORPORATION.
- # SPDX-License-Identifier: Apache-2.0
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from omegaconf.omegaconf import OmegaConf, open_dict
- from pytorch_lightning import Trainer
- from pytorch_lightning.callbacks import ModelSummary
- from pytorch_lightning.callbacks.timer import Timer
- from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment
- from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
- from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
- from nemo.collections.nlp.parts.nlp_overrides import (
- GradScaler,
- MegatronHalfPrecisionPlugin,
- NLPDDPPlugin,
- PipelineMixedPrecisionPlugin,
- )
- from nemo.core.config import hydra_runner
- from nemo.utils import logging
- from nemo.utils.exp_manager import StatelessTimer, exp_manager
- from nemo_chem.models.megamolbart import MegaMolBARTModel
- from nemo_chem.data import MoleculeCsvDatasetConfig
- from nemo_chem.utils import recursive_make_dirs, update_dataclass_config
- from nemo_chem.data import Preprocess, CsvToBinary
- import os
- def setup_trainer(cfg):
- """Trainer setup functions"""
- megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
- plugins = [
- NLPDDPPlugin(
- no_ddp_communication_hook=True,
- gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
- find_unused_parameters=False,
- )
- ]
- if cfg.trainer.precision in [16, 'bf16']:
- scaler = None
- if cfg.trainer.precision == 16:
- scaler = GradScaler(
- init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
- growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
- hysteresis=cfg.model.get('hysteresis', 2),
- )
- if megatron_amp_o2:
- plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
- else:
- plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
- if cfg.get('cluster_type', None) == 'BCP':
- plugins.append(TorchElasticEnvironment())
- trainer = Trainer(plugins=plugins, **cfg.trainer, callbacks=[ModelSummary(max_depth=3)])
- exp_manager(trainer, cfg.get("exp_manager", None))
- # recursive_make_dirs(log_dir)
- # recursive_make_dirs(trainer.checkpoint_callback.dirpath)
- # update resume from checkpoint found by exp_manager
- if cfg.model.resume_from_checkpoint is not None:
- resume_from_checkpoint = cfg.model.resume_from_checkpoint
- else:
- resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
- logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')
- trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
- # Override timer callback to a stateless one
- for idx, callback in enumerate(trainer.callbacks):
- if isinstance(callback, Timer):
- trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,)
- # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams
- with open_dict(cfg):
- cfg.model.precision = cfg.trainer.precision
- return trainer
- @hydra_runner(config_path="conf", config_name="megamolbart_pretrain_xsmall_span_aug")
- def main(cfg) -> None:
- with open_dict(cfg):
- cfg.model.data = update_dataclass_config(cfg.model.data, MoleculeCsvDatasetConfig)
- logging.info("\n\n************** Experiment configuration ***********")
- logging.info(f'\n{OmegaConf.to_yaml(cfg)}')
- trainer = setup_trainer(cfg)
- model = MegaMolBARTModel(cfg.model, trainer)
- logging.info("************** Model parameters and their sizes ***********")
- for name, param in model.named_parameters():
- logging.info(f'{name}: {param.size()}')
- logging.info("***********************************************************")
- if cfg.do_training:
- logging.info("************** Starting Training ***********")
- trainer.fit(model)
- logging.info("************** Finished Training ***********")
- else:
- logging.info("************** Starting Data PreProcessing ***********")
- logging.info("Processing data into CSV files")
- preprocess = Preprocess()
- preprocess.prepare_dataset(links_file=cfg.model.data.links_file,
- output_dir=cfg.model.data.dataset_path)
- if cfg.model.data.dataset_format == "bin":
- logging.info("Converting CSV data into Binary")
- csvtobin = CsvToBinary(input_dir=cfg.model.data.dataset_path,
- out_dir=cfg.model.data.dataset_path,
- config=cfg,
- num_enumerations=cfg.model.data.num_enumerations,
- num_workers=cfg.model.data.num_workers)
- csvtobin.prepare_dataset()
- logging.info("************** Finished Data PreProcessing ***********")
- if cfg.do_testing:
- logging.info("************** Starting Testing ***********")
- trainer.test(model)
- logging.info("************** Finished Testing ***********")
- if __name__ == '__main__':
- main()
|