Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

seq2seq_data.py 4.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
  1. """ Module containing classes to load seq2seq data"""
  2. import pandas as pd
  3. from rdkit import Chem
  4. from typing import Any, Dict, List, Tuple
  5. from molbart.data.base import ReactionListDataModule
  6. class Uspto50DataModule(ReactionListDataModule):
  7. """
  8. DataModule for the USPTO-50 dataset
  9. The reactions as well as a type token are read from
  10. a pickled DataFrame
  11. """
  12. def __init__(self, **kwargs):
  13. super().__init__(**kwargs)
  14. self._include_type_token = kwargs.get("include_type_token", False)
  15. def _get_sequences(self, batch: List[Dict[str, Any]], train: bool) -> Tuple[List[str], List[str]]:
  16. reactants = [Chem.MolToSmiles(item["reactants"]) for item in batch]
  17. products = [Chem.MolToSmiles(item["products"]) for item in batch]
  18. if train:
  19. reactants = self._batch_augmenter(reactants)
  20. products = self._batch_augmenter(products)
  21. if self._include_type_token and not self.reverse:
  22. reactants = [item["type_tokens"] + smi for item, smi in zip(batch, reactants)]
  23. if self._include_type_token and self.reverse:
  24. products = [item["type_tokens"] + smi for item, smi in zip(batch, products)]
  25. return reactants, products
  26. def _load_all_data(self) -> None:
  27. df = pd.read_pickle(self.dataset_path).reset_index()
  28. self._all_data = {
  29. "reactants": df["reactants_mol"].tolist(),
  30. "products": df["products_mol"].tolist(),
  31. "type_tokens": df["reaction_type"].tolist(),
  32. }
  33. self._set_split_indices_from_dataframe(df)
  34. class UsptoMixedDataModule(ReactionListDataModule):
  35. """
  36. DataModule for the USPTO-Mixed dataset
  37. The reactions are read from a pickled DataFrame
  38. """
  39. def _get_sequences(self, batch: List[Dict[str, Any]], train: bool) -> Tuple[List[str], List[str]]:
  40. reactants = [Chem.MolToSmiles(item["reactants"]) for item in batch]
  41. products = [Chem.MolToSmiles(item["products"]) for item in batch]
  42. if train:
  43. reactants = self._batch_augmenter(reactants)
  44. products = self._batch_augmenter(products)
  45. return reactants, products
  46. def _load_all_data(self) -> None:
  47. df = pd.read_pickle(self.dataset_path).reset_index()
  48. self._all_data = {
  49. "reactants": df["reactants_mol"].tolist(),
  50. "products": df["products_mol"].tolist(),
  51. }
  52. self._set_split_indices_from_dataframe(df)
  53. class UsptoSepDataModule(ReactionListDataModule):
  54. """
  55. DataModule for the USPTO-Separated dataset
  56. The reactants, reagents and products are read from
  57. a pickled DataFrame
  58. """
  59. def _get_sequences(self, batch: List[Dict[str, Any]], train: bool) -> Tuple[List[str], List[str]]:
  60. reactants = [Chem.MolToSmiles(item["reactants"]) for item in batch]
  61. reagents = [Chem.MolToSmiles(item["reagents"]) for item in batch]
  62. products = [Chem.MolToSmiles(item["products"]) for item in batch]
  63. if train:
  64. reactants = self._batch_augmenter(reactants)
  65. reagents = self._batch_augmenter(reagents)
  66. products = self._batch_augmenter(products)
  67. reactants = [react_smi + ">" + reag_smi for react_smi, reag_smi in zip(reactants, reagents)]
  68. return reactants, products
  69. def _load_all_data(self) -> None:
  70. df = pd.read_pickle(self.dataset_path).reset_index()
  71. self._all_data = {
  72. "reactants": df["reactants_mol"].tolist(),
  73. "products": df["products_mol"].tolist(),
  74. "reagents": df["reagents_mol"].tolist(),
  75. }
  76. self._set_split_indices_from_dataframe(df)
  77. class MolecularOptimizationDataModule(ReactionListDataModule):
  78. """
  79. DataModule for a dataset for molecular optimization
  80. The input and ouput molecules, as well as a the property
  81. tokens are read from a pickled DataFrame
  82. """
  83. def _get_sequences(self, batch: List[Dict[str, Any]], train: bool) -> Tuple[List[str], List[str]]:
  84. input_smiles = [Chem.MolToSmiles(item["input_mols"]) for item in batch]
  85. output_smiles = [Chem.MolToSmiles(item["output_mols"]) for item in batch]
  86. if train:
  87. input_smiles = self._batch_augmenter(input_smiles)
  88. output_smiles = self._batch_augmenter(output_smiles)
  89. input_smiles = [item["prop_tokens"] + smi for item, smi in zip(batch, input_smiles)]
  90. return input_smiles, output_smiles
  91. def _load_all_data(self) -> None:
  92. df = pd.read_pickle(self.dataset_path).reset_index()
  93. self._all_data = {
  94. "prop_tokens": df["property_tokens"].tolist(),
  95. "input_mols": df["input_mols"].tolist(),
  96. "output_mols": df["output_mols"].tolist(),
  97. }
  98. self._set_split_indices_from_dataframe(df)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...