Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_inference.py 4.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
  1. import logging
  2. from contextlib import contextmanager
  3. from rdkit import Chem
  4. from hydra import compose, initialize
  5. from nemo_chem.models.megamolbart import NeMoMegaMolBARTWrapper
  6. log = logging.getLogger(__name__)
  7. _INFERER = None
  8. @contextmanager
  9. def load_model(inf_cfg):
  10. global _INFERER
  11. if _INFERER is None:
  12. _INFERER = NeMoMegaMolBARTWrapper(model_cfg=inf_cfg)
  13. yield _INFERER
  14. def test_smis_to_hiddens():
  15. with initialize(config_path="../examples/chem/conf"):
  16. cfg = compose(config_name="infer")
  17. with load_model(cfg) as inferer:
  18. smis = ['c1cc2ccccc2cc1',
  19. 'COc1cc2nc(N3CCN(C(=O)c4ccco4)CC3)nc(N)c2cc1OC',
  20. 'CC(=O)C(=O)N1CCC([C@H]2CCCCN2C(=O)c2ccc3c(n2)CCN(C(=O)OC(C)(C)C)C3)CC1']
  21. hidden_state, pad_masks = inferer.smis_to_hidden(smis)
  22. assert hidden_state is not None
  23. assert hidden_state.shape[0] == len(smis)
  24. assert hidden_state.shape[2] == inferer.cfg.max_position_embeddings
  25. assert pad_masks is not None
  26. def test_smis_to_embedding():
  27. with initialize(config_path="../examples/chem/conf"):
  28. cfg = compose(config_name="infer")
  29. with load_model(cfg) as inferer:
  30. smis = ['c1cc2ccccc2cc1',
  31. 'COc1cc2nc(N3CCN(C(=O)c4ccco4)CC3)nc(N)c2cc1OC',
  32. 'CC(=O)C(=O)N1CCC([C@H]2CCCCN2C(=O)c2ccc3c(n2)CCN(C(=O)OC(C)(C)C)C3)CC1']
  33. embedding = inferer.smis_to_embedding(smis)
  34. assert embedding is not None
  35. assert embedding.shape[0] == len(smis)
  36. assert embedding.shape[1] == inferer.cfg.max_position_embeddings
  37. def test_hidden_to_smis():
  38. with initialize(config_path="../examples/chem/conf"):
  39. cfg = compose(config_name="infer")
  40. with load_model(cfg) as inferer:
  41. smis = ['c1cc2ccccc2cc1',
  42. 'COc1cc2nc(N3CCN(C(=O)c4ccco4)CC3)nc(N)c2cc1OC',
  43. 'CC(=O)C(=O)N1CCC([C@H]2CCCCN2C(=O)c2ccc3c(n2)CCN(C(=O)OC(C)(C)C)C3)CC1']
  44. hidden_state, pad_masks = inferer.smis_to_hidden(smis)
  45. infered_smis = inferer.hidden_to_smis(hidden_state, pad_masks)
  46. log.info(f'Input SMILES and Infered: {smis}, {infered_smis}')
  47. assert(len(infered_smis) == len(smis))
  48. for smi, infered_smi in zip(smis, infered_smis):
  49. log.info(f'Input and Infered:{smi}, {infered_smi}')
  50. input_mol = Chem.MolFromSmiles(smi)
  51. infer_mol = Chem.MolFromSmiles(infered_smi)
  52. assert input_mol is not None and infer_mol is not None
  53. canonical_smi = Chem.MolToSmiles(input_mol, canonical=True)
  54. canonical_infered_smi = Chem.MolToSmiles(infer_mol, canonical=True)
  55. log.info(f'Canonical Input and Infered: {canonical_smi}, {canonical_infered_smi}')
  56. assert(canonical_smi == canonical_infered_smi)
  57. def test_sample():
  58. with initialize(config_path="../examples/chem/conf"):
  59. cfg = compose(config_name="infer")
  60. with load_model(cfg) as inferer:
  61. smis = ['c1cc2ccccc2cc1',
  62. 'COc1cc2nc(N3CCN(C(=O)c4ccco4)CC3)nc(N)c2cc1OC',
  63. 'CC(=O)C(=O)N1CCC([C@H]2CCCCN2C(=O)c2ccc3c(n2)CCN(C(=O)OC(C)(C)C)C3)CC1']
  64. samples = inferer.sample(smis, num_samples=10, sampling_method='greedy-perturbate')
  65. samples = set(samples)
  66. log.info('\n'.join(smis))
  67. log.info('\n'.join(samples))
  68. valid_molecules = []
  69. for smi in set(samples):
  70. isvalid = False
  71. mol = Chem.MolFromSmiles(smi)
  72. if mol:
  73. isvalid = True
  74. valid_molecules.append(smi)
  75. log.info(f'Sample: {smi}, {isvalid}')
  76. log.info('Valid Molecules' + "\n".join(valid_molecules))
  77. log.info(f'Total samples = {len(samples)} unique samples {len(set(samples))} valids {len(valid_molecules)}')
  78. if len(valid_molecules) < len(samples) * 0.3:
  79. log.warning("TOO FEW VALID SAMPLES")
  80. assert len(valid_molecules) != 0
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...