Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

round_trip_inference.py 3.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  1. """Module for running round-trip inference and accuracy scoring of backward predictions
  2. using a forward Chemformer model"""
  3. import subprocess
  4. import tempfile
  5. from argparse import Namespace
  6. from typing import Any, Dict, List, Tuple
  7. import hydra
  8. import pandas as pd
  9. import pytorch_lightning as pl
  10. import molbart.utils.data_utils as util
  11. from molbart.models import Chemformer
  12. from molbart.retrosynthesis import round_trip_utils as rt_utils
  13. def create_round_trip_dataset(args: Namespace) -> Tuple[Namespace, Dict[str, Any]]:
  14. """
  15. Reading sampled smiles and creating dataframe on synthesis-datamodule format.
  16. Args:
  17. args: Input arguments with parameters for Chemformer, data paths etc.
  18. Returns:
  19. updated arguments and input-data metadata dictionary
  20. """
  21. print("Creating input data from sampled predictions.")
  22. _, round_trip_input_data = tempfile.mkstemp(suffix=".csv")
  23. input_data = pd.read_csv(args.input_data, sep="\t")
  24. input_data = input_data.iloc[input_data["set"].values == args.dataset_part]
  25. target_column = args.target_column
  26. input_targets = input_data[target_column].values
  27. predicted_data = pd.read_json(args.backward_predictions, orient="table")
  28. batch_size = len(predicted_data["sampled_molecules"].values[0])
  29. n_samples = sum([len(batch_smiles) for batch_smiles in predicted_data["sampled_molecules"].values])
  30. n_beams = len(predicted_data["sampled_molecules"].values[0][0])
  31. sampled_data_params = {
  32. "n_samples": n_samples,
  33. "beam_size": n_beams,
  34. "batch_size": batch_size,
  35. "round_trip_input_data": round_trip_input_data,
  36. }
  37. counter = 0
  38. sampled_smiles = []
  39. target_smiles = []
  40. # Unravel predictions
  41. for batch_smiles in predicted_data["sampled_molecules"].values:
  42. for top_n_smiles in batch_smiles:
  43. sampled_smiles.extend(top_n_smiles)
  44. target_smiles.extend([input_targets[counter] for _ in range(n_beams)])
  45. counter += 1
  46. input_data = pd.DataFrame(
  47. {
  48. "reactants": sampled_smiles,
  49. "products": target_smiles,
  50. "set": len(target_smiles) * ["test"],
  51. }
  52. )
  53. print(f"Writing data to temporary file: {round_trip_input_data}")
  54. input_data.to_csv(round_trip_input_data, sep="\t", index=False)
  55. args.data_path = round_trip_input_data
  56. return args, sampled_data_params
  57. def _run_test_callbacks(chemformer: Chemformer, metrics_scores: List[Dict[str, Any]]) -> None:
  58. """Run callback.on_test_batch_end on all (scoring) callbacks."""
  59. for batch_idx, scores in enumerate(metrics_scores):
  60. for callback in chemformer.trainer.callbacks:
  61. if not isinstance(callback, pl.callbacks.progress.ProgressBar):
  62. callback.on_test_batch_end(chemformer.trainer, chemformer.model, scores, {}, batch_idx, 0)
  63. @hydra.main(version_base=None, config_path="../config", config_name="round_trip_inference")
  64. def main(args) -> None:
  65. util.seed_everything(args.seed)
  66. args, sampled_data_params = create_round_trip_dataset(args)
  67. chemformer = Chemformer(args)
  68. rt_utils.set_output_files(args, chemformer)
  69. print("Running round-trip inference.")
  70. sampled_smiles, log_lhs, target_smiles = chemformer.predict()
  71. # Reformat on original shape [n_batches, batch_size, n_beams]
  72. sampled_smiles, target_smiles = rt_utils.convert_to_input_format(
  73. sampled_smiles, target_smiles, sampled_data_params, args.n_chunks
  74. )
  75. metrics = rt_utils.compute_round_trip_accuracy(chemformer, sampled_smiles, target_smiles)
  76. _run_test_callbacks(chemformer, metrics)
  77. print(f"Removing temporary file: {sampled_data_params['round_trip_input_data']}")
  78. subprocess.check_output(["rm", sampled_data_params["round_trip_input_data"]])
  79. print("Round-trip inference done!")
  80. return
  81. if __name__ == "__main__":
  82. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...