Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

predict.py 1.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
  1. import hydra
  2. import pandas as pd
  3. import molbart.utils.data_utils as util
  4. from molbart.models import Chemformer
  5. def write_predictions(args, smiles, log_lhs, original_smiles):
  6. num_data = len(smiles)
  7. beam_width = len(smiles[0])
  8. beam_outputs = [[[]] * num_data for _ in range(beam_width)]
  9. beam_log_lhs = [[[]] * num_data for _ in range(beam_width)]
  10. for b_idx, (smiles_beams, log_lhs_beams) in enumerate(zip(smiles, log_lhs)):
  11. for beam_idx, (smi, log_lhs) in enumerate(zip(smiles_beams, log_lhs_beams)):
  12. beam_outputs[beam_idx][b_idx] = smi
  13. beam_log_lhs[beam_idx][b_idx] = log_lhs
  14. df_data = {"target_smiles": original_smiles}
  15. for beam_idx in range(beam_width):
  16. df_data["sampled_smiles_" + str(beam_idx + 1)] = beam_outputs[beam_idx]
  17. for beam_idx in range(beam_width):
  18. df_data["loglikelihood_" + str(beam_idx + 1)] = beam_log_lhs[beam_idx]
  19. df = pd.DataFrame(data=df_data)
  20. df.to_csv(args.output_sampled_smiles, sep="\t", index=False)
  21. @hydra.main(version_base=None, config_path="config", config_name="predict")
  22. def main(args):
  23. chemformer = Chemformer(args)
  24. print("Making predictions...")
  25. smiles, log_lhs, original_smiles = chemformer.predict(
  26. dataset=args.dataset_part,
  27. )
  28. write_predictions(args, smiles, log_lhs, original_smiles)
  29. print("Finished predictions.")
  30. return
  31. if __name__ == "__main__":
  32. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...