Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

add_ktiv_male.py 1.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  1. from models import KtivMaleModel
  2. from tasks import KtivMaleTask
  3. from transformers import CanineTokenizer
  4. import pandas as pd
  5. import torch
  6. from tqdm.auto import tqdm
  7. import logging
  8. MODEL_FN = 'models/ktiv_male/latest'
  9. DATA_FN = 'data/processed/nikud.csv'
  10. SAVE_FN = 'data/processed/nikud_with_ktiv_male.csv'
  11. def combined_text_gen(df, max_len = 2000):
  12. # note: max_len must be less than 2048 (max input length of model)
  13. # and should be somewhat smaller to account for added characters in ktiv male
  14. out = ''
  15. for text in df.text:
  16. if len(out) + len(text) + 1 <= max_len:
  17. if out != '':
  18. out += '\n'
  19. out += text
  20. else:
  21. yield out
  22. out = text
  23. if out != '':
  24. yield out
  25. def main():
  26. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  27. print('Device detected:', device)
  28. print('Loading data...')
  29. df = pd.read_csv(DATA_FN)
  30. print('Loading tokenizer...')
  31. tokenizer = CanineTokenizer.from_pretrained("google/canine-c")
  32. print('Loading model...')
  33. model = KtivMaleModel.from_pretrained(MODEL_FN)
  34. model.to(device)
  35. model.eval()
  36. print('Creating task...')
  37. task = KtivMaleTask(tokenizer, model, device=device)
  38. print('Combining texts into larger units...')
  39. df = pd.DataFrame({
  40. 'haser': [t for t in tqdm(combined_text_gen(df))]
  41. })
  42. print('Adding ktiv male to text...')
  43. tqdm.pandas(desc='Generating ktiv male')
  44. df['male'] = df.haser.progress_apply(lambda text: task.nikud2male(text, split=True, sample=False))
  45. df = df[df.male != ''].copy()
  46. print(f'Saving to: {SAVE_FN}')
  47. df.to_csv(SAVE_FN, index=False)
  48. if __name__ == '__main__':
  49. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...