Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

sequence_scorer.py 3.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. from fairseq import utils
  8. class SequenceScorer(object):
  9. """Scores the target for a given source sentence."""
  10. def __init__(self, models):
  11. self.models = models
  12. self.pad = models[0].dst_dict.pad()
  13. assert all(m.dst_dict.pad() == self.pad for m in self.models[1:])
  14. def cuda(self):
  15. for model in self.models:
  16. model.cuda()
  17. return self
  18. def score_batched_itr(self, data_itr, cuda=False, timer=None):
  19. """Iterate over a batched dataset and yield scored translations."""
  20. for sample in data_itr:
  21. s = utils.make_variable(sample, volatile=True, cuda=cuda)
  22. if timer is not None:
  23. timer.start()
  24. pos_scores, attn = self.score(s)
  25. if timer is not None:
  26. timer.stop(s['ntokens'])
  27. for i, id in enumerate(s['id'].data):
  28. # remove padding from ref
  29. src = utils.strip_pad(s['net_input']['src_tokens'].data[i, :], self.pad)
  30. ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
  31. tgt_len = ref.numel()
  32. pos_scores_i = pos_scores[i][:tgt_len]
  33. score_i = pos_scores_i.sum() / tgt_len
  34. attn_i = attn[i]
  35. _, alignment = attn_i.max(dim=0)
  36. hypos = [{
  37. 'tokens': ref,
  38. 'score': score_i,
  39. 'attention': attn_i,
  40. 'alignment': alignment,
  41. 'positional_scores': pos_scores_i,
  42. }]
  43. # return results in the same format as SequenceGenerator
  44. yield id, src, ref, hypos
  45. def score(self, sample):
  46. """Score a batch of translations."""
  47. net_input = sample['net_input']
  48. # compute scores for each model in the ensemble
  49. avg_probs = None
  50. avg_attn = None
  51. for model in self.models:
  52. with utils.maybe_no_grad():
  53. model.eval()
  54. encoder_out = model.encoder(
  55. net_input['src_tokens'],
  56. net_input['src_lengths'],
  57. )
  58. decoder_out = model.decoder(
  59. net_input['prev_output_tokens'],
  60. encoder_out,
  61. )
  62. attn = decoder_out[1]
  63. probs = model.get_normalized_probs(decoder_out, log_probs=False).data
  64. if avg_probs is None:
  65. avg_probs = probs
  66. else:
  67. avg_probs.add_(probs)
  68. if attn is not None:
  69. attn = attn.data
  70. if avg_attn is None:
  71. avg_attn = attn
  72. else:
  73. avg_attn.add_(attn)
  74. avg_probs.div_(len(self.models))
  75. avg_probs.log_()
  76. if avg_attn is not None:
  77. avg_attn.div_(len(self.models))
  78. avg_probs = avg_probs.gather(
  79. dim=2,
  80. index=sample['target'].data.unsqueeze(-1),
  81. )
  82. return avg_probs.squeeze(2), avg_attn
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...