1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
|
- import pathlib
- from argparse import Namespace
- import numpy as np
- import omegaconf as oc
- import pandas as pd
- import pytest
- import molbart.utils.data_utils as util
- from molbart.models import Chemformer
- from molbart.data import SynthesisDataModule
- from molbart.utils.tokenizers import ChemformerTokenizer, SpanTokensMasker
- @pytest.fixture
- def example_tokens():
- return [
- ["^", "C", "(", "=", "O", ")", "unknown", "&"],
- ["^", "C", "C", "<SEP>", "C", "Br", "&"],
- ]
- @pytest.fixture
- def regex_tokens():
- regex = r"\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]"
- return regex.split("|")
- @pytest.fixture
- def smiles_data():
- return ["CCO.Ccc", "CCClCCl", "C(=O)CBr"]
- @pytest.fixture
- def mock_random_choice(mocker):
- class ToggleBool:
- def __init__(self):
- self.state = True
- def __call__(self, *args, **kwargs):
- states = []
- for _ in range(kwargs["k"]):
- states.append(self.state)
- self.state = not self.state
- return states
- mocker.patch("molbart.utils.tokenizers.tokenizers.random.choices", side_effect=ToggleBool())
- @pytest.fixture
- def setup_tokenizer(regex_tokens, smiles_data):
- def wrapper(tokens=None):
- return ChemformerTokenizer(smiles=smiles_data, tokens=tokens, regex_token_patterns=regex_tokens)
- return wrapper
- @pytest.fixture
- def setup_masker(setup_tokenizer):
- def wrapper(cls=SpanTokensMasker):
- tokenizer = setup_tokenizer()
- return tokenizer, cls(tokenizer)
- return wrapper
- @pytest.fixture
- def round_trip_params(shared_datadir):
- params = {
- "n_samples": 3,
- "beam_size": 5,
- "batch_size": 3,
- "round_trip_input_data": shared_datadir / "round_trip_input_data.csv",
- }
- return params
- @pytest.fixture
- def round_trip_namespace_args(shared_datadir):
- args = Namespace()
- args.input_data = shared_datadir / "example_data_uspto.csv"
- args.backward_predictions = shared_datadir / "example_data_backward_sampled_smiles_uspto50k.json"
- args.output_score_data = "temp_metrics.csv"
- args.dataset_part = "test"
- args.working_directory = "tests"
- args.target_column = "products"
- return args
- @pytest.fixture
- def round_trip_raw_prediction_data(shared_datadir):
- round_trip_df = pd.read_json(shared_datadir / "round_trip_predictions_raw.json", orient="table")
- round_trip_predictions = [np.array(smiles_lst) for smiles_lst in round_trip_df["round_trip_smiles"].values]
- data = {
- "sampled_smiles": round_trip_predictions,
- "target_smiles": round_trip_df["target_smiles"].values,
- }
- return data
- @pytest.fixture
- def round_trip_converted_prediction_data(shared_datadir):
- round_trip_df = pd.read_json(shared_datadir / "round_trip_predictions_converted.json", orient="table")
- round_trip_predictions = [np.array(smiles_lst) for smiles_lst in round_trip_df["round_trip_smiles"].values]
- data = {
- "sampled_smiles": round_trip_predictions,
- "target_smiles": round_trip_df["target_smiles"].values,
- }
- return data
- @pytest.fixture
- def model_batch_setup(round_trip_namespace_args):
- config = oc.OmegaConf.load("molbart/config/round_trip_inference.yaml")
- data = pd.read_csv(round_trip_namespace_args.input_data, sep="\t")
- config.d_model = 4
- config.batch_size = 3
- config.n_beams = 3
- config.n_layers = 1
- config.n_heads = 2
- config.d_feedforward = 2
- config.task = "forward_prediction"
- config.datamodule = None
- config.vocabulary_path = "bart_vocab_downstream.json"
- config.n_gpus = 0
- config.device = "cpu"
- config.data_device = "cpu"
- chemformer = Chemformer(config)
- datamodule = SynthesisDataModule(
- reactants=data["reactants"].values,
- products=data["products"].values,
- dataset_path="",
- tokenizer=chemformer.tokenizer,
- batch_size=config.batch_size,
- max_seq_len=util.DEFAULT_MAX_SEQ_LEN,
- reverse=False,
- )
- datamodule.setup()
- dataloader = datamodule.full_dataloader()
- batch_idx, batch_input = next(enumerate(dataloader))
- output_data = {
- "chemformer": chemformer,
- "tokenizer": chemformer.tokenizer,
- "batch_idx": batch_idx,
- "batch_input": batch_input,
- "max_seq_len": util.DEFAULT_MAX_SEQ_LEN,
- }
- return output_data
|