Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

score_collection.py 3.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
  1. """ Module containing classes used to score the reaction routes.
  2. """
  3. from __future__ import annotations
  4. import logging
  5. from typing import TYPE_CHECKING
  6. import yaml
  7. from omegaconf import ListConfig, OmegaConf
  8. from molbart.utils.base_collection import BaseCollection
  9. from molbart.utils.scores import BaseScore
  10. from molbart.utils.scores.scores import __name__ as score_module
  11. if TYPE_CHECKING:
  12. from typing import Any, Dict, List, Optional
  13. class ScoreCollection(BaseCollection):
  14. """
  15. Store score objects for the chemformer model.
  16. The scores can be obtained by name
  17. .. code-block::
  18. scores = ScoreCollection()
  19. score = scores['TopKAccuracy']
  20. """
  21. _collection_name = "scores"
  22. def __init__(self) -> None:
  23. super().__init__()
  24. self._logger = logging.Logger("score-collection")
  25. def __repr__(self) -> str:
  26. if self.selection:
  27. return f"{self._collection_name} ({', '.join(self.selection)})"
  28. return f"{self._collection_name} ({', '.join(self.items)})"
  29. def load(self, score: BaseScore) -> None: # type: ignore
  30. """
  31. Add a pre-initialized score object to the collection
  32. Args:
  33. score: the item to add
  34. """
  35. if not isinstance(score, BaseScore):
  36. raise ValueError("Only objects of classes inherited from " "molbart.scores.BaseScore can be added")
  37. self._items[repr(score)] = score
  38. self._logger.info(f"Loaded score: {repr(score)}")
  39. def load_from_config(self, scores_config: ListConfig) -> None:
  40. """
  41. Load one or several scores from a configuration dictionary
  42. The keys are the name of score class. If a score is not
  43. defined in the ``molbart.utils.scores.scores`` module, the module
  44. name can be appended, e.g. ``mypackage.scoring.AwesomeScore``.
  45. The values of the configuration is passed directly to the score
  46. class along with the ``config`` parameter.
  47. Args:
  48. scores_config: Config of scores
  49. """
  50. for item in scores_config:
  51. if isinstance(item, str):
  52. cls = self.load_dynamic_class(item, score_module)
  53. obj = cls()
  54. config_str = ""
  55. else:
  56. item = [(key, item[key]) for key in item.keys()][0]
  57. name, kwargs = item
  58. x = yaml.load(OmegaConf.to_yaml(kwargs), Loader=yaml.SafeLoader)
  59. kwargs = self._unravel_list_dict(x)
  60. cls = self.load_dynamic_class(name, score_module)
  61. obj = cls(**kwargs)
  62. config_str = f" with configuration '{kwargs}'"
  63. self._items[repr(obj)] = obj
  64. print(f"Loaded score: '{repr(obj)}'{config_str}")
  65. def score(self, sampled_smiles: List[List[str]], target_smiles: Optional[List[str]] = None) -> Dict[str, Any]:
  66. """
  67. Apply all scorers in collection to the given sampled and target SMILES.
  68. Args:
  69. sampled_smiles: top-N SMILES sampled by a model, such as Chemformer.
  70. target_smiles: ground truth SMILES.
  71. Returns:
  72. A dictionary with all the scores.
  73. """
  74. scores = []
  75. for score in self._items.values():
  76. scores.append(score(sampled_smiles, target_smiles))
  77. return self._unravel_list_dict(scores)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...