Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

callback_collection.py 2.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  1. """ Module containing classes used to score the reaction routes.
  2. """
  3. from __future__ import annotations
  4. import logging
  5. from typing import TYPE_CHECKING
  6. import yaml
  7. from omegaconf import ListConfig, OmegaConf
  8. from pytorch_lightning.callbacks import Callback
  9. from molbart.utils.base_collection import BaseCollection
  10. from molbart.utils.callbacks.callbacks import __name__ as callback_module
  11. if TYPE_CHECKING:
  12. from typing import List
  13. class CallbackCollection(BaseCollection):
  14. """
  15. Store callback objects for the chemformer model.
  16. The callbacks can be obtained by name
  17. .. code-block::
  18. callbacks = CallbackCollection()
  19. callback = callbacks['LearningRateMonitor']
  20. """
  21. _collection_name = "callbacks"
  22. def __init__(self) -> None:
  23. super().__init__()
  24. self._logger = logging.Logger("callback-collection")
  25. def __repr__(self) -> str:
  26. if self.selection:
  27. return f"{self._collection_name} ({', '.join(self.selection)})"
  28. return f"{self._collection_name} ({', '.join(self.items)})"
  29. def load(self, callback: Callback) -> None: # type: ignore
  30. """
  31. Add a pre-initialized callback object to the collection
  32. Args:
  33. callback: the item to add
  34. """
  35. if not isinstance(callback, Callback):
  36. raise ValueError(
  37. "Only objects of classes inherited from " "pytorch_lightning.callbacks.Callbacks can be added"
  38. )
  39. self._items[repr(callback)] = callback
  40. self._logger.info(f"Loaded callback: {repr(callback)}")
  41. def load_from_config(self, callbacks_config: ListConfig) -> None:
  42. """
  43. Load one or several callbacks from a configuration dictionary
  44. The keys are the name of callback class. If a callback is not
  45. defined in the ``molbart.utils.callbacks.callbacks`` module, the module
  46. name can be appended, e.g. ``mypackage.callbacks.AwesomeCallback``.
  47. The values of the configuration is passed directly to the callback
  48. class along with the ``config`` parameter.
  49. Args:
  50. callbacks_config: Config of callbacks
  51. """
  52. for item in callbacks_config:
  53. if isinstance(item, str):
  54. cls = self.load_dynamic_class(item, callback_module)
  55. obj = cls()
  56. config_str = ""
  57. else:
  58. item = [(key, item[key]) for key in item.keys()][0]
  59. name, kwargs = item
  60. x = yaml.load(OmegaConf.to_yaml(kwargs), Loader=yaml.SafeLoader)
  61. kwargs = self._unravel_list_dict(x)
  62. cls = self.load_dynamic_class(name, callback_module)
  63. obj = cls(**kwargs)
  64. config_str = f" with configuration '{kwargs}'"
  65. self._items[repr(obj)] = obj
  66. print(f"Loaded callback: '{repr(obj)}'{config_str}")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...