Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

prepare_dataset.py 4.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  1. import torch
  2. from nemo.utils import logging
  3. from rdkit import Chem
  4. from pysmilesutils.augment import SMILESAugmenter
  5. from typing import List
  6. import numpy as np
  7. import math
  8. from nemo_chem.tokenizer import MolEncTokenizer
  9. import time
  10. __all__ = ['PrepareDataset']
  11. class PrepareDataset:
  12. def __init__(self, tokenizer: MolEncTokenizer, seq_length: int,
  13. pad_size_divisible_by_8: bool, **kwargs):
  14. self.tokenizer = tokenizer
  15. self.seq_length = seq_length
  16. self.pad_size_divisible_by_8 = pad_size_divisible_by_8
  17. def _check_seq_len(self, tokens):
  18. """ Warn user and shorten sequence if the tokens are too long, otherwise return original
  19. Args:
  20. tokens (List[List[str]]): List of token sequences
  21. mask (List[List[int]]): List of mask sequences
  22. Returns:
  23. tokens (List[List[str]]): List of token sequences (shortened, if necessary)
  24. mask (List[List[int]]): List of mask sequences (shortened, if necessary)
  25. """
  26. seq_len = max([len(ts) for ts in tokens])
  27. if seq_len > self.seq_length:
  28. tokens_short = [ts[:self.seq_length] for ts in tokens]
  29. return tokens_short
  30. return tokens
  31. def _canonicalize_smile(self, smile):
  32. mol = Chem.MolFromSmiles(smile)
  33. canon_smile = Chem.MolToSmiles(mol, canonical=True)
  34. return canon_smile
  35. def convert_tokens_to_smiles(self, tokens, canonical: True):
  36. """Take in a token array and convert it back to a canonicalized smile"""
  37. smiles = self.tokenizer.detokenize(tokens)
  38. if canonical:
  39. canon_smiles = [self._canonicalize_smile(smile) for smile in smiles]
  40. return canon_smiles
  41. return smiles
  42. def _pad_seqs(self, seqs, pad_token):
  43. pad_length = max([len(seq) for seq in seqs])
  44. if self.pad_size_divisible_by_8:
  45. pad_length = int(math.ceil(pad_length/8) * 8)
  46. padded = [np.append(seq, np.array([pad_token] * (pad_length - len(seq)))) for seq in seqs]
  47. masks = [([1] * len(seq)) + ([0] * (pad_length - len(seq))) for seq in seqs] # 1/True = Active, 0/False = Inactive
  48. return padded, masks
  49. def _prepare_tokens(self, token_ids, canonicalize: bool = False):
  50. """Prepare tokens for encoder or decoder from batch of input SMILES strings
  51. Args:
  52. batch (List[str]): Batch of input SMILES strings
  53. tokenizer: Tokenizer instantiation.
  54. mask (bool, optional): Mask decoder tokens. Defaults to False.
  55. canonicalize (bool, optional): Canonicalize input SMILES. Defaults to False.
  56. smiles_augmenter (optional): Function to augment SMILES. Defaults to None.
  57. Returns:
  58. dict: token output
  59. """
  60. tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
  61. #canonicalize all ids
  62. canon_target = self.convert_tokens_to_smiles(tokens, canonical=False)
  63. # pad and optionally mask the tokens
  64. token_ids = self._check_seq_len(token_ids)
  65. token_output = {
  66. "token_ids": token_ids,
  67. "target_smiles": canon_target
  68. }
  69. return token_output
  70. def collate_fn(self, batch: List[np.array], label_pad: int = -1):
  71. encoder_tokens = self._prepare_tokens(batch, canonicalize=False)
  72. enc_token_ids, enc_pad_mask = self._pad_seqs(encoder_tokens['token_ids'], self.tokenizer.pad_id)
  73. enc_token_ids = torch.tensor(enc_token_ids, dtype=torch.int64) #converting a list into torch tensor is very slow, convert to np.array first
  74. enc_pad_mask = torch.tensor(enc_pad_mask, dtype=torch.int64)
  75. decoder_tokens = self._prepare_tokens(batch, canonicalize=False)
  76. label_ids = [sample + [self.tokenizer.eos_id] for sample in decoder_tokens['token_ids']] # assign label_ids before adding bos_id to decoder
  77. dec_token_ids = [[self.tokenizer.bos_id] + sample for sample in decoder_tokens['token_ids']]
  78. dec_token_ids, dec_pad_mask = self._pad_seqs(dec_token_ids, self.tokenizer.pad_id)
  79. dec_token_ids = torch.tensor(dec_token_ids, dtype=torch.int64)
  80. dec_pad_mask = torch.tensor(dec_pad_mask, dtype=torch.int64)
  81. label_token_ids, loss_mask = self._pad_seqs(label_ids, self.tokenizer.pad_id)
  82. label_token_ids = torch.tensor(label_token_ids, dtype=torch.int64)
  83. loss_mask = torch.tensor(loss_mask, dtype=torch.int64)
  84. label_token_ids[~loss_mask.to(torch.bool)] = label_pad
  85. collate_output = {
  86. "text_enc": enc_token_ids,
  87. "enc_mask": enc_pad_mask,
  88. "text_dec": dec_token_ids,
  89. "dec_mask": dec_pad_mask,
  90. 'labels': label_token_ids,
  91. 'loss_mask': loss_mask,
  92. 'target_smiles': encoder_tokens['target_smiles']} # smiles strings
  93. return collate_output
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...