Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_huggingface_import.py 2.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
  1. """
  2. Ensure that we can load huggingface/transformer GPTs into minGPT
  3. """
  4. import unittest
  5. import torch
  6. from transformers import GPT2Tokenizer, GPT2LMHeadModel
  7. from mingpt.model import GPT
  8. from mingpt.bpe import BPETokenizer
  9. # -----------------------------------------------------------------------------
  10. class TestHuggingFaceImport(unittest.TestCase):
  11. def test_gpt2(self):
  12. model_type = 'gpt2'
  13. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  14. prompt = "Hello!!!!!!!!!? 🤗, my dog is a little"
  15. # create a minGPT and a huggingface/transformers model
  16. model = GPT.from_pretrained(model_type)
  17. model_hf = GPT2LMHeadModel.from_pretrained(model_type) # init a HF model too
  18. # ship both to device
  19. model.to(device)
  20. model_hf.to(device)
  21. # set both to eval mode
  22. model.eval()
  23. model_hf.eval()
  24. # tokenize input prompt
  25. # ... with mingpt
  26. tokenizer = BPETokenizer()
  27. x1 = tokenizer(prompt).to(device)
  28. # ... with huggingface/transformers
  29. tokenizer_hf = GPT2Tokenizer.from_pretrained(model_type)
  30. model_hf.config.pad_token_id = model_hf.config.eos_token_id # suppress a warning
  31. encoded_input = tokenizer_hf(prompt, return_tensors='pt').to(device)
  32. x2 = encoded_input['input_ids']
  33. # ensure the logits match exactly
  34. logits1, loss = model(x1)
  35. logits2 = model_hf(x2).logits
  36. self.assertTrue(torch.allclose(logits1, logits2))
  37. # now draw the argmax samples from each
  38. y1 = model.generate(x1, max_new_tokens=20, do_sample=False)[0]
  39. y2 = model_hf.generate(x2, max_new_tokens=20, do_sample=False)[0]
  40. self.assertTrue(torch.equal(y1, y2)) # compare the raw sampled indices
  41. # convert indices to strings
  42. out1 = tokenizer.decode(y1.cpu().squeeze())
  43. out2 = tokenizer_hf.decode(y2.cpu().squeeze())
  44. self.assertTrue(out1 == out2) # compare the exact output strings too
  45. if __name__ == '__main__':
  46. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...