1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
|
- import copy
- from argparse import Namespace
- import numpy as np
- import pandas as pd
- import pytorch_lightning as pl
- import torch
- from molbart.models.transformer_models import BARTModel, UnifiedModel
- from molbart.data import SynthesisDataModule
- from molbart.data.mol_data import ChemblDataModule, ZincDataModule
- from molbart.data.seq2seq_data import (
- MolecularOptimizationDataModule,
- Uspto50DataModule,
- UsptoMixedDataModule,
- UsptoSepDataModule,
- )
- # Default model hyperparams
- DEFAULT_D_MODEL = 512
- DEFAULT_NUM_LAYERS = 6
- DEFAULT_NUM_HEADS = 8
- DEFAULT_D_FEEDFORWARD = 2048
- DEFAULT_ACTIVATION = "gelu"
- DEFAULT_MAX_SEQ_LEN = 512
- DEFAULT_DROPOUT = 0.1
- DEFAULT_MODEL = "bart"
- DEFAULT_DATASET_TYPE = "synthesis"
- DEFAULT_DEEPSPEED_CONFIG_PATH = "ds_config.json"
- DEFAULT_LOG_DIR = "tb_logs"
- DEFAULT_VOCAB_PATH = "bart_vocab.json"
- DEFAULT_CHEM_TOKEN_START = 272
- REGEX = r"\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]"
- DEFAULT_GPUS = 1
- DEFAULT_NUM_NODES = 1
- USE_GPU = True
- use_gpu = USE_GPU and torch.cuda.is_available()
- def build_molecule_datamodule(args, tokenizer, masker=None):
- dm_cls = {
- "chembl": ChemblDataModule,
- "zinc": ZincDataModule,
- }
- dm = dm_cls[args.dataset_type](
- task=args.task,
- augment_prob=args.augmentation_probability,
- masker=masker,
- dataset_path=args.data_path,
- tokenizer=tokenizer,
- batch_size=args.batch_size,
- max_seq_len=args.max_seq_len,
- train_token_batch_size=args.train_tokens,
- num_buckets=args.n_buckets,
- unified_model=args.model_type == "unified",
- )
- return dm
- def build_seq2seq_datamodule(config, tokenizer, forward=True):
- dm_cls = {
- "uspto_50": Uspto50DataModule,
- "uspto_50_with_type": Uspto50DataModule,
- "uspto_mixed": UsptoMixedDataModule,
- "uspto_sep": UsptoSepDataModule,
- "mol_opt": MolecularOptimizationDataModule,
- "synthesis": SynthesisDataModule,
- }
- if config.dataset_type not in dm_cls:
- raise ValueError(f"Unknown dataset: {config.dataset_type}")
- kwargs = {
- "uspto_50_with_type": {
- "include_type_token": True,
- }
- }
- if config.get("n_chunks") is not None and config.get("i_chunk") is None:
- raise ValueError("n_chunks is specified in config, but i_chunk is not.")
- if config.get("i_chunk") is not None and config.get("n_chunks") is None:
- raise ValueError("i_chunk is specified in config, but n_chunks is not.")
- dm = dm_cls[config.dataset_type](
- augment_prob=config.get("augmentation_probability"),
- reverse=not forward,
- dataset_path=config.data_path,
- tokenizer=tokenizer,
- batch_size=config.batch_size,
- max_seq_len=getattr(config, "max_seq_len", DEFAULT_MAX_SEQ_LEN),
- train_token_batch_size=config.get("train_tokens"),
- num_buckets=config.get("n_buckets"),
- unified_model=config.model_type == "unified",
- i_chunk=config.get("i_chunk", 0),
- n_chunks=config.get("n_chunks", 1),
- **kwargs.get(config.dataset_type, {}),
- )
- return dm
- def seed_everything(seed):
- pl.utilities.seed.seed_everything(seed)
- def load_bart(args, sampler):
- model = BARTModel.load_from_checkpoint(args.model_path, decode_sampler=sampler)
- model.eval()
- return model
- def load_unified(args, sampler):
- model = UnifiedModel.load_from_checkpoint(args.model_path, decode_sampler=sampler)
- model.eval()
- return model
- def _clean_string(x, expr_list):
- y = copy.copy(x)
- y = y.replace("''", "&") # Mark empty SMILES string with dummy character
- for expr in expr_list:
- y = y.replace(expr, "")
- return y
- def _convert_to_array(data_list):
- data_new = np.zeros(len(data_list), dtype="object")
- for ix, x in enumerate(data_list):
- data_new[ix] = x
- return data_new
- def read_score_tsv(
- filename,
- str_to_list_columns,
- is_numeric,
- expr_list1=["'", "[array([", "array([", "[array(", "array(", " ", "\n"],
- ):
- """
- Read TSV-file generated by the Chemformer.score_model() function.
- Args:
- - filename: str (path to .csv file)
- - str_to_list_columns: list(str) (list of columns to convert from string to nested list)
- - is_numeric: list(bool) (list denoting which columns contain strings that should be converted to lists of floats)
- """
- sep = ","
- numeric_expr_list = ["(", ")", "[", "]", "\n"]
- data = pd.read_csv(filename, sep="\t")
- for col, to_float in zip(str_to_list_columns, is_numeric):
- print("Converting string to data of column: " + col)
- data_str = data[col].values
- data_list = []
- for X in data_str:
- X = [x for x in X.split(sep) if "dtype=" not in x]
- inner_list = []
- X_new = []
- is_last_molecule = False
- for x in X:
- x = _clean_string(x, expr_list1)
- if x == "":
- continue
- if x[-1] == ")" and sum([token == "(" for token in x]) < sum([token == ")" for token in x]):
- x = x[:-1]
- is_last_molecule = True
- if x[-1] == "]" and sum([token == "[" for token in x]) < sum([token == "]" for token in x]):
- x = x[:-1]
- is_last_molecule = True
- inner_list.append(x)
- if is_last_molecule:
- if to_float:
- inner_list = [_clean_string(element, numeric_expr_list) for element in inner_list]
- inner_list = [float(element) if element != "" else np.nan for element in inner_list]
- X_new.append(inner_list)
- inner_list = []
- is_last_molecule = False
- print("Batch size after cleaning (for validating cleaning): " + str(len(X_new)))
- data_list.append(X_new)
- data.drop(columns=[col], inplace=True)
- data_list = _convert_to_array(data_list)
- data[col] = data_list
- return data
|