Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

data_collection.py 4.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
  1. """ Module containing classes used to score the reaction routes.
  2. """
  3. from __future__ import annotations
  4. import logging
  5. from omegaconf import ListConfig, OmegaConf
  6. from typing import Any, Dict
  7. import yaml
  8. from molbart.data import _AbsDataModule
  9. from molbart.data.datamodules import __name__ as data_module
  10. from molbart.utils import data_utils
  11. from molbart.utils.tokenizers import ChemformerTokenizer
  12. from molbart.utils.base_collection import BaseCollection
  13. class DataCollection(BaseCollection):
  14. """
  15. Store datamodule object for the chemformer model.
  16. The datamodule can be obtained by name
  17. .. code-block::
  18. datamodule = DataCollection()
  19. """
  20. _collection_name = "data"
  21. def __init__(self, config: OmegaConf, tokenizer: ChemformerTokenizer) -> None:
  22. super().__init__()
  23. self._logger = logging.Logger("data-collection")
  24. self._config = config
  25. self._tokenizer = tokenizer
  26. def __repr__(self) -> str:
  27. if self.selection:
  28. return f"{self._collection_name} ({', '.join(self.selection)})"
  29. return f"{self._collection_name} ({', '.join(self.items)})"
  30. def load(self, datamodule: _AbsDataModule) -> None: # type: ignore
  31. """
  32. Load a datamodule object to the collection
  33. Args:
  34. datamodule: the item to add
  35. """
  36. if not isinstance(datamodule, _AbsDataModule):
  37. raise ValueError("Only objects of classes inherited from " "molbart.data._AbsDataModule can be added")
  38. self._items[repr(datamodule)] = datamodule
  39. self._logger.info(f"Loaded datamodule: {repr(datamodule)}")
  40. def load_from_config(self, datamodule_config: ListConfig) -> None:
  41. """
  42. Load a datamodule from a configuration dictionary
  43. The keys are the name of score class. If a score is not
  44. defined in the ``molbart.data.datamodules`` module, the module
  45. name can be appended, e.g. ``mypackage.data.AwesomeDataModule``.
  46. The values of the configuration is passed directly to the datamodule
  47. class along with the ``config`` parameter.
  48. Args:
  49. datamodule_config: Config of the datamodule
  50. """
  51. for item in datamodule_config:
  52. if isinstance(item, str):
  53. cls = self.load_dynamic_class(item, data_module)
  54. kwargs = self._set_datamodule_kwargs()
  55. else:
  56. item = [(key, item[key]) for key in item.keys()][0]
  57. name, kwargs = item
  58. cls = self.load_dynamic_class(name, data_module)
  59. x = yaml.load(OmegaConf.to_yaml(kwargs), Loader=yaml.SafeLoader)
  60. kwargs = self._unravel_list_dict(x)
  61. kwargs.update(self._set_datamodule_kwargs())
  62. obj = cls(**kwargs)
  63. config_str = f" with configuration '{kwargs}'"
  64. self._items[repr(obj)] = obj
  65. print(f"Loaded datamodule: '{repr(obj)}'{config_str}")
  66. def get_datamodule(self, datamodule_config: ListConfig) -> _AbsDataModule:
  67. """
  68. Return the datamodule which has been loaded from the config file
  69. """
  70. self.load_from_config(datamodule_config)
  71. return self.objects()[0]
  72. def _set_datamodule_kwargs(self) -> Dict[str, Any]:
  73. """
  74. Returns a dictionary with kwargs which are general to the _AbsDataModule.
  75. These are specified as single parameters in the config file
  76. """
  77. reverse = self._config.task == "backward_prediction"
  78. kwargs = {
  79. "reverse": reverse,
  80. "max_seq_len": self._config.get("max_seq_len", data_utils.DEFAULT_MAX_SEQ_LEN),
  81. "tokenizer": self._tokenizer,
  82. "augment_prob": self._config.get("augmentation_probability"),
  83. "augment_prob": self._config.get("augmentation_probability"),
  84. "unified_model": self._config.model_type == "unified",
  85. "dataset_path": self._config.data_path,
  86. "batch_size": self._config.batch_size,
  87. "train_token_batch_size": self._config.get("train_tokens"),
  88. "num_buckets": self._config.get("n_buckets"),
  89. "unified_model": self._config.model_type == "unified",
  90. "i_chunk": self._config.get("i_chunk", 0),
  91. "n_chunks": self._config.get("n_chunks", 1),
  92. }
  93. return kwargs
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...