Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

pretrain.py 4.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
  1. import os
  2. import hydra
  3. import molbart.utils.data_utils as util
  4. import molbart.utils.trainer_utils as trainer_utils
  5. from molbart.models.transformer_models import BARTModel, UnifiedModel
  6. from molbart.utils.samplers.beam_search_samplers import DecodeSampler
  7. from molbart.utils.tokenizer import ChemformerTokenizer, ReplaceTokensMasker, SpanTokensMasker
  8. def build_model(args, sampler, vocab_size, total_steps, pad_token_idx):
  9. # These args don't affect the model directly but will be saved by lightning as hparams
  10. # Tensorboard doesn't like None so we need to convert to string
  11. train_tokens = "None" if args.train_tokens is None else args.train_tokens
  12. n_buckets = "None" if args.n_buckets is None else args.n_buckets
  13. extra_args = {
  14. "batch_size": args.batch_size,
  15. "acc_batches": args.acc_batches,
  16. "mask_prob": args.mask_prob,
  17. "epochs": args.n_epochs,
  18. "clip_grad": args.clip_grad,
  19. "train_tokens": train_tokens,
  20. "num_buckets": n_buckets,
  21. "limit_val_batches": args.limit_val_batches,
  22. "augment_prob": args.augmentation_probability,
  23. "task": args.task,
  24. "mask_scheme": args.mask_scheme,
  25. "model_type": args.model_type,
  26. }
  27. if args.model_type == "bart":
  28. model = BARTModel(
  29. sampler,
  30. pad_token_idx,
  31. vocab_size,
  32. args.d_model,
  33. args.n_layers,
  34. args.n_heads,
  35. args.d_feedforward,
  36. args.learning_rate,
  37. args.weight_decay,
  38. args.activation,
  39. total_steps,
  40. args.max_seq_len,
  41. schedule=args.schedule,
  42. warm_up_steps=args.warm_up_steps,
  43. dropout=util.DEFAULT_DROPOUT,
  44. **extra_args,
  45. )
  46. elif args.model_type == "unified":
  47. model = UnifiedModel(
  48. sampler,
  49. pad_token_idx,
  50. vocab_size,
  51. args.d_model,
  52. args.n_layers,
  53. args.n_heads,
  54. args.d_feedforward,
  55. args.learning_rate,
  56. args.weight_decay,
  57. args.activation,
  58. total_steps,
  59. args.max_seq_len,
  60. schedule=args.schedule,
  61. warm_up_steps=args.warm_up_steps,
  62. dropout=util.DEFAULT_DROPOUT,
  63. **extra_args,
  64. )
  65. else:
  66. raise ValueError(f"Unknown model type {args.model_type}")
  67. return model
  68. @hydra.main(version_base=None, config_path="config", config_name="pretrain")
  69. def main(args):
  70. util.seed_everything(args.seed)
  71. if args.dataset_type == "zinc" and args.train_tokens is not None:
  72. raise ValueError("train_tokens arg must be None when using zinc dataset.")
  73. if args.n_gpus > 1 and args.train_tokens is not None:
  74. raise ValueError("train_tokens arg must be None when training on multiple gpus.")
  75. print("Building tokeniser...")
  76. tokeniser = ChemformerTokenizer(filename=args.vocabulary_path)
  77. if args.mask_scheme == "replace":
  78. masker = ReplaceTokensMasker(tokenizer=tokeniser, mask_prob=args.mask_prob)
  79. else:
  80. masker = SpanTokensMasker(tokenizer=tokeniser, mask_prob=args.mask_prob)
  81. print("Finished tokeniser.")
  82. print("Building data module...")
  83. dm = util.build_molecule_datamodule(args, tokeniser, masker=masker)
  84. n_available_cpus = len(os.sched_getaffinity(0))
  85. n_workers = n_available_cpus // args.n_gpus
  86. dm._num_workers = n_workers
  87. print(f"Using {str(n_workers)} workers for data module.")
  88. print("Finished data module.")
  89. vocab_size = len(tokeniser)
  90. train_steps = trainer_utils.calc_train_steps(args, dm)
  91. print(f"Train steps: {train_steps}")
  92. sampler = DecodeSampler(tokeniser, args.max_seq_len)
  93. pad_token_idx = tokeniser["pad"]
  94. print("Building model...")
  95. model = build_model(args, sampler, vocab_size, train_steps, pad_token_idx)
  96. print("Finished model.")
  97. print("Building trainer...")
  98. trainer = trainer_utils.build_trainer(args)
  99. print("Finished trainer.")
  100. print("Fitting data module to model")
  101. trainer.fit(model, dm)
  102. print("Finished training.")
  103. if __name__ == "__main__":
  104. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...