Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

sequence_generator.py 23 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import math
  8. import torch
  9. from fairseq import utils
  10. from fairseq.data import LanguagePairDataset
  11. from fairseq.models import FairseqIncrementalDecoder
  12. class SequenceGenerator(object):
  13. def __init__(self, models, beam_size=1, minlen=1, maxlen=None,
  14. stop_early=True, normalize_scores=True, len_penalty=1,
  15. unk_penalty=0, retain_dropout=False, sampling=False):
  16. """Generates translations of a given source sentence.
  17. Args:
  18. min/maxlen: The length of the generated output will be bounded by
  19. minlen and maxlen (not including the end-of-sentence marker).
  20. stop_early: Stop generation immediately after we finalize beam_size
  21. hypotheses, even though longer hypotheses might have better
  22. normalized scores.
  23. normalize_scores: Normalize scores by the length of the output.
  24. """
  25. self.models = models
  26. self.pad = models[0].dst_dict.pad()
  27. self.unk = models[0].dst_dict.unk()
  28. self.eos = models[0].dst_dict.eos()
  29. assert all(m.dst_dict.pad() == self.pad for m in self.models[1:])
  30. assert all(m.dst_dict.unk() == self.unk for m in self.models[1:])
  31. assert all(m.dst_dict.eos() == self.eos for m in self.models[1:])
  32. self.vocab_size = len(models[0].dst_dict)
  33. self.beam_size = beam_size
  34. self.minlen = minlen
  35. max_decoder_len = min(m.max_decoder_positions() for m in self.models)
  36. max_decoder_len -= 1 # we define maxlen not including the EOS marker
  37. self.maxlen = max_decoder_len if maxlen is None else min(maxlen, max_decoder_len)
  38. self.stop_early = stop_early
  39. self.normalize_scores = normalize_scores
  40. self.len_penalty = len_penalty
  41. self.unk_penalty = unk_penalty
  42. self.retain_dropout = retain_dropout
  43. self.sampling = sampling
  44. def cuda(self):
  45. for model in self.models:
  46. model.cuda()
  47. return self
  48. def generate_batched_itr(self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
  49. cuda=False, timer=None, prefix_size=0):
  50. """Iterate over a batched dataset and yield individual translations.
  51. Args:
  52. maxlen_a/b: generate sequences of maximum length ax + b,
  53. where x is the source sentence length.
  54. cuda: use GPU for generation
  55. timer: StopwatchMeter for timing generations.
  56. """
  57. if maxlen_b is None:
  58. maxlen_b = self.maxlen
  59. for sample in data_itr:
  60. s = utils.make_variable(sample, volatile=True, cuda=cuda)
  61. input = s['net_input']
  62. srclen = input['src_tokens'].size(1)
  63. if timer is not None:
  64. timer.start()
  65. with utils.maybe_no_grad():
  66. hypos = self.generate(
  67. input['src_tokens'],
  68. input['src_lengths'],
  69. beam_size=beam_size,
  70. maxlen=int(maxlen_a*srclen + maxlen_b),
  71. prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
  72. )
  73. if timer is not None:
  74. timer.stop(sum(len(h[0]['tokens']) for h in hypos))
  75. for i, id in enumerate(s['id'].data):
  76. # remove padding
  77. src = utils.strip_pad(input['src_tokens'].data[i, :], self.pad)
  78. ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
  79. yield id, src, ref, hypos[i]
  80. def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
  81. """Generate a batch of translations."""
  82. with utils.maybe_no_grad():
  83. return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens)
  84. def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
  85. bsz, srclen = src_tokens.size()
  86. maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
  87. # the max beam size is the dictionary size - 1, since we never select pad
  88. beam_size = beam_size if beam_size is not None else self.beam_size
  89. beam_size = min(beam_size, self.vocab_size - 1)
  90. encoder_outs = []
  91. incremental_states = {}
  92. for model in self.models:
  93. if not self.retain_dropout:
  94. model.eval()
  95. if isinstance(model.decoder, FairseqIncrementalDecoder):
  96. incremental_states[model] = {}
  97. else:
  98. incremental_states[model] = None
  99. # compute the encoder output for each beam
  100. encoder_out = model.encoder(
  101. src_tokens.repeat(1, beam_size).view(-1, srclen),
  102. src_lengths.expand(beam_size, src_lengths.numel()).t().contiguous().view(-1),
  103. )
  104. encoder_outs.append(encoder_out)
  105. # initialize buffers
  106. scores = src_tokens.data.new(bsz * beam_size, maxlen + 1).float().fill_(0)
  107. scores_buf = scores.clone()
  108. tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
  109. tokens_buf = tokens.clone()
  110. tokens[:, 0] = self.eos
  111. attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
  112. attn_buf = attn.clone()
  113. # list of completed sentences
  114. finalized = [[] for i in range(bsz)]
  115. finished = [False for i in range(bsz)]
  116. worst_finalized = [{'idx': None, 'score': -math.inf} for i in range(bsz)]
  117. num_remaining_sent = bsz
  118. # number of candidate hypos per step
  119. cand_size = 2 * beam_size # 2 x beam size in case half are EOS
  120. # offset arrays for converting between different indexing schemes
  121. bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
  122. cand_offsets = torch.arange(0, cand_size).type_as(tokens)
  123. # helper function for allocating buffers on the fly
  124. buffers = {}
  125. def buffer(name, type_of=tokens): # noqa
  126. if name not in buffers:
  127. buffers[name] = type_of.new()
  128. return buffers[name]
  129. def is_finished(sent, step, unfinalized_scores=None):
  130. """
  131. Check whether we've finished generation for a given sentence, by
  132. comparing the worst score among finalized hypotheses to the best
  133. possible score among unfinalized hypotheses.
  134. """
  135. assert len(finalized[sent]) <= beam_size
  136. if len(finalized[sent]) == beam_size:
  137. if self.stop_early or step == maxlen or unfinalized_scores is None:
  138. return True
  139. # stop if the best unfinalized score is worse than the worst
  140. # finalized one
  141. best_unfinalized_score = unfinalized_scores[sent].max()
  142. if self.normalize_scores:
  143. best_unfinalized_score /= maxlen
  144. if worst_finalized[sent]['score'] >= best_unfinalized_score:
  145. return True
  146. return False
  147. def finalize_hypos(step, bbsz_idx, eos_scores, unfinalized_scores=None):
  148. """
  149. Finalize the given hypotheses at this step, while keeping the total
  150. number of finalized hypotheses per sentence <= beam_size.
  151. Note: the input must be in the desired finalization order, so that
  152. hypotheses that appear earlier in the input are preferred to those
  153. that appear later.
  154. Args:
  155. step: current time step
  156. bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
  157. indicating which hypotheses to finalize
  158. eos_scores: A vector of the same size as bbsz_idx containing
  159. scores for each hypothesis
  160. unfinalized_scores: A vector containing scores for all
  161. unfinalized hypotheses
  162. """
  163. assert bbsz_idx.numel() == eos_scores.numel()
  164. # clone relevant token and attention tensors
  165. tokens_clone = tokens.index_select(0, bbsz_idx)
  166. tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
  167. tokens_clone[:, step] = self.eos
  168. attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2]
  169. # compute scores per token position
  170. pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
  171. pos_scores[:, step] = eos_scores
  172. # convert from cumulative to per-position scores
  173. pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
  174. # normalize sentence-level scores
  175. if self.normalize_scores:
  176. eos_scores /= (step + 1) ** self.len_penalty
  177. cum_unfin = []
  178. prev = 0
  179. for f in finished:
  180. if f:
  181. prev += 1
  182. else:
  183. cum_unfin.append(prev)
  184. sents_seen = set()
  185. for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())):
  186. unfin_idx = idx // beam_size
  187. sent = unfin_idx + cum_unfin[unfin_idx]
  188. sents_seen.add((sent, unfin_idx))
  189. def get_hypo():
  190. # remove padding tokens from attn scores
  191. nonpad_idxs = src_tokens[sent].ne(self.pad)
  192. hypo_attn = attn_clone[i][nonpad_idxs]
  193. _, alignment = hypo_attn.max(dim=0)
  194. return {
  195. 'tokens': tokens_clone[i],
  196. 'score': score,
  197. 'attention': hypo_attn, # src_len x tgt_len
  198. 'alignment': alignment,
  199. 'positional_scores': pos_scores[i],
  200. }
  201. if len(finalized[sent]) < beam_size:
  202. finalized[sent].append(get_hypo())
  203. elif not self.stop_early and score > worst_finalized[sent]['score']:
  204. # replace worst hypo for this sentence with new/better one
  205. worst_idx = worst_finalized[sent]['idx']
  206. if worst_idx is not None:
  207. finalized[sent][worst_idx] = get_hypo()
  208. # find new worst finalized hypo for this sentence
  209. idx, s = min(enumerate(finalized[sent]), key=lambda r: r[1]['score'])
  210. worst_finalized[sent] = {
  211. 'score': s['score'],
  212. 'idx': idx,
  213. }
  214. newly_finished = []
  215. for sent, unfin_idx in sents_seen:
  216. # check termination conditions for this sentence
  217. if not finished[sent] and is_finished(sent, step, unfinalized_scores):
  218. finished[sent] = True
  219. newly_finished.append(unfin_idx)
  220. return newly_finished
  221. reorder_state = None
  222. batch_idxs = None
  223. for step in range(maxlen + 1): # one extra step for EOS marker
  224. # reorder decoder internal states based on the prev choice of beams
  225. if reorder_state is not None:
  226. if batch_idxs is not None:
  227. # update beam indices to take into account removed sentences
  228. corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
  229. reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
  230. for i, model in enumerate(self.models):
  231. if isinstance(model.decoder, FairseqIncrementalDecoder):
  232. model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
  233. encoder_outs[i] = model.decoder.reorder_encoder_out(encoder_outs[i], reorder_state)
  234. probs, avg_attn_scores = self._decode(
  235. tokens[:, :step + 1], encoder_outs, incremental_states)
  236. if step == 0:
  237. # at the first step all hypotheses are equally likely, so use
  238. # only the first beam
  239. probs = probs.unfold(0, 1, beam_size).squeeze(2).contiguous()
  240. scores = scores.type_as(probs)
  241. scores_buf = scores_buf.type_as(probs)
  242. elif not self.sampling:
  243. # make probs contain cumulative scores for each hypothesis
  244. probs.add_(scores[:, step - 1].view(-1, 1))
  245. probs[:, self.pad] = -math.inf # never select pad
  246. probs[:, self.unk] -= self.unk_penalty # apply unk penalty
  247. # Record attention scores
  248. attn[:, :, step + 1].copy_(avg_attn_scores)
  249. cand_scores = buffer('cand_scores', type_of=scores)
  250. cand_indices = buffer('cand_indices')
  251. cand_beams = buffer('cand_beams')
  252. eos_bbsz_idx = buffer('eos_bbsz_idx')
  253. eos_scores = buffer('eos_scores', type_of=scores)
  254. if step < maxlen:
  255. if prefix_tokens is not None and step < prefix_tokens.size(1):
  256. probs_slice = probs.view(bsz, -1, probs.size(-1))[:, 0, :]
  257. cand_scores = torch.gather(
  258. probs_slice, dim=1,
  259. index=prefix_tokens[:, step].view(-1, 1).data
  260. ).expand(-1, cand_size)
  261. cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data
  262. cand_beams.resize_as_(cand_indices).fill_(0)
  263. elif self.sampling:
  264. assert self.pad == 1, 'sampling assumes the first two symbols can be ignored'
  265. exp_probs = probs.exp_().view(-1, self.vocab_size)
  266. if step == 0:
  267. # we exclude the first two vocab items, one of which is pad
  268. torch.multinomial(exp_probs[:, 2:], beam_size, replacement=True, out=cand_indices)
  269. cand_indices.add_(2)
  270. else:
  271. torch.multinomial(exp_probs[:, 2:], 1, replacement=True, out=cand_indices)
  272. cand_indices.add_(2)
  273. torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
  274. cand_scores.log_()
  275. cand_indices = cand_indices.view(bsz, -1).repeat(1, 2)
  276. cand_scores = cand_scores.view(bsz, -1).repeat(1, 2)
  277. if step == 0:
  278. cand_beams = torch.zeros(bsz, cand_size).type_as(cand_indices)
  279. else:
  280. cand_beams = torch.arange(0, beam_size).repeat(bsz, 2).type_as(cand_indices)
  281. # make scores cumulative
  282. cand_scores.add_(
  283. torch.gather(
  284. scores[:, step - 1].view(bsz, beam_size), dim=1,
  285. index=cand_beams,
  286. )
  287. )
  288. else:
  289. # take the best 2 x beam_size predictions. We'll choose the first
  290. # beam_size of these which don't predict eos to continue with.
  291. torch.topk(
  292. probs.view(bsz, -1),
  293. k=min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad
  294. out=(cand_scores, cand_indices),
  295. )
  296. torch.div(cand_indices, self.vocab_size, out=cand_beams)
  297. cand_indices.fmod_(self.vocab_size)
  298. else:
  299. # finalize all active hypotheses once we hit maxlen
  300. # pick the hypothesis with the highest prob of EOS right now
  301. torch.sort(
  302. probs[:, self.eos],
  303. descending=True,
  304. out=(eos_scores, eos_bbsz_idx),
  305. )
  306. num_remaining_sent -= len(finalize_hypos(
  307. step, eos_bbsz_idx, eos_scores))
  308. assert num_remaining_sent == 0
  309. break
  310. # cand_bbsz_idx contains beam indices for the top candidate
  311. # hypotheses, with a range of values: [0, bsz*beam_size),
  312. # and dimensions: [bsz, cand_size]
  313. cand_bbsz_idx = cand_beams.add(bbsz_offsets)
  314. # finalize hypotheses that end in eos
  315. eos_mask = cand_indices.eq(self.eos)
  316. finalized_sents = set()
  317. if step >= self.minlen:
  318. # only consider eos when it's among the top beam_size indices
  319. torch.masked_select(
  320. cand_bbsz_idx[:, :beam_size],
  321. mask=eos_mask[:, :beam_size],
  322. out=eos_bbsz_idx,
  323. )
  324. if eos_bbsz_idx.numel() > 0:
  325. torch.masked_select(
  326. cand_scores[:, :beam_size],
  327. mask=eos_mask[:, :beam_size],
  328. out=eos_scores,
  329. )
  330. finalized_sents = finalize_hypos(
  331. step, eos_bbsz_idx, eos_scores, cand_scores)
  332. num_remaining_sent -= len(finalized_sents)
  333. assert num_remaining_sent >= 0
  334. if num_remaining_sent == 0:
  335. break
  336. assert step < maxlen
  337. if len(finalized_sents) > 0:
  338. new_bsz = bsz - len(finalized_sents)
  339. # construct batch_idxs which holds indices of batches to keep for the next pass
  340. batch_mask = torch.ones(bsz).type_as(cand_indices)
  341. batch_mask[cand_indices.new(finalized_sents)] = 0
  342. batch_idxs = batch_mask.nonzero().squeeze(-1)
  343. eos_mask = eos_mask[batch_idxs]
  344. cand_beams = cand_beams[batch_idxs]
  345. bbsz_offsets.resize_(new_bsz, 1)
  346. cand_bbsz_idx = cand_beams.add(bbsz_offsets)
  347. cand_scores = cand_scores[batch_idxs]
  348. cand_indices = cand_indices[batch_idxs]
  349. if prefix_tokens is not None:
  350. prefix_tokens = prefix_tokens[batch_idxs]
  351. scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
  352. scores_buf.resize_as_(scores)
  353. tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
  354. tokens_buf.resize_as_(tokens)
  355. attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
  356. attn_buf.resize_as_(attn)
  357. bsz = new_bsz
  358. else:
  359. batch_idxs = None
  360. # set active_mask so that values > cand_size indicate eos hypos
  361. # and values < cand_size indicate candidate active hypos.
  362. # After, the min values per row are the top candidate active hypos
  363. active_mask = buffer('active_mask')
  364. torch.add(
  365. eos_mask.type_as(cand_offsets) * cand_size,
  366. cand_offsets[:eos_mask.size(1)],
  367. out=active_mask,
  368. )
  369. # get the top beam_size active hypotheses, which are just the hypos
  370. # with the smallest values in active_mask
  371. active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore')
  372. torch.topk(
  373. active_mask, k=beam_size, dim=1, largest=False,
  374. out=(_ignore, active_hypos)
  375. )
  376. active_bbsz_idx = buffer('active_bbsz_idx')
  377. torch.gather(
  378. cand_bbsz_idx, dim=1, index=active_hypos,
  379. out=active_bbsz_idx,
  380. )
  381. active_scores = torch.gather(
  382. cand_scores, dim=1, index=active_hypos,
  383. out=scores[:, step].view(bsz, beam_size),
  384. )
  385. active_bbsz_idx = active_bbsz_idx.view(-1)
  386. active_scores = active_scores.view(-1)
  387. # copy tokens and scores for active hypotheses
  388. torch.index_select(
  389. tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
  390. out=tokens_buf[:, :step + 1],
  391. )
  392. torch.gather(
  393. cand_indices, dim=1, index=active_hypos,
  394. out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
  395. )
  396. if step > 0:
  397. torch.index_select(
  398. scores[:, :step], dim=0, index=active_bbsz_idx,
  399. out=scores_buf[:, :step],
  400. )
  401. torch.gather(
  402. cand_scores, dim=1, index=active_hypos,
  403. out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
  404. )
  405. # copy attention for active hypotheses
  406. torch.index_select(
  407. attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
  408. out=attn_buf[:, :, :step + 2],
  409. )
  410. # swap buffers
  411. old_tokens = tokens
  412. tokens = tokens_buf
  413. tokens_buf = old_tokens
  414. old_scores = scores
  415. scores = scores_buf
  416. scores_buf = old_scores
  417. old_attn = attn
  418. attn = attn_buf
  419. attn_buf = old_attn
  420. # reorder incremental state in decoder
  421. reorder_state = active_bbsz_idx
  422. # sort by score descending
  423. for sent in range(len(finalized)):
  424. finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)
  425. return finalized
  426. def _decode(self, tokens, encoder_outs, incremental_states):
  427. # wrap in Variable
  428. tokens = utils.volatile_variable(tokens)
  429. avg_probs = None
  430. avg_attn = None
  431. for model, encoder_out in zip(self.models, encoder_outs):
  432. with utils.maybe_no_grad():
  433. if incremental_states[model] is not None:
  434. decoder_out = list(model.decoder(tokens, encoder_out, incremental_states[model]))
  435. else:
  436. decoder_out = list(model.decoder(tokens, encoder_out))
  437. decoder_out[0] = decoder_out[0][:, -1, :]
  438. attn = decoder_out[1]
  439. probs = model.get_normalized_probs(decoder_out, log_probs=False).data
  440. if avg_probs is None:
  441. avg_probs = probs
  442. else:
  443. avg_probs.add_(probs)
  444. if attn is not None:
  445. attn = attn[:, -1, :].data
  446. if avg_attn is None:
  447. avg_attn = attn
  448. else:
  449. avg_attn.add_(attn)
  450. avg_probs.div_(len(self.models))
  451. avg_probs.log_()
  452. if avg_attn is not None:
  453. avg_attn.div_(len(self.models))
  454. return avg_probs, avg_attn
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...