Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

chemformer_disconnect_service.py 4.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
  1. import os
  2. from typing import Dict, List
  3. import numpy as np
  4. import omegaconf as oc
  5. from fastapi import FastAPI
  6. from service_utils import get_predictions
  7. from molbart.models import Chemformer
  8. from molbart.retrosynthesis.disconnection_aware import utils
  9. from molbart.retrosynthesis.disconnection_aware.disconnection_atom_mapper import (
  10. DisconnectionAtomMapper,
  11. )
  12. app = FastAPI()
  13. # Container for data, classes that can be loaded upon startup of the REST API
  14. config = oc.OmegaConf.load("../molbart/config/predict.yaml")
  15. config.batch_size = 64
  16. config.n_gpus = 1
  17. config.model_path = os.environ["CHEMFORMER_DISCONNECTION_MODEL"]
  18. config.model_type = "bart"
  19. config.n_beams = 10
  20. config.n_unique_beams = 10 # Make sure we output unique predictions
  21. config.task = os.environ["CHEMFORMER_TASK"]
  22. config.vocabulary_path = os.environ["CHEMFORMER_VOCAB"]
  23. config.datamodule = None
  24. CONDA_PATH = None
  25. RXNUTILS_ENV_PATH = None
  26. if "CONDA_PATH" in os.environ:
  27. CONDA_PATH = os.environ["CONDA_PATH"]
  28. if "RXNUTILS_ENV_PATH" in os.environ:
  29. RXNUTILS_ENV_PATH = os.environ["RXNUTILS_ENV_PATH"]
  30. MODELS = {
  31. "chemformer_disconnect": Chemformer(config),
  32. "atom_mapper": DisconnectionAtomMapper(),
  33. }
  34. def _get_n_predictions(predicted_reactants: List[List[str]]):
  35. return [len(smiles_list) for smiles_list in predicted_reactants]
  36. def _reshape(smiles_list: List[str], n_predictions: List[int]):
  37. reshaped_smiles_list = []
  38. counter = 0
  39. for n_pred in n_predictions:
  40. all_predictions = [smiles for smiles in smiles_list[counter : counter + n_pred]]
  41. counter += n_pred
  42. reshaped_smiles_list.append(all_predictions)
  43. return reshaped_smiles_list
  44. @app.post("/chemformer-disconnect-api/predict-disconnection")
  45. def predict_disconnection(smiles_list: List[str], bonds_list: List[List[int]], n_beams: int = 10) -> List[Dict]:
  46. """
  47. Make prediction with disconnection-Chemformer given list of input SMILES and
  48. corresponding list of bonds to break [one bond per input SMILES].
  49. Returns the basic predictions and input product (with new atom-mapping)
  50. for each bond in each product. Tailored to the multi-step disconnection
  51. approach in aizynthfinder.
  52. Args:
  53. smiles_list: batch of input SMILES to model
  54. bonds: list of bonds to break for each input SMILES (one bond per molecule)
  55. n_beams: number of beams in beam search
  56. """
  57. # Get input SMILES to the prediction and tag SMILES using the corresponding bonds
  58. # for that input.
  59. smiles_atom_map_tagged = [
  60. MODELS["atom_mapper"].tag_current_bond(smiles, bond_atom_inds)
  61. for smiles, bond_atom_inds in zip(smiles_list, bonds_list)
  62. ]
  63. smiles_tagged_list = utils.get_model_input(
  64. smiles_atom_map_tagged,
  65. rxnutils_env_path=RXNUTILS_ENV_PATH,
  66. conda_path=CONDA_PATH,
  67. )
  68. output = []
  69. predicted_smiles, log_lhs, _ = get_predictions(MODELS["chemformer_disconnect"], smiles_tagged_list, n_beams)
  70. n_predictions = _get_n_predictions(predicted_smiles)
  71. # Get atom-mapping of predicted reaction
  72. mapped_rxns, _ = MODELS["atom_mapper"].predictions_atom_mapping(smiles_list, predicted_smiles)
  73. reactants_mapped = np.array([mapped_rxn.split(">")[0] for mapped_rxn in mapped_rxns])
  74. product_new_mapping = np.array([mapped_rxn.split(">")[-1] for mapped_rxn in mapped_rxns])
  75. output = []
  76. for item_pred, item_lhs, item_smiles, item_mapped_product, item_bond in zip(
  77. _reshape(reactants_mapped, n_predictions),
  78. log_lhs,
  79. smiles_list,
  80. _reshape(product_new_mapping, n_predictions),
  81. bonds_list,
  82. ):
  83. output.append(
  84. {
  85. "input": item_smiles,
  86. "output": list(item_pred),
  87. "lhs": [float(val) for val in item_lhs],
  88. "product_new_mapping": list(item_mapped_product),
  89. "current_bond": item_bond,
  90. }
  91. )
  92. return output
  93. if __name__ == "__main__":
  94. import uvicorn
  95. uvicorn.run(
  96. "chemformer_disconnect_service:app",
  97. host="0.0.0.0",
  98. port=8023,
  99. log_level="info",
  100. reload=False,
  101. )
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...