Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

service_utils.py 2.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  1. from typing import List, Tuple
  2. import numpy as np
  3. from pysmilesutils.augment import SMILESAugmenter
  4. import molbart.utils.data_utils as util
  5. from molbart.models import Chemformer
  6. from molbart.data import SynthesisDataModule
  7. def calculate_llhs(chemformer: Chemformer, reactants: List[str], products: List[str]) -> List[float]:
  8. """
  9. Calculate log-likelihood of reactant/product pairs.
  10. """
  11. datamodule = SynthesisDataModule(
  12. reactants=reactants,
  13. products=products,
  14. dataset_path="",
  15. tokenizer=chemformer.tokenizer,
  16. batch_size=chemformer.config.batch_size,
  17. max_seq_len=util.DEFAULT_MAX_SEQ_LEN,
  18. augment_prob=False,
  19. reverse=not chemformer.config.forward_prediction,
  20. )
  21. datamodule.setup()
  22. llhs = chemformer.log_likelihood(dataloader=datamodule.full_dataloader())
  23. return llhs
  24. def estimate_compound_llhs(
  25. chemformer: Chemformer,
  26. reactants: List[str],
  27. products: List[str],
  28. n_augments: int = 10,
  29. ) -> Tuple[np.ndarray, np.ndarray]:
  30. """
  31. Use SMILES augmentation to generate multiple SMILES representations of a
  32. compound and compute the log-likelihood of each SMILES.
  33. Returns the maximum log-likelihood.
  34. """
  35. augmenter = SMILESAugmenter()
  36. base_log_likelihoods = np.array(calculate_llhs(chemformer, reactants, products))
  37. all_llhs = []
  38. all_llhs.append(base_log_likelihoods[:, np.newaxis])
  39. for _ in range(n_augments - 1):
  40. if chemformer.data_args.forward_prediction:
  41. this_products = augmenter(products)
  42. this_reactants = reactants
  43. else:
  44. this_products = products
  45. this_reactants = augmenter(reactants)
  46. aug_log_likelihoods = np.array(calculate_llhs(chemformer, this_reactants, this_products))
  47. all_llhs.append(aug_log_likelihoods[:, np.newaxis])
  48. best_log_likelihoods = np.concatenate(all_llhs, axis=1)
  49. best_log_likelihoods = np.max(best_log_likelihoods, axis=1)
  50. return best_log_likelihoods
  51. def get_predictions(
  52. chemformer: Chemformer, smiles_list: List[str], n_beams: int = 10
  53. ) -> Tuple[List[List[str]], List[List[float]], List[str]]:
  54. """
  55. Predict with Chemformer on input smiles_list.
  56. """
  57. # Setting both reactants and products to smiles_list since we do
  58. # not have the "ground truth" data.
  59. datamodule = SynthesisDataModule(
  60. reactants=smiles_list,
  61. products=smiles_list,
  62. tokenizer=chemformer.tokenizer,
  63. batch_size=chemformer.config.batch_size,
  64. max_seq_len=util.DEFAULT_MAX_SEQ_LEN,
  65. dataset_path=""
  66. )
  67. datamodule.setup()
  68. chemformer.model.n_unique_beams = n_beams
  69. chemformer.model.num_beams = n_beams
  70. smiles, log_lhs, original_smiles = chemformer.predict(dataloader=datamodule.full_dataloader())
  71. return smiles, log_lhs, original_smiles
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...