Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

model.py 14 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
  1. """
  2. Full definition of a GPT Language Model, all of it in this single file.
  3. References:
  4. 1) the official GPT-2 TensorFlow implementation released by OpenAI:
  5. https://github.com/openai/gpt-2/blob/master/src/model.py
  6. 2) huggingface/transformers PyTorch implementation:
  7. https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
  8. """
  9. import math
  10. import torch
  11. import torch.nn as nn
  12. from torch.nn import functional as F
  13. from mingpt.utils import CfgNode as CN
  14. # -----------------------------------------------------------------------------
  15. class NewGELU(nn.Module):
  16. """
  17. Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
  18. Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
  19. """
  20. def forward(self, x):
  21. return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
  22. class CausalSelfAttention(nn.Module):
  23. """
  24. A vanilla multi-head masked self-attention layer with a projection at the end.
  25. It is possible to use torch.nn.MultiheadAttention here but I am including an
  26. explicit implementation here to show that there is nothing too scary here.
  27. """
  28. def __init__(self, config):
  29. super().__init__()
  30. assert config.n_embd % config.n_head == 0
  31. # key, query, value projections for all heads, but in a batch
  32. self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
  33. # output projection
  34. self.c_proj = nn.Linear(config.n_embd, config.n_embd)
  35. # regularization
  36. self.attn_dropout = nn.Dropout(config.attn_pdrop)
  37. self.resid_dropout = nn.Dropout(config.resid_pdrop)
  38. # causal mask to ensure that attention is only applied to the left in the input sequence
  39. self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
  40. .view(1, 1, config.block_size, config.block_size))
  41. self.n_head = config.n_head
  42. self.n_embd = config.n_embd
  43. def forward(self, x):
  44. B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
  45. # calculate query, key, values for all heads in batch and move head forward to be the batch dim
  46. q, k ,v = self.c_attn(x).split(self.n_embd, dim=2)
  47. k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
  48. q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
  49. v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
  50. # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
  51. att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
  52. att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
  53. att = F.softmax(att, dim=-1)
  54. att = self.attn_dropout(att)
  55. y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
  56. y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
  57. # output projection
  58. y = self.resid_dropout(self.c_proj(y))
  59. return y
  60. class Block(nn.Module):
  61. """ an unassuming Transformer block """
  62. def __init__(self, config):
  63. super().__init__()
  64. self.ln_1 = nn.LayerNorm(config.n_embd)
  65. self.attn = CausalSelfAttention(config)
  66. self.ln_2 = nn.LayerNorm(config.n_embd)
  67. self.mlp = nn.ModuleDict(dict(
  68. c_fc = nn.Linear(config.n_embd, 4 * config.n_embd),
  69. c_proj = nn.Linear(4 * config.n_embd, config.n_embd),
  70. act = NewGELU(),
  71. dropout = nn.Dropout(config.resid_pdrop),
  72. ))
  73. m = self.mlp
  74. self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward
  75. def forward(self, x):
  76. x = x + self.attn(self.ln_1(x))
  77. x = x + self.mlpf(self.ln_2(x))
  78. return x
  79. class GPT(nn.Module):
  80. """ GPT Language Model """
  81. @staticmethod
  82. def get_default_config():
  83. C = CN()
  84. # either model_type or (n_layer, n_head, n_embd) must be given in the config
  85. C.model_type = 'gpt'
  86. C.n_layer = None
  87. C.n_head = None
  88. C.n_embd = None
  89. # these options must be filled in externally
  90. C.vocab_size = None
  91. C.block_size = None
  92. # dropout hyperparameters
  93. C.embd_pdrop = 0.1
  94. C.resid_pdrop = 0.1
  95. C.attn_pdrop = 0.1
  96. return C
  97. def __init__(self, config):
  98. super().__init__()
  99. assert config.vocab_size is not None
  100. assert config.block_size is not None
  101. self.block_size = config.block_size
  102. type_given = config.model_type is not None
  103. params_given = all([config.n_layer is not None, config.n_head is not None, config.n_embd is not None])
  104. assert type_given ^ params_given # exactly one of these (XOR)
  105. if type_given:
  106. # translate from model_type to detailed configuration
  107. config.merge_from_dict({
  108. # names follow the huggingface naming conventions
  109. # GPT-1
  110. 'openai-gpt': dict(n_layer=12, n_head=12, n_embd=768), # 117M params
  111. # GPT-2 configs
  112. 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
  113. 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
  114. 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
  115. 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
  116. # Gophers
  117. 'gopher-44m': dict(n_layer=8, n_head=16, n_embd=512),
  118. # (there are a number more...)
  119. # I made these tiny models up
  120. 'gpt-mini': dict(n_layer=6, n_head=6, n_embd=192),
  121. 'gpt-micro': dict(n_layer=4, n_head=4, n_embd=128),
  122. 'gpt-nano': dict(n_layer=3, n_head=3, n_embd=48),
  123. }[config.model_type])
  124. self.transformer = nn.ModuleDict(dict(
  125. wte = nn.Embedding(config.vocab_size, config.n_embd),
  126. wpe = nn.Embedding(config.block_size, config.n_embd),
  127. drop = nn.Dropout(config.embd_pdrop),
  128. h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
  129. ln_f = nn.LayerNorm(config.n_embd),
  130. ))
  131. self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
  132. # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
  133. self.apply(self._init_weights)
  134. for pn, p in self.named_parameters():
  135. if pn.endswith('c_proj.weight'):
  136. torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
  137. # report number of parameters (note we don't count the decoder parameters in lm_head)
  138. n_params = sum(p.numel() for p in self.transformer.parameters())
  139. print("number of parameters: %.2fM" % (n_params/1e6,))
  140. def _init_weights(self, module):
  141. if isinstance(module, nn.Linear):
  142. torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
  143. if module.bias is not None:
  144. torch.nn.init.zeros_(module.bias)
  145. elif isinstance(module, nn.Embedding):
  146. torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
  147. elif isinstance(module, nn.LayerNorm):
  148. torch.nn.init.zeros_(module.bias)
  149. torch.nn.init.ones_(module.weight)
  150. @classmethod
  151. def from_pretrained(cls, model_type):
  152. """
  153. Initialize a pretrained GPT model by copying over the weights
  154. from a huggingface/transformers checkpoint.
  155. """
  156. assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
  157. from transformers import GPT2LMHeadModel
  158. # create a from-scratch initialized minGPT model
  159. config = cls.get_default_config()
  160. config.model_type = model_type
  161. config.vocab_size = 50257 # openai's model vocabulary
  162. config.block_size = 1024 # openai's model block_size
  163. model = GPT(config)
  164. sd = model.state_dict()
  165. # init a huggingface/transformers model
  166. model_hf = GPT2LMHeadModel.from_pretrained(model_type)
  167. sd_hf = model_hf.state_dict()
  168. # copy while ensuring all of the parameters are aligned and match in names and shapes
  169. keys = [k for k in sd_hf if not k.endswith('attn.masked_bias')] # ignore these
  170. transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
  171. # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla nn.Linear.
  172. # this means that we have to transpose these weights when we import them
  173. assert len(keys) == len(sd)
  174. for k in keys:
  175. if any(k.endswith(w) for w in transposed):
  176. # special treatment for the Conv1D weights we need to transpose
  177. assert sd_hf[k].shape[::-1] == sd[k].shape
  178. with torch.no_grad():
  179. sd[k].copy_(sd_hf[k].t())
  180. else:
  181. # vanilla copy over the other parameters
  182. assert sd_hf[k].shape == sd[k].shape
  183. with torch.no_grad():
  184. sd[k].copy_(sd_hf[k])
  185. return model
  186. def configure_optimizers(self, train_config):
  187. """
  188. This long function is unfortunately doing something very simple and is being very defensive:
  189. We are separating out all parameters of the model into two buckets: those that will experience
  190. weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
  191. We are then returning the PyTorch optimizer object.
  192. """
  193. # separate out all parameters to those that will and won't experience regularizing weight decay
  194. decay = set()
  195. no_decay = set()
  196. whitelist_weight_modules = (torch.nn.Linear, )
  197. blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
  198. for mn, m in self.named_modules():
  199. for pn, p in m.named_parameters():
  200. fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
  201. # random note: because named_modules and named_parameters are recursive
  202. # we will see the same tensors p many many times. but doing it this way
  203. # allows us to know which parent module any tensor p belongs to...
  204. if pn.endswith('bias'):
  205. # all biases will not be decayed
  206. no_decay.add(fpn)
  207. elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
  208. # weights of whitelist modules will be weight decayed
  209. decay.add(fpn)
  210. elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
  211. # weights of blacklist modules will NOT be weight decayed
  212. no_decay.add(fpn)
  213. # validate that we considered every parameter
  214. param_dict = {pn: p for pn, p in self.named_parameters()}
  215. inter_params = decay & no_decay
  216. union_params = decay | no_decay
  217. assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
  218. assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
  219. % (str(param_dict.keys() - union_params), )
  220. # create the pytorch optimizer object
  221. optim_groups = [
  222. {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
  223. {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
  224. ]
  225. optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
  226. return optimizer
  227. def forward(self, idx, targets=None):
  228. device = idx.device
  229. b, t = idx.size()
  230. assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
  231. pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
  232. # forward the GPT model itself
  233. tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
  234. pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
  235. x = self.transformer.drop(tok_emb + pos_emb)
  236. for block in self.transformer.h:
  237. x = block(x)
  238. x = self.transformer.ln_f(x)
  239. logits = self.lm_head(x)
  240. # if we are given some desired targets also calculate the loss
  241. loss = None
  242. if targets is not None:
  243. loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
  244. return logits, loss
  245. @torch.no_grad()
  246. def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
  247. """
  248. Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
  249. the sequence max_new_tokens times, feeding the predictions back into the model each time.
  250. Most likely you'll want to make sure to be in model.eval() mode of operation for this.
  251. """
  252. for _ in range(max_new_tokens):
  253. # if the sequence context is growing too long we must crop it at block_size
  254. idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
  255. # forward the model to get the logits for the index in the sequence
  256. logits, _ = self(idx_cond)
  257. # pluck the logits at the final step and scale by desired temperature
  258. logits = logits[:, -1, :] / temperature
  259. # optionally crop the logits to only the top k options
  260. if top_k is not None:
  261. v, _ = torch.topk(logits, top_k)
  262. logits[logits < v[:, [-1]]] = -float('Inf')
  263. # apply softmax to convert logits to (normalized) probabilities
  264. probs = F.softmax(logits, dim=-1)
  265. # either sample from the distribution or take the most likely element
  266. if do_sample:
  267. idx_next = torch.multinomial(probs, num_samples=1)
  268. else:
  269. _, idx_next = torch.topk(probs, k=1, dim=-1)
  270. # append sampled index to the running sequence and continue
  271. idx = torch.cat((idx, idx_next), dim=1)
  272. return idx
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...