Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

base_collection.py 2.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
  1. """ Module containing classes used to score the reaction routes.
  2. """
  3. from __future__ import annotations
  4. import importlib
  5. from typing import TYPE_CHECKING
  6. from omegaconf import DictConfig
  7. if TYPE_CHECKING:
  8. from typing import Any, Dict, List, Optional
  9. class BaseCollection:
  10. """
  11. Base class for collection classes (callback collection, score collection, etc.).
  12. """
  13. _collection_name = "base"
  14. def __init__(self) -> None:
  15. self._items: Dict[str, Any] = {}
  16. def __repr__(self) -> str:
  17. return f"{self._collection_name} ({', '.join(self.names)})"
  18. def load_from_config(self, config: DictConfig) -> None:
  19. """
  20. Load one or several items (e.g. score, callback, etc.) from a configuration dictionary
  21. The keys are the name of item class. If an item is not
  22. defined in the ``molbart.utils.items.items`` module, the module
  23. name can be appended, e.g. ``mypackage.item.AwesomeItem``.
  24. """
  25. raise NotImplementedError("BaseCollection.load_from_config() not implemented.")
  26. def names(self) -> List[str]:
  27. """Return a list of the names of all the loaded items"""
  28. return list(self._items.keys())
  29. def objects(self) -> List[Any]:
  30. """Return a list of all the loaded items"""
  31. return list(self._items.values())
  32. @staticmethod
  33. def load_dynamic_class(
  34. name_spec: str,
  35. default_module: Optional[str] = None,
  36. exception_cls: Any = ValueError,
  37. ) -> Any:
  38. """
  39. Load an object from a dynamic specification.
  40. The specification can be either:
  41. ClassName, in-case the module name is taken from the `default_module` argument
  42. or
  43. package_name.module_name.ClassName, in-case the module is taken as `package_name.module_name`
  44. Args:
  45. name_spec: the class specification
  46. default_module: the default module
  47. exception_cls: the exception class to raise on exception
  48. Returns
  49. the loaded class
  50. """
  51. if "." not in name_spec:
  52. name = name_spec
  53. if not default_module:
  54. raise exception_cls("Must provide default_module argument if not given in name_spec")
  55. module_name = default_module
  56. else:
  57. module_name, name = name_spec.rsplit(".", maxsplit=1)
  58. try:
  59. loaded_module = importlib.import_module(module_name)
  60. except ImportError:
  61. raise exception_cls(f"Unable to load module: {module_name}")
  62. if not hasattr(loaded_module, name):
  63. raise exception_cls(f"Module ({module_name}) does not have a class called {name}")
  64. return getattr(loaded_module, name)
  65. @staticmethod
  66. def _unravel_list_dict(input_data: List[Dict]):
  67. output = {}
  68. for data in input_data:
  69. for key, value in data.items():
  70. output[key] = value
  71. return output
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...