Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

disconnection_atom_mapper.py 6.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
  1. """Module containing atom-mapping functionality needed to run disconnection-Chemformer"""
  2. import numpy as np
  3. from rdkit import Chem
  4. RXN_MAPPER_ENV_OK = True
  5. try:
  6. from rxnmapper import RXNMapper
  7. except ImportError:
  8. RXN_MAPPER_ENV_OK = False
  9. from typing import Dict, List, Sequence, Tuple
  10. class DisconnectionAtomMapper:
  11. """Class for handling atom-mapping routines of multi-step disconnection-Chemformer"""
  12. def __init__(self):
  13. if RXN_MAPPER_ENV_OK:
  14. self.rxn_mapper = RXNMapper()
  15. def mapping_to_index(self, mol: Chem.rdchem.Mol) -> Dict[int, int]:
  16. """
  17. Atom-map-num to index mapping.
  18. Args:
  19. mol: rdkit Molecule
  20. Returns
  21. a dictionary which maps atom-map-number to atom-index"""
  22. mapping = {atom.GetAtomMapNum(): atom.GetIdx() for atom in mol.GetAtoms() if atom.GetAtomMapNum()}
  23. return mapping
  24. def predictions_atom_mapping(
  25. self, smiles_list: List[str], predicted_smiles_list: List[np.ndarray]
  26. ) -> Tuple[np.ndarray, np.ndarray]:
  27. """
  28. Create atom-mapping on the predicted reactions using RXN-mapper.
  29. Requires RXNMapper to be installed.
  30. Args:
  31. rxn_mapper: RXN-mapper model
  32. smiles_list: batch of input product SMILES to predict atom-mapping on
  33. predicted_smiles_list: batch of predicted reactant SMILES
  34. Returns:
  35. the atom-mapped reactions and the mapping confidence
  36. """
  37. if not RXN_MAPPER_ENV_OK:
  38. raise ImportError("rxnmapper has to be installed in the environment!")
  39. rxn_smiles_list = []
  40. for product_smiles_mapped, reactants_smiles in zip(smiles_list, predicted_smiles_list):
  41. product_smiles = self.remove_atom_mapping(product_smiles_mapped)
  42. rxn_smiles_list.extend(self._reaction_smiles_lst(product_smiles, reactants_smiles))
  43. mapped_rxns = self.rxn_mapper.get_attention_guided_atom_maps(rxn_smiles_list, canonicalize_rxns=False)
  44. atom_map_confidences = np.array([rxnmapper_output["confidence"] for rxnmapper_output in mapped_rxns])
  45. mapped_rxns = np.array([rxnmapper_output["mapped_rxn"] for rxnmapper_output in mapped_rxns])
  46. return mapped_rxns, atom_map_confidences
  47. def propagate_input_mapping_to_reactants(
  48. self,
  49. product_input_mapping: str,
  50. predicted_reactants: str,
  51. product_new_mapping: str,
  52. ) -> str:
  53. """
  54. Propagate old atom-mapping to reactants using the new atom-mapping.
  55. Args:
  56. product_input_mapping: input product.
  57. predicted_reactants: predicted_reactants without atom-mapping.
  58. product_new_mapping: product with new mapping from rxn-mapper.
  59. Returns:
  60. reactant SMILES with the same atom-mapping as the input product.
  61. """
  62. product_input_mapping = self._canonicalize_mapped(product_input_mapping)
  63. product_new_mapping = self._canonicalize_mapped(product_new_mapping)
  64. mol_input_mapping = Chem.MolFromSmiles(product_input_mapping)
  65. mol_new_mapping = Chem.MolFromSmiles(product_new_mapping)
  66. reactants_mol = Chem.MolFromSmiles(predicted_reactants)
  67. reactants_map_to_index = self.mapping_to_index(reactants_mol)
  68. predicted_reactants = self.remove_atom_mapping(predicted_reactants, canonical=False)
  69. reactants_mol = Chem.MolFromSmiles(predicted_reactants)
  70. for atom_idx, atom_input in enumerate(mol_input_mapping.GetAtoms()):
  71. atom_new_mapping = mol_new_mapping.GetAtomWithIdx(atom_idx)
  72. atom_map_num_input = atom_input.GetAtomMapNum()
  73. atom_map_num_new_mapping = atom_new_mapping.GetAtomMapNum()
  74. try:
  75. atom_reactant = reactants_mol.GetAtomWithIdx(reactants_map_to_index[atom_map_num_new_mapping])
  76. atom_reactant.SetAtomMapNum(atom_map_num_input)
  77. except KeyError:
  78. continue
  79. return Chem.MolToSmiles(reactants_mol)
  80. def remove_atom_mapping(self, smiles_mapped: str, canonical: bool = True) -> str:
  81. """
  82. Remove atom-mapping from SMILES.
  83. Args:
  84. smiles_mapped: SMILES with atom-mapping
  85. canonical: whether to canonicalize the output SMILES
  86. Returns:
  87. SMILES without atom-mapping
  88. """
  89. mol = Chem.MolFromSmiles(smiles_mapped)
  90. for atom in mol.GetAtoms():
  91. atom.SetAtomMapNum(0)
  92. return Chem.MolToSmiles(mol, canonical=canonical)
  93. def tag_current_bond(self, product_smiles: str, bond_inds: Sequence[int]) -> str:
  94. """
  95. Remove atom-tagging on all atoms except those in bonds_inds.
  96. Tag bond_inds atoms as [<atom>:1] where <atom> is any atom.
  97. Args:
  98. mol: (product) SMILES with atom-mapping
  99. bond_inds: atom indices involved in current bond to break
  100. Returns:
  101. atom-map tagged SMILES
  102. """
  103. mol = Chem.MolFromSmiles(product_smiles)
  104. for atom in mol.GetAtoms():
  105. if atom.GetAtomMapNum() in bond_inds:
  106. atom.SetAtomMapNum(1)
  107. else:
  108. atom.SetAtomMapNum(0)
  109. return Chem.MolToSmiles(mol)
  110. def _canonicalize_mapped(self, smiles_mapped: str) -> str:
  111. smiles = self.remove_atom_mapping(smiles_mapped, canonical=False)
  112. mol_mapped = Chem.MolFromSmiles(smiles_mapped)
  113. mol_unmapped = Chem.MolFromSmiles(smiles)
  114. _, canonical_atom_order = tuple(
  115. zip(*sorted([(j, i) for i, j in enumerate(Chem.CanonicalRankAtoms(mol_unmapped))]))
  116. )
  117. mol_mapped = Chem.RenumberAtoms(mol_mapped, canonical_atom_order)
  118. return Chem.MolToSmiles(mol_mapped, canonical=False)
  119. def _reaction_smiles_lst(self, product_smiles: str, reactants_smiles: np.ndarray) -> List[str]:
  120. """
  121. Construct the reaction smiles given product and reactant SMILES.
  122. Args:
  123. product_smiles: input product SMILES
  124. reactants_smiles: list of predicted reactant SMILES
  125. Returns:
  126. list of reaction SMILES
  127. """
  128. rxn_smiles = [f"{reactants}>>{product_smiles}" for reactants in reactants_smiles]
  129. return rxn_smiles
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...