Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

lstm.py 19 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import torch
  8. from torch.autograd import Variable
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from fairseq import utils
  12. from fairseq.data import LanguagePairDataset
  13. from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model, register_model_architecture
  14. @register_model('lstm')
  15. class LSTMModel(FairseqModel):
  16. def __init__(self, encoder, decoder):
  17. super().__init__(encoder, decoder)
  18. @staticmethod
  19. def add_args(parser):
  20. """Add model-specific arguments to the parser."""
  21. parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
  22. help='dropout probability')
  23. parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
  24. help='encoder embedding dimension')
  25. parser.add_argument('--encoder-embed-path', default=None, type=str, metavar='STR',
  26. help='path to pre-trained encoder embedding')
  27. parser.add_argument('--encoder-hidden-size', type=int, metavar='N',
  28. help='encoder hidden size')
  29. parser.add_argument('--encoder-layers', type=int, metavar='N',
  30. help='number of encoder layers')
  31. parser.add_argument('--encoder-bidirectional', action='store_true',
  32. help='make all layers of encoder bidirectional')
  33. parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
  34. help='decoder embedding dimension')
  35. parser.add_argument('--decoder-embed-path', default=None, type=str, metavar='STR',
  36. help='path to pre-trained decoder embedding')
  37. parser.add_argument('--decoder-hidden-size', type=int, metavar='N',
  38. help='decoder hidden size')
  39. parser.add_argument('--decoder-layers', type=int, metavar='N',
  40. help='number of decoder layers')
  41. parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
  42. help='decoder output embedding dimension')
  43. parser.add_argument('--decoder-attention', type=str, metavar='BOOL',
  44. help='decoder attention')
  45. # Granular dropout settings (if not specified these default to --dropout)
  46. parser.add_argument('--encoder-dropout-in', type=float, metavar='D',
  47. help='dropout probability for encoder input embedding')
  48. parser.add_argument('--encoder-dropout-out', type=float, metavar='D',
  49. help='dropout probability for encoder output')
  50. parser.add_argument('--decoder-dropout-in', type=float, metavar='D',
  51. help='dropout probability for decoder input embedding')
  52. parser.add_argument('--decoder-dropout-out', type=float, metavar='D',
  53. help='dropout probability for decoder output')
  54. @classmethod
  55. def build_model(cls, args, src_dict, dst_dict):
  56. """Build a new model instance."""
  57. if not hasattr(args, 'encoder_embed_path'):
  58. args.encoder_embed_path = None
  59. if not hasattr(args, 'decoder_embed_path'):
  60. args.decoder_embed_path = None
  61. if not hasattr(args, 'encoder_hidden_size'):
  62. args.encoder_hidden_size = args.encoder_embed_dim
  63. if not hasattr(args, 'decoder_hidden_size'):
  64. args.decoder_hidden_size = args.decoder_embed_dim
  65. if not hasattr(args, 'encoder_bidirectional'):
  66. args.encoder_bidirectional = False
  67. def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
  68. num_embeddings = len(dictionary)
  69. padding_idx = dictionary.pad()
  70. embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
  71. embed_dict = utils.parse_embedding(embed_path)
  72. utils.print_embed_overlap(embed_dict, dictionary)
  73. return utils.load_embedding(embed_dict, dictionary, embed_tokens)
  74. pretrained_encoder_embed = None
  75. if args.encoder_embed_path:
  76. pretrained_encoder_embed = load_pretrained_embedding_from_file(
  77. args.encoder_embed_path, src_dict, args.encoder_embed_dim)
  78. pretrained_decoder_embed = None
  79. if args.decoder_embed_path:
  80. pretrained_decoder_embed = load_pretrained_embedding_from_file(
  81. args.decoder_embed_path, dst_dict, args.decoder_embed_dim)
  82. encoder = LSTMEncoder(
  83. dictionary=src_dict,
  84. embed_dim=args.encoder_embed_dim,
  85. hidden_size=args.encoder_hidden_size,
  86. num_layers=args.encoder_layers,
  87. dropout_in=args.encoder_dropout_in,
  88. dropout_out=args.encoder_dropout_out,
  89. bidirectional=args.encoder_bidirectional,
  90. pretrained_embed=pretrained_encoder_embed,
  91. )
  92. try:
  93. attention = bool(eval(args.decoder_attention))
  94. except TypeError:
  95. attention = bool(args.decoder_attention)
  96. decoder = LSTMDecoder(
  97. dictionary=dst_dict,
  98. embed_dim=args.decoder_embed_dim,
  99. hidden_size=args.decoder_hidden_size,
  100. out_embed_dim=args.decoder_out_embed_dim,
  101. num_layers=args.decoder_layers,
  102. dropout_in=args.decoder_dropout_in,
  103. dropout_out=args.decoder_dropout_out,
  104. attention=attention,
  105. encoder_embed_dim=args.encoder_embed_dim,
  106. encoder_output_units=encoder.output_units,
  107. pretrained_embed=pretrained_decoder_embed,
  108. )
  109. return cls(encoder, decoder)
  110. class LSTMEncoder(FairseqEncoder):
  111. """LSTM encoder."""
  112. def __init__(
  113. self, dictionary, embed_dim=512, hidden_size=512, num_layers=1,
  114. dropout_in=0.1, dropout_out=0.1, bidirectional=False,
  115. left_pad_source=LanguagePairDataset.LEFT_PAD_SOURCE,
  116. pretrained_embed=None,
  117. padding_value=0.,
  118. ):
  119. super().__init__(dictionary)
  120. self.num_layers = num_layers
  121. self.dropout_in = dropout_in
  122. self.dropout_out = dropout_out
  123. self.bidirectional = bidirectional
  124. self.hidden_size = hidden_size
  125. num_embeddings = len(dictionary)
  126. self.padding_idx = dictionary.pad()
  127. if pretrained_embed is None:
  128. self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
  129. else:
  130. self.embed_tokens = pretrained_embed
  131. self.lstm = LSTM(
  132. input_size=embed_dim,
  133. hidden_size=hidden_size,
  134. num_layers=num_layers,
  135. dropout=self.dropout_out,
  136. bidirectional=bidirectional,
  137. )
  138. self.left_pad_source = left_pad_source
  139. self.padding_value = padding_value
  140. self.output_units = hidden_size
  141. if bidirectional:
  142. self.output_units *= 2
  143. def forward(self, src_tokens, src_lengths):
  144. if self.left_pad_source:
  145. # convert left-padding to right-padding
  146. src_tokens = utils.convert_padding_direction(
  147. src_tokens,
  148. self.padding_idx,
  149. left_to_right=True,
  150. )
  151. bsz, seqlen = src_tokens.size()
  152. # embed tokens
  153. x = self.embed_tokens(src_tokens)
  154. x = F.dropout(x, p=self.dropout_in, training=self.training)
  155. # B x T x C -> T x B x C
  156. x = x.transpose(0, 1)
  157. # pack embedded source tokens into a PackedSequence
  158. packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist())
  159. # apply LSTM
  160. if self.bidirectional:
  161. state_size = 2 * self.num_layers, bsz, self.hidden_size
  162. else:
  163. state_size = self.num_layers, bsz, self.hidden_size
  164. h0 = Variable(x.data.new(*state_size).zero_())
  165. c0 = Variable(x.data.new(*state_size).zero_())
  166. packed_outs, (final_hiddens, final_cells) = self.lstm(
  167. packed_x,
  168. (h0, c0),
  169. )
  170. # unpack outputs and apply dropout
  171. x, _ = nn.utils.rnn.pad_packed_sequence(
  172. packed_outs, padding_value=self.padding_value)
  173. x = F.dropout(x, p=self.dropout_out, training=self.training)
  174. assert list(x.size()) == [seqlen, bsz, self.output_units]
  175. if self.bidirectional:
  176. bi_final_hiddens, bi_final_cells = [], []
  177. for i in range(self.num_layers):
  178. bi_final_hiddens.append(
  179. torch.cat(
  180. (final_hiddens[2 * i], final_hiddens[2 * i + 1]),
  181. dim=0).view(bsz, self.output_units))
  182. bi_final_cells.append(
  183. torch.cat(
  184. (final_cells[2 * i], final_cells[2 * i + 1]),
  185. dim=0).view(bsz, self.output_units))
  186. return x, bi_final_hiddens, bi_final_cells
  187. encoder_padding_mask = src_tokens.eq(self.padding_idx).t()
  188. return {
  189. 'encoder_out': (x, final_hiddens, final_cells),
  190. 'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None
  191. }
  192. def max_positions(self):
  193. """Maximum input length supported by the encoder."""
  194. return int(1e5) # an arbitrary large number
  195. class AttentionLayer(nn.Module):
  196. def __init__(self, input_embed_dim, output_embed_dim):
  197. super().__init__()
  198. self.input_proj = Linear(input_embed_dim, output_embed_dim, bias=False)
  199. self.output_proj = Linear(2*output_embed_dim, output_embed_dim, bias=False)
  200. def forward(self, input, source_hids, encoder_padding_mask):
  201. # input: bsz x input_embed_dim
  202. # source_hids: srclen x bsz x output_embed_dim
  203. # x: bsz x output_embed_dim
  204. x = self.input_proj(input)
  205. # compute attention
  206. attn_scores = (source_hids * x.unsqueeze(0)).sum(dim=2)
  207. # don't attend over padding
  208. if encoder_padding_mask is not None:
  209. attn_scores = attn_scores.float().masked_fill_(
  210. encoder_padding_mask,
  211. float('-inf')
  212. ).type_as(attn_scores) # FP16 support: cast to float and back
  213. attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz
  214. # sum weighted sources
  215. x = (attn_scores.unsqueeze(2) * source_hids).sum(dim=0)
  216. x = F.tanh(self.output_proj(torch.cat((x, input), dim=1)))
  217. return x, attn_scores
  218. class LSTMDecoder(FairseqIncrementalDecoder):
  219. """LSTM decoder."""
  220. def __init__(
  221. self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512,
  222. num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True,
  223. encoder_embed_dim=512, encoder_output_units=512,
  224. pretrained_embed=None,
  225. ):
  226. super().__init__(dictionary)
  227. self.dropout_in = dropout_in
  228. self.dropout_out = dropout_out
  229. self.hidden_size = hidden_size
  230. num_embeddings = len(dictionary)
  231. padding_idx = dictionary.pad()
  232. if pretrained_embed is None:
  233. self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
  234. else:
  235. self.embed_tokens = pretrained_embed
  236. self.encoder_output_units = encoder_output_units
  237. assert encoder_output_units == hidden_size, \
  238. '{} {}'.format(encoder_output_units, hidden_size)
  239. # TODO another Linear layer if not equal
  240. self.layers = nn.ModuleList([
  241. LSTMCell(
  242. input_size=encoder_output_units + embed_dim if layer == 0 else hidden_size,
  243. hidden_size=hidden_size,
  244. )
  245. for layer in range(num_layers)
  246. ])
  247. self.attention = AttentionLayer(encoder_output_units, hidden_size) if attention else None
  248. if hidden_size != out_embed_dim:
  249. self.additional_fc = Linear(hidden_size, out_embed_dim)
  250. self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
  251. def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
  252. encoder_out = encoder_out_dict['encoder_out']
  253. encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
  254. if incremental_state is not None:
  255. prev_output_tokens = prev_output_tokens[:, -1:]
  256. bsz, seqlen = prev_output_tokens.size()
  257. # get outputs from encoder
  258. encoder_outs, _, _ = encoder_out[:3]
  259. srclen = encoder_outs.size(0)
  260. # embed tokens
  261. x = self.embed_tokens(prev_output_tokens)
  262. x = F.dropout(x, p=self.dropout_in, training=self.training)
  263. # B x T x C -> T x B x C
  264. x = x.transpose(0, 1)
  265. # initialize previous states (or get from cache during incremental generation)
  266. cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
  267. if cached_state is not None:
  268. prev_hiddens, prev_cells, input_feed = cached_state
  269. else:
  270. _, encoder_hiddens, encoder_cells = encoder_out[:3]
  271. num_layers = len(self.layers)
  272. prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
  273. prev_cells = [encoder_cells[i] for i in range(num_layers)]
  274. input_feed = Variable(x.data.new(bsz, self.encoder_output_units).zero_())
  275. attn_scores = Variable(x.data.new(srclen, seqlen, bsz).zero_())
  276. outs = []
  277. for j in range(seqlen):
  278. # input feeding: concatenate context vector from previous time step
  279. input = torch.cat((x[j, :, :], input_feed), dim=1)
  280. for i, rnn in enumerate(self.layers):
  281. # recurrent cell
  282. hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))
  283. # hidden state becomes the input to the next layer
  284. input = F.dropout(hidden, p=self.dropout_out, training=self.training)
  285. # save state for next time step
  286. prev_hiddens[i] = hidden
  287. prev_cells[i] = cell
  288. # apply attention using the last layer's hidden state
  289. if self.attention is not None:
  290. out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs, encoder_padding_mask)
  291. else:
  292. out = hidden
  293. out = F.dropout(out, p=self.dropout_out, training=self.training)
  294. # input feeding
  295. input_feed = out
  296. # save final output
  297. outs.append(out)
  298. # cache previous states (no-op except during incremental generation)
  299. utils.set_incremental_state(
  300. self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed))
  301. # collect outputs across time steps
  302. x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)
  303. # T x B x C -> B x T x C
  304. x = x.transpose(1, 0)
  305. # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
  306. attn_scores = attn_scores.transpose(0, 2)
  307. # project back to size of vocabulary
  308. if hasattr(self, 'additional_fc'):
  309. x = self.additional_fc(x)
  310. x = F.dropout(x, p=self.dropout_out, training=self.training)
  311. x = self.fc_out(x)
  312. return x, attn_scores
  313. def reorder_incremental_state(self, incremental_state, new_order):
  314. super().reorder_incremental_state(incremental_state, new_order)
  315. cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
  316. if cached_state is None:
  317. return
  318. def reorder_state(state):
  319. if isinstance(state, list):
  320. return [reorder_state(state_i) for state_i in state]
  321. return state.index_select(0, new_order)
  322. new_state = tuple(map(reorder_state, cached_state))
  323. utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
  324. def reorder_encoder_out(self, encoder_out_dict, new_order):
  325. encoder_out_dict['encoder_out'] = tuple(
  326. eo.index_select(1, new_order)
  327. for eo in encoder_out_dict['encoder_out']
  328. )
  329. if encoder_out_dict['encoder_padding_mask'] is not None:
  330. encoder_out_dict['encoder_padding_mask'] = \
  331. encoder_out_dict['encoder_padding_mask'].index_select(1, new_order)
  332. return encoder_out_dict
  333. def max_positions(self):
  334. """Maximum output length supported by the decoder."""
  335. return int(1e5) # an arbitrary large number
  336. def Embedding(num_embeddings, embedding_dim, padding_idx):
  337. m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
  338. nn.init.uniform(m.weight, -0.1, 0.1)
  339. nn.init.constant(m.weight[padding_idx], 0)
  340. return m
  341. def LSTM(input_size, hidden_size, **kwargs):
  342. m = nn.LSTM(input_size, hidden_size, **kwargs)
  343. for name, param in m.named_parameters():
  344. if 'weight' in name or 'bias' in name:
  345. param.data.uniform_(-0.1, 0.1)
  346. return m
  347. def LSTMCell(input_size, hidden_size, **kwargs):
  348. m = nn.LSTMCell(input_size, hidden_size, **kwargs)
  349. for name, param in m.named_parameters():
  350. if 'weight' in name or 'bias' in name:
  351. param.data.uniform_(-0.1, 0.1)
  352. return m
  353. def Linear(in_features, out_features, bias=True, dropout=0):
  354. """Weight-normalized Linear layer (input: N x T x C)"""
  355. m = nn.Linear(in_features, out_features, bias=bias)
  356. m.weight.data.uniform_(-0.1, 0.1)
  357. if bias:
  358. m.bias.data.uniform_(-0.1, 0.1)
  359. return m
  360. @register_model_architecture('lstm', 'lstm')
  361. def base_architecture(args):
  362. args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
  363. args.encoder_hidden_size = getattr(args, 'encoder_hidden_size', 512)
  364. args.encoder_layers = getattr(args, 'encoder_layers', 1)
  365. args.encoder_bidirectional = getattr(args, 'encoder_bidirectional', False)
  366. args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', args.dropout)
  367. args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', args.dropout)
  368. args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
  369. args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 512)
  370. args.decoder_layers = getattr(args, 'decoder_layers', 1)
  371. args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512)
  372. args.decoder_attention = getattr(args, 'decoder_attention', '1')
  373. args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout)
  374. args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout)
  375. @register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en')
  376. def lstm_wiseman_iwslt_de_en(args):
  377. base_architecture(args)
  378. args.encoder_embed_dim = 256
  379. args.encoder_hidden_size = 256
  380. args.encoder_layers = 1
  381. args.encoder_bidirectional = False
  382. args.encoder_dropout_in = 0
  383. args.encoder_dropout_out = 0
  384. args.decoder_embed_dim = 256
  385. args.decoder_hidden_size = 256
  386. args.decoder_layers = 1
  387. args.decoder_out_embed_dim = 256
  388. args.decoder_attention = '1'
  389. args.decoder_dropout_in = 0
  390. @register_model_architecture('lstm', 'lstm_luong_wmt_en_de')
  391. def lstm_luong_wmt_en_de(args):
  392. base_architecture(args)
  393. args.encoder_embed_dim = 1000
  394. args.encoder_hidden_size = 1000
  395. args.encoder_layers = 4
  396. args.encoder_dropout_out = 0
  397. args.encoder_bidirectional = False
  398. args.decoder_embed_dim = 1000
  399. args.decoder_hidden_size = 1000
  400. args.decoder_layers = 4
  401. args.decoder_out_embed_dim = 1000
  402. args.decoder_attention = '1'
  403. args.decoder_dropout_out = 0
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...