1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
|
- # Copyright (c) 2017-present, Facebook, Inc.
- # All rights reserved.
- #
- # This source code is licensed under the license found in the LICENSE file in
- # the root directory of this source tree. An additional grant of patent rights
- # can be found in the PATENTS file in the same directory.
- import math
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from fairseq.data import LanguagePairDataset
- from fairseq.modules import (
- LearnedPositionalEmbedding, MultiheadAttention,
- SinusoidalPositionalEmbedding,
- )
- from . import (
- FairseqIncrementalDecoder, FairseqEncoder, FairseqModel,
- register_model, register_model_architecture,
- )
- @register_model('transformer')
- class TransformerModel(FairseqModel):
- def __init__(self, encoder, decoder):
- super().__init__(encoder, decoder)
- @staticmethod
- def add_args(parser):
- """Add model-specific arguments to the parser."""
- parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
- help='dropout probability')
- parser.add_argument('--attention-dropout', default=0., type=float, metavar='D',
- help='dropout probability for attention weights')
- parser.add_argument('--relu-dropout', default=0., type=float, metavar='D',
- help='dropout probability after ReLU in FFN')
- parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
- help='encoder embedding dimension')
- parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
- help='encoder embedding dimension for FFN')
- parser.add_argument('--encoder-layers', type=int, metavar='N',
- help='num encoder layers')
- parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
- help='num encoder attention heads')
- parser.add_argument('--encoder-normalize-before', default=False, action='store_true',
- help='apply layernorm before each encoder block')
- parser.add_argument('--encoder-learned-pos', default=False, action='store_true',
- help='use learned positional embeddings in the encoder')
- parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
- help='decoder embedding dimension')
- parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
- help='decoder embedding dimension for FFN')
- parser.add_argument('--decoder-layers', type=int, metavar='N',
- help='num decoder layers')
- parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
- help='num decoder attention heads')
- parser.add_argument('--decoder-learned-pos', default=False, action='store_true',
- help='use learned positional embeddings in the decoder')
- parser.add_argument('--decoder-normalize-before', default=False, action='store_true',
- help='apply layernorm before each decoder block')
- parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true',
- help='share decoder input and output embeddings')
- parser.add_argument('--share-all-embeddings', default=False, action='store_true',
- help='share encoder, decoder and output embeddings'
- ' (requires shared dictionary and embed dim)')
- @classmethod
- def build_model(cls, args, src_dict, dst_dict):
- """Build a new model instance."""
- def build_embedding(dictionary, embed_dim):
- num_embeddings = len(dictionary)
- padding_idx = dictionary.pad()
- return Embedding(num_embeddings, embed_dim, padding_idx)
- if args.share_all_embeddings:
- if src_dict != dst_dict:
- raise RuntimeError('--share-all-embeddings requires a joined dictionary')
- if args.encoder_embed_dim != args.decoder_embed_dim:
- raise RuntimeError(
- '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
- encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim)
- decoder_embed_tokens = encoder_embed_tokens
- args.share_decoder_input_output_embed = True
- else:
- encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim)
- decoder_embed_tokens = build_embedding(dst_dict, args.decoder_embed_dim)
- encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens)
- decoder = TransformerDecoder(args, dst_dict, decoder_embed_tokens)
- return TransformerModel(encoder, decoder)
- class TransformerEncoder(FairseqEncoder):
- """Transformer encoder."""
- def __init__(self, args, dictionary, embed_tokens):
- super().__init__(dictionary)
- self.dropout = args.dropout
- embed_dim = embed_tokens.embedding_dim
- self.padding_idx = embed_tokens.padding_idx
- self.embed_tokens = embed_tokens
- self.embed_scale = math.sqrt(embed_dim)
- self.embed_positions = PositionalEmbedding(
- 1024, embed_dim, self.padding_idx,
- left_pad=LanguagePairDataset.LEFT_PAD_SOURCE,
- learned=args.encoder_learned_pos,
- )
- self.layers = nn.ModuleList([])
- self.layers.extend([
- TransformerEncoderLayer(args)
- for i in range(args.encoder_layers)
- ])
- def forward(self, src_tokens, src_lengths):
- # embed tokens and positions
- x = self.embed_scale * self.embed_tokens(src_tokens)
- x += self.embed_positions(src_tokens)
- x = F.dropout(x, p=self.dropout, training=self.training)
- # B x T x C -> T x B x C
- x = x.transpose(0, 1)
- # compute padding mask
- encoder_padding_mask = src_tokens.eq(self.padding_idx)
- if not encoder_padding_mask.any():
- encoder_padding_mask = None
- # encoder layers
- for layer in self.layers:
- x = layer(x, encoder_padding_mask)
- return {
- 'encoder_out': x, # T x B x C
- 'encoder_padding_mask': encoder_padding_mask, # B x T
- }
- def max_positions(self):
- """Maximum input length supported by the encoder."""
- return self.embed_positions.max_positions()
- def upgrade_state_dict(self, state_dict):
- if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
- if 'encoder.embed_positions.weights' in state_dict:
- del state_dict['encoder.embed_positions.weights']
- if 'encoder.embed_positions._float_tensor' not in state_dict:
- state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor()
- return state_dict
- class TransformerDecoder(FairseqIncrementalDecoder):
- """Transformer decoder."""
- def __init__(self, args, dictionary, embed_tokens):
- super().__init__(dictionary)
- self.dropout = args.dropout
- self.share_input_output_embed = args.share_decoder_input_output_embed
- embed_dim = embed_tokens.embedding_dim
- padding_idx = embed_tokens.padding_idx
- self.embed_tokens = embed_tokens
- self.embed_scale = math.sqrt(embed_dim)
- self.embed_positions = PositionalEmbedding(
- 1024, embed_dim, padding_idx,
- left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
- learned=args.decoder_learned_pos,
- )
- self.layers = nn.ModuleList([])
- self.layers.extend([
- TransformerDecoderLayer(args)
- for i in range(args.decoder_layers)
- ])
- if not self.share_input_output_embed:
- self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
- nn.init.normal(self.embed_out, mean=0, std=embed_dim**-0.5)
- def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
- # embed positions
- positions = self.embed_positions(
- prev_output_tokens,
- incremental_state=incremental_state,
- )
- if incremental_state is not None:
- prev_output_tokens = prev_output_tokens[:, -1:]
- positions = positions[:, -1:]
- # embed tokens and positions
- x = self.embed_scale * self.embed_tokens(prev_output_tokens)
- x += positions
- x = F.dropout(x, p=self.dropout, training=self.training)
- # B x T x C -> T x B x C
- x = x.transpose(0, 1)
- # decoder layers
- for layer in self.layers:
- x, attn = layer(
- x,
- encoder_out['encoder_out'],
- encoder_out['encoder_padding_mask'],
- incremental_state,
- )
- # T x B x C -> B x T x C
- x = x.transpose(0, 1)
- # project back to size of vocabulary
- if self.share_input_output_embed:
- x = F.linear(x, self.embed_tokens.weight)
- else:
- x = F.linear(x, self.embed_out)
- return x, attn
- def reorder_encoder_out(self, encoder_out_dict, new_order):
- if encoder_out_dict['encoder_padding_mask'] is not None:
- encoder_out_dict['encoder_padding_mask'] = \
- encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
- return encoder_out_dict
- def max_positions(self):
- """Maximum output length supported by the decoder."""
- return self.embed_positions.max_positions()
- def upgrade_state_dict(self, state_dict):
- if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
- if 'decoder.embed_positions.weights' in state_dict:
- del state_dict['decoder.embed_positions.weights']
- if 'decoder.embed_positions._float_tensor' not in state_dict:
- state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor()
- return state_dict
- class TransformerEncoderLayer(nn.Module):
- """Encoder layer block.
- In the original paper each operation (multi-head attention or FFN) is
- postprocessed with: dropout -> add residual -> layernorm.
- In the tensor2tensor code they suggest that learning is more robust when
- preprocessing each layer with layernorm and postprocessing with:
- dropout -> add residual.
- We default to the approach in the paper, but the tensor2tensor approach can
- be enabled by setting `normalize_before=True`.
- """
- def __init__(self, args):
- super().__init__()
- self.embed_dim = args.encoder_embed_dim
- self.self_attn = MultiheadAttention(
- self.embed_dim, args.encoder_attention_heads,
- dropout=args.attention_dropout,
- )
- self.dropout = args.dropout
- self.relu_dropout = args.relu_dropout
- self.normalize_before = args.encoder_normalize_before
- self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
- self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
- self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(2)])
- def forward(self, x, encoder_padding_mask):
- residual = x
- x = self.maybe_layer_norm(0, x, before=True)
- x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask)
- x = F.dropout(x, p=self.dropout, training=self.training)
- x = residual + x
- x = self.maybe_layer_norm(0, x, after=True)
- residual = x
- x = self.maybe_layer_norm(1, x, before=True)
- x = F.relu(self.fc1(x))
- x = F.dropout(x, p=self.relu_dropout, training=self.training)
- x = self.fc2(x)
- x = F.dropout(x, p=self.dropout, training=self.training)
- x = residual + x
- x = self.maybe_layer_norm(1, x, after=True)
- return x
- def maybe_layer_norm(self, i, x, before=False, after=False):
- assert before ^ after
- if after ^ self.normalize_before:
- return self.layer_norms[i](x)
- else:
- return x
- class TransformerDecoderLayer(nn.Module):
- """Decoder layer block."""
- def __init__(self, args):
- super().__init__()
- self.embed_dim = args.decoder_embed_dim
- self.self_attn = MultiheadAttention(
- self.embed_dim, args.decoder_attention_heads,
- dropout=args.attention_dropout,
- )
- self.dropout = args.dropout
- self.relu_dropout = args.relu_dropout
- self.normalize_before = args.decoder_normalize_before
- self.encoder_attn = MultiheadAttention(
- self.embed_dim, args.decoder_attention_heads,
- dropout=args.attention_dropout,
- )
- self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
- self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
- self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)])
- def forward(self, x, encoder_out, encoder_padding_mask, incremental_state):
- residual = x
- x = self.maybe_layer_norm(0, x, before=True)
- x, _ = self.self_attn(
- query=x,
- key=x,
- value=x,
- mask_future_timesteps=True,
- incremental_state=incremental_state,
- need_weights=False,
- )
- x = F.dropout(x, p=self.dropout, training=self.training)
- x = residual + x
- x = self.maybe_layer_norm(0, x, after=True)
- residual = x
- x = self.maybe_layer_norm(1, x, before=True)
- x, attn = self.encoder_attn(
- query=x,
- key=encoder_out,
- value=encoder_out,
- key_padding_mask=encoder_padding_mask,
- incremental_state=incremental_state,
- static_kv=True,
- )
- x = F.dropout(x, p=self.dropout, training=self.training)
- x = residual + x
- x = self.maybe_layer_norm(1, x, after=True)
- residual = x
- x = self.maybe_layer_norm(2, x, before=True)
- x = F.relu(self.fc1(x))
- x = F.dropout(x, p=self.relu_dropout, training=self.training)
- x = self.fc2(x)
- x = F.dropout(x, p=self.dropout, training=self.training)
- x = residual + x
- x = self.maybe_layer_norm(2, x, after=True)
- return x, attn
- def maybe_layer_norm(self, i, x, before=False, after=False):
- assert before ^ after
- if after ^ self.normalize_before:
- return self.layer_norms[i](x)
- else:
- return x
- def Embedding(num_embeddings, embedding_dim, padding_idx):
- m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
- nn.init.normal(m.weight, mean=0, std=embedding_dim**-0.5)
- return m
- def LayerNorm(embedding_dim):
- m = nn.LayerNorm(embedding_dim)
- return m
- def Linear(in_features, out_features, bias=True):
- m = nn.Linear(in_features, out_features, bias)
- nn.init.xavier_uniform(m.weight)
- nn.init.constant(m.bias, 0.)
- return m
- def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
- if learned:
- m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
- nn.init.normal(m.weight, mean=0, std=embedding_dim**-0.5)
- nn.init.constant(m.weight[padding_idx], 0)
- else:
- m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, init_size=num_embeddings)
- return m
- @register_model_architecture('transformer', 'transformer')
- def base_architecture(args):
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 2048)
- args.encoder_layers = getattr(args, 'encoder_layers', 6)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim)
- args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim)
- args.decoder_layers = getattr(args, 'decoder_layers', 6)
- args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
- @register_model_architecture('transformer', 'transformer_iwslt_de_en')
- def transformer_iwslt_de_en(args):
- base_architecture(args)
- args.encoder_embed_dim = 256
- args.encoder_ffn_embed_dim = 512
- args.encoder_layers = 3
- args.encoder_attention_heads = 4
- args.decoder_embed_dim = 256
- args.decoder_ffn_embed_dim = 512
- args.decoder_layers = 3
- args.decoder_attention_heads = 4
- @register_model_architecture('transformer', 'transformer_wmt_en_de')
- def transformer_wmt_en_de(args):
- base_architecture(args)
- args.encoder_embed_dim = 512
- args.encoder_ffn_embed_dim = 2048
- args.encoder_layers = 6
- args.encoder_attention_heads = 8
- args.decoder_embed_dim = 512
- args.decoder_ffn_embed_dim = 2048
- args.decoder_layers = 6
- args.decoder_attention_heads = 8
- # parameters used in the "Attention Is All You Need" paper (Vaswani, et al, 2017)
- @register_model_architecture('transformer', 'transformer_vaswani_wmt_en_de_big')
- def transformer_vaswani_wmt_en_de_big(args):
- base_architecture(args)
- args.encoder_embed_dim = 1024
- args.encoder_ffn_embed_dim = 4096
- args.encoder_layers = 6
- args.encoder_attention_heads = 16
- args.decoder_embed_dim = 1024
- args.decoder_ffn_embed_dim = 4096
- args.decoder_layers = 6
- args.decoder_attention_heads = 16
- @register_model_architecture('transformer', 'transformer_wmt_en_de_big')
- def transformer_wmt_en_de_big(args):
- transformer_vaswani_wmt_en_de_big(args)
- args.attention_dropout = 0.1
- # default parameters used in tensor2tensor implementation
- @register_model_architecture('transformer', 'transformer_wmt_en_de_big_t2t')
- def transformer_wmt_en_de_big_t2t(args):
- transformer_vaswani_wmt_en_de_big(args)
- args.encoder_normalize_before = True
- args.decoder_normalize_before = True
- args.attention_dropout = 0.1
- args.relu_dropout = 0.1
|