Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

chemformer_service.py 2.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
  1. import os
  2. from typing import List
  3. import omegaconf as oc
  4. from fastapi import FastAPI
  5. from service_utils import calculate_llhs, estimate_compound_llhs, get_predictions
  6. from molbart.models import Chemformer
  7. app = FastAPI()
  8. # Container for data, classes that can be loaded upon startup of the REST API
  9. config = oc.OmegaConf.load("../molbart/config/predict.yaml")
  10. config.batch_size = 64
  11. config.n_gpus = 1
  12. config.model_path = os.environ["CHEMFORMER_MODEL"]
  13. config.model_type = "bart"
  14. config.n_beams = 10
  15. config.task = os.environ["CHEMFORMER_TASK"]
  16. config.vocabulary_path = os.environ["CHEMFORMER_VOCAB"]
  17. config.datamodule = None
  18. global_items = {"chemformer": Chemformer(config)}
  19. @app.post("/chemformer-api/predict")
  20. def predict(smiles_list: List[str], n_beams: int = 10):
  21. smiles, log_lhs, original_smiles = get_predictions(global_items["chemformer"], smiles_list, n_beams)
  22. output = []
  23. for item_pred, item_lhs, item_smiles in zip(smiles, log_lhs, original_smiles):
  24. output.append(
  25. {
  26. "input": item_smiles,
  27. "output": list(item_pred),
  28. "lhs": [float(val) for val in item_lhs],
  29. }
  30. )
  31. return output
  32. @app.post("/chemformer-api/log_likelihood")
  33. def log_likelihood(reactants: List[str], products: List[str]):
  34. log_lhs = calculate_llhs(global_items["chemformer"], reactants, products)
  35. output = []
  36. for prod_smi, react_smi, llhs in zip(products, reactants, log_lhs):
  37. output.append(
  38. {
  39. "product_smiles": str(prod_smi),
  40. "reactant_smiles": str(react_smi),
  41. "log_likelihood": float(llhs),
  42. }
  43. )
  44. return output
  45. @app.post("/chemformer-api/compound_log_likelihood")
  46. def compound_log_likelihood(reactants: List[str], products: List[str], n_augments: int = 10):
  47. log_lhs = estimate_compound_llhs(global_items["chemformer"], reactants, products, n_augments=n_augments)
  48. output = []
  49. for prod_smi, react_smi, llhs in zip(products, reactants, log_lhs):
  50. output.append(
  51. {
  52. "product_smiles": str(prod_smi),
  53. "reactant_smiles": str(react_smi),
  54. "log_likelihood": float(llhs),
  55. }
  56. )
  57. return output
  58. if __name__ == "__main__":
  59. import uvicorn
  60. uvicorn.run(
  61. "chemformer_service:app",
  62. host="0.0.0.0",
  63. port=8003,
  64. log_level="info",
  65. reload=False,
  66. )
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...