Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

data_utils.py 6.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
  1. import copy
  2. from argparse import Namespace
  3. import numpy as np
  4. import pandas as pd
  5. import pytorch_lightning as pl
  6. import torch
  7. from molbart.models.transformer_models import BARTModel, UnifiedModel
  8. from molbart.data import SynthesisDataModule
  9. from molbart.data.mol_data import ChemblDataModule, ZincDataModule
  10. from molbart.data.seq2seq_data import (
  11. MolecularOptimizationDataModule,
  12. Uspto50DataModule,
  13. UsptoMixedDataModule,
  14. UsptoSepDataModule,
  15. )
  16. # Default model hyperparams
  17. DEFAULT_D_MODEL = 512
  18. DEFAULT_NUM_LAYERS = 6
  19. DEFAULT_NUM_HEADS = 8
  20. DEFAULT_D_FEEDFORWARD = 2048
  21. DEFAULT_ACTIVATION = "gelu"
  22. DEFAULT_MAX_SEQ_LEN = 512
  23. DEFAULT_DROPOUT = 0.1
  24. DEFAULT_MODEL = "bart"
  25. DEFAULT_DATASET_TYPE = "synthesis"
  26. DEFAULT_DEEPSPEED_CONFIG_PATH = "ds_config.json"
  27. DEFAULT_LOG_DIR = "tb_logs"
  28. DEFAULT_VOCAB_PATH = "bart_vocab.json"
  29. DEFAULT_CHEM_TOKEN_START = 272
  30. REGEX = r"\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]"
  31. DEFAULT_GPUS = 1
  32. DEFAULT_NUM_NODES = 1
  33. USE_GPU = True
  34. use_gpu = USE_GPU and torch.cuda.is_available()
  35. def build_molecule_datamodule(args, tokenizer, masker=None):
  36. dm_cls = {
  37. "chembl": ChemblDataModule,
  38. "zinc": ZincDataModule,
  39. }
  40. dm = dm_cls[args.dataset_type](
  41. task=args.task,
  42. augment_prob=args.augmentation_probability,
  43. masker=masker,
  44. dataset_path=args.data_path,
  45. tokenizer=tokenizer,
  46. batch_size=args.batch_size,
  47. max_seq_len=args.max_seq_len,
  48. train_token_batch_size=args.train_tokens,
  49. num_buckets=args.n_buckets,
  50. unified_model=args.model_type == "unified",
  51. )
  52. return dm
  53. def build_seq2seq_datamodule(config, tokenizer, forward=True):
  54. dm_cls = {
  55. "uspto_50": Uspto50DataModule,
  56. "uspto_50_with_type": Uspto50DataModule,
  57. "uspto_mixed": UsptoMixedDataModule,
  58. "uspto_sep": UsptoSepDataModule,
  59. "mol_opt": MolecularOptimizationDataModule,
  60. "synthesis": SynthesisDataModule,
  61. }
  62. if config.dataset_type not in dm_cls:
  63. raise ValueError(f"Unknown dataset: {config.dataset_type}")
  64. kwargs = {
  65. "uspto_50_with_type": {
  66. "include_type_token": True,
  67. }
  68. }
  69. if config.get("n_chunks") is not None and config.get("i_chunk") is None:
  70. raise ValueError("n_chunks is specified in config, but i_chunk is not.")
  71. if config.get("i_chunk") is not None and config.get("n_chunks") is None:
  72. raise ValueError("i_chunk is specified in config, but n_chunks is not.")
  73. dm = dm_cls[config.dataset_type](
  74. augment_prob=config.get("augmentation_probability"),
  75. reverse=not forward,
  76. dataset_path=config.data_path,
  77. tokenizer=tokenizer,
  78. batch_size=config.batch_size,
  79. max_seq_len=getattr(config, "max_seq_len", DEFAULT_MAX_SEQ_LEN),
  80. train_token_batch_size=config.get("train_tokens"),
  81. num_buckets=config.get("n_buckets"),
  82. unified_model=config.model_type == "unified",
  83. i_chunk=config.get("i_chunk", 0),
  84. n_chunks=config.get("n_chunks", 1),
  85. **kwargs.get(config.dataset_type, {}),
  86. )
  87. return dm
  88. def seed_everything(seed):
  89. pl.utilities.seed.seed_everything(seed)
  90. def load_bart(args, sampler):
  91. model = BARTModel.load_from_checkpoint(args.model_path, decode_sampler=sampler)
  92. model.eval()
  93. return model
  94. def load_unified(args, sampler):
  95. model = UnifiedModel.load_from_checkpoint(args.model_path, decode_sampler=sampler)
  96. model.eval()
  97. return model
  98. def _clean_string(x, expr_list):
  99. y = copy.copy(x)
  100. y = y.replace("''", "&") # Mark empty SMILES string with dummy character
  101. for expr in expr_list:
  102. y = y.replace(expr, "")
  103. return y
  104. def _convert_to_array(data_list):
  105. data_new = np.zeros(len(data_list), dtype="object")
  106. for ix, x in enumerate(data_list):
  107. data_new[ix] = x
  108. return data_new
  109. def read_score_tsv(
  110. filename,
  111. str_to_list_columns,
  112. is_numeric,
  113. expr_list1=["'", "[array([", "array([", "[array(", "array(", " ", "\n"],
  114. ):
  115. """
  116. Read TSV-file generated by the Chemformer.score_model() function.
  117. Args:
  118. - filename: str (path to .csv file)
  119. - str_to_list_columns: list(str) (list of columns to convert from string to nested list)
  120. - is_numeric: list(bool) (list denoting which columns contain strings that should be converted to lists of floats)
  121. """
  122. sep = ","
  123. numeric_expr_list = ["(", ")", "[", "]", "\n"]
  124. data = pd.read_csv(filename, sep="\t")
  125. for col, to_float in zip(str_to_list_columns, is_numeric):
  126. print("Converting string to data of column: " + col)
  127. data_str = data[col].values
  128. data_list = []
  129. for X in data_str:
  130. X = [x for x in X.split(sep) if "dtype=" not in x]
  131. inner_list = []
  132. X_new = []
  133. is_last_molecule = False
  134. for x in X:
  135. x = _clean_string(x, expr_list1)
  136. if x == "":
  137. continue
  138. if x[-1] == ")" and sum([token == "(" for token in x]) < sum([token == ")" for token in x]):
  139. x = x[:-1]
  140. is_last_molecule = True
  141. if x[-1] == "]" and sum([token == "[" for token in x]) < sum([token == "]" for token in x]):
  142. x = x[:-1]
  143. is_last_molecule = True
  144. inner_list.append(x)
  145. if is_last_molecule:
  146. if to_float:
  147. inner_list = [_clean_string(element, numeric_expr_list) for element in inner_list]
  148. inner_list = [float(element) if element != "" else np.nan for element in inner_list]
  149. X_new.append(inner_list)
  150. inner_list = []
  151. is_last_molecule = False
  152. print("Batch size after cleaning (for validating cleaning): " + str(len(X_new)))
  153. data_list.append(X_new)
  154. data.drop(columns=[col], inplace=True)
  155. data_list = _convert_to_array(data_list)
  156. data[col] = data_list
  157. return data
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...