Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

round_trip_utils.py 3.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
  1. from typing import Any, Dict, List, Tuple, Union
  2. import numpy as np
  3. from molbart.models import Chemformer
  4. def compute_round_trip_accuracy(
  5. chemformer: Chemformer,
  6. sampled_smiles: List[np.ndarray],
  7. target_smiles: List[List[str]],
  8. ) -> List[Dict[str, Any]]:
  9. """
  10. Calculating (round-trip) accuracy given sampled and target SMILES (products).
  11. Args:
  12. chemformer: a Chemformer model with a decoder sampler
  13. sampled_smiles: product SMILES sampled by forward Chemformer
  14. target_smiles: ground truth product SMILES
  15. """
  16. print("Evaluating predictions.")
  17. metrics_out = []
  18. batch_idx = 0
  19. for sampled_batch, target_batch in zip(sampled_smiles, target_smiles):
  20. metrics = chemformer.model.sampler.compute_sampling_metrics(
  21. sampled_batch,
  22. target_batch,
  23. is_canonical=False,
  24. )
  25. metrics.update({"sampled_molecules": sampled_batch, "target_smiles": target_batch})
  26. metrics_out.append(metrics)
  27. batch_idx += 1
  28. return metrics_out
  29. def batchify(smiles_lst: Union[List[str], np.ndarray], batch_size: int) -> Union[List[List[str]], List[np.ndarray]]:
  30. """
  31. Create batches given an input list of SMILES or list of list of SMILES.
  32. Args:
  33. smiles_list: list of SMILES
  34. batch_size: number of samples in batch
  35. Returns:
  36. batched SMILES in a list
  37. """
  38. n_samples = len(smiles_lst)
  39. n_batches = int(np.ceil(n_samples / batch_size))
  40. batched_smiles = []
  41. for i_batch in range(n_batches):
  42. if i_batch != n_batches - 1:
  43. batched_smiles.append(smiles_lst[i_batch * batch_size : (i_batch + 1) * batch_size])
  44. else:
  45. batched_smiles.append(smiles_lst[i_batch * batch_size : :])
  46. return batched_smiles
  47. def convert_to_input_format(
  48. sampled_smiles: List[List[str]],
  49. target_smiles: List[List[str]],
  50. sampled_data_params: Dict[str, Any],
  51. n_chunks: int = 1,
  52. ) -> Tuple[List[np.ndarray], List[List[str]]]:
  53. """
  54. Converting sampled data to original input format such that,
  55. sampled_smiles: [n_batches, batch_size, n_beams],
  56. target_smiles: [n_batches, batch_size, 1].
  57. Args:
  58. sampled_smiles: SMILES sampled in round-trip inference
  59. target_smiles: target SMILES (ground truth product)
  60. sampled_data_params: parameters of the input data from backward predictions
  61. (batch_size, beam_size, n_samples)
  62. Returns:
  63. Reshaped round-trip predictions.
  64. """
  65. batch_size = sampled_data_params["batch_size"]
  66. n_beams = sampled_data_params["beam_size"]
  67. n_samples = sampled_data_params["n_samples"]
  68. sampled_smiles = np.array(sampled_smiles)
  69. target_smiles = np.array(target_smiles)
  70. sampled_smiles = np.reshape(sampled_smiles, (-1, n_beams))
  71. target_smiles = np.reshape(target_smiles, (-1, n_beams))
  72. if n_chunks == 1:
  73. assert target_smiles.shape[0] == n_samples
  74. # Sanity-check that target smiles are the same within beams
  75. for tgt_beams in target_smiles:
  76. assert np.all(tgt_beams == tgt_beams[0])
  77. # Extract the target smiles for each original sample
  78. target_smiles = [tgt_smi[0] for tgt_smi in target_smiles]
  79. smpl_smiles_reform = batchify(sampled_smiles, batch_size)
  80. tgt_smiles_reform = batchify(target_smiles, batch_size)
  81. return smpl_smiles_reform, tgt_smiles_reform
  82. def set_output_files(args, chemformer):
  83. if args.output_score_data and args.output_sampled_smiles:
  84. for callback in chemformer.trainer.callbacks:
  85. if hasattr(callback, "set_output_files"):
  86. callback.set_output_files(args.output_score_data, args.output_sampled_smiles)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...