Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_atom_mapper.py 2.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
  1. import pytest
  2. from rdkit import Chem
  3. from molbart.retrosynthesis.disconnection_aware.disconnection_atom_mapper import (
  4. DisconnectionAtomMapper,
  5. )
  6. from molbart.retrosynthesis.disconnection_aware.utils import (
  7. verify_disconnection,
  8. )
  9. @pytest.mark.parametrize(
  10. ("reactants_smiles", "expected"),
  11. [
  12. (
  13. "[Cl:2].[CH:1]1=[CH:7][CH:6]=[CH:5][CH:4]=[CH:3]1",
  14. {2: 0, 1: 1, 7: 2, 6: 3, 5: 4, 4: 5, 3: 6},
  15. ),
  16. (
  17. "[Cl:6].[CH:1]1=[CH:17][CH:2]=[CH:5][CH:24]=[CH:3]1",
  18. {6: 0, 1: 1, 17: 2, 2: 3, 5: 4, 24: 5, 3: 6},
  19. ),
  20. ],
  21. )
  22. def test_mapping_to_index(reactants_smiles, expected):
  23. mapper = DisconnectionAtomMapper()
  24. mapping2idx = mapper.mapping_to_index(Chem.MolFromSmiles(reactants_smiles))
  25. assert mapping2idx == expected
  26. def test_remove_atom_mapping():
  27. mapper = DisconnectionAtomMapper()
  28. smiles = "[CH:1]1=[CH:7][CH:6]=[CH:5][CH:4]=[CH:3]1"
  29. assert mapper.remove_atom_mapping(smiles) == "c1ccccc1"
  30. @pytest.mark.parametrize(
  31. ("reactants", "product_new_mapping", "product_old_mapping", "expected"),
  32. [
  33. (
  34. "[Cl:2].[CH:1]1=[CH:7][CH:6]=[CH:5][CH:4]=[CH:3]1",
  35. "[Cl:2][C:1]1=[CH:7][CH:6]=[CH:5][CH:4]=[CH:3]1",
  36. "[Cl:5][C:3]1=[CH:15][CH:1]=[CH:2][CH:7]=[CH:16]1",
  37. "[Cl:5].[cH:1]1[cH:2][cH:7][cH:16][cH:3][cH:15]1",
  38. ),
  39. (
  40. "[CH:1]1=[CH:7][CH:6]=[CH:5][CH:4]=[CH:3]1",
  41. "[Cl:2][C:1]1=[CH:7][CH:6]=[CH:5][CH:4]=[CH:3]1",
  42. "[Cl:5][C:3]1=[CH:15][CH:1]=[CH:7][CH:8]=[CH:16]1",
  43. "[cH:1]1[cH:7][cH:8][cH:16][cH:3][cH:15]1",
  44. ),
  45. ],
  46. )
  47. def test_input_mapping_to_reactants(reactants, product_new_mapping, product_old_mapping, expected):
  48. mapper = DisconnectionAtomMapper()
  49. assert mapper.propagate_input_mapping_to_reactants(product_old_mapping, reactants, product_new_mapping) == expected
  50. @pytest.mark.parametrize(
  51. ("product_mapping", "bond_atom_inds", "expected"),
  52. [
  53. (
  54. "[Cl:5][C:3]1=[CH:15][CH:1]=[CH:2][CH:6]=[CH:16]1",
  55. [1, 15],
  56. "Clc1ccc[cH:1][cH:1]1",
  57. ),
  58. (
  59. "[Cl:5][C:3]1=[CH:15][CH:1]=[CH:2][CH:6]=[CH:16]1",
  60. [5, 3],
  61. "c1cc[c:1]([Cl:1])cc1",
  62. ),
  63. ],
  64. )
  65. def test_tag_current_bond(product_mapping, bond_atom_inds, expected):
  66. mapper = DisconnectionAtomMapper()
  67. assert mapper.tag_current_bond(product_mapping, bond_atom_inds) == expected
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...