Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

bleu.py 3.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import ctypes
  8. import math
  9. import torch
  10. try:
  11. from fairseq import libbleu
  12. except ImportError as e:
  13. import sys
  14. sys.stderr.write('ERROR: missing libbleu.so. run `python setup.py install`\n')
  15. raise e
  16. C = ctypes.cdll.LoadLibrary(libbleu.__file__)
  17. class BleuStat(ctypes.Structure):
  18. _fields_ = [
  19. ('reflen', ctypes.c_size_t),
  20. ('predlen', ctypes.c_size_t),
  21. ('match1', ctypes.c_size_t),
  22. ('count1', ctypes.c_size_t),
  23. ('match2', ctypes.c_size_t),
  24. ('count2', ctypes.c_size_t),
  25. ('match3', ctypes.c_size_t),
  26. ('count3', ctypes.c_size_t),
  27. ('match4', ctypes.c_size_t),
  28. ('count4', ctypes.c_size_t),
  29. ]
  30. class Scorer(object):
  31. def __init__(self, pad, eos, unk):
  32. self.stat = BleuStat()
  33. self.pad = pad
  34. self.eos = eos
  35. self.unk = unk
  36. self.reset()
  37. def reset(self, one_init=False):
  38. if one_init:
  39. C.bleu_one_init(ctypes.byref(self.stat))
  40. else:
  41. C.bleu_zero_init(ctypes.byref(self.stat))
  42. def add(self, ref, pred):
  43. if not isinstance(ref, torch.IntTensor):
  44. raise TypeError('ref must be a torch.IntTensor (got {})'
  45. .format(type(ref)))
  46. if not isinstance(pred, torch.IntTensor):
  47. raise TypeError('pred must be a torch.IntTensor(got {})'
  48. .format(type(pred)))
  49. # don't match unknown words
  50. rref = ref.clone()
  51. assert not rref.lt(0).any()
  52. rref[rref.eq(self.unk)] = -999
  53. rref = rref.contiguous().view(-1)
  54. pred = pred.contiguous().view(-1)
  55. C.bleu_add(
  56. ctypes.byref(self.stat),
  57. ctypes.c_size_t(rref.size(0)),
  58. ctypes.c_void_p(rref.data_ptr()),
  59. ctypes.c_size_t(pred.size(0)),
  60. ctypes.c_void_p(pred.data_ptr()),
  61. ctypes.c_int(self.pad),
  62. ctypes.c_int(self.eos))
  63. def score(self, order=4):
  64. psum = sum(math.log(p) if p > 0 else float('-Inf')
  65. for p in self.precision()[:order])
  66. return self.brevity() * math.exp(psum / order) * 100
  67. def precision(self):
  68. def ratio(a, b):
  69. return a / b if b > 0 else 0
  70. return [
  71. ratio(self.stat.match1, self.stat.count1),
  72. ratio(self.stat.match2, self.stat.count2),
  73. ratio(self.stat.match3, self.stat.count3),
  74. ratio(self.stat.match4, self.stat.count4),
  75. ]
  76. def brevity(self):
  77. r = self.stat.reflen / self.stat.predlen
  78. return min(1, math.exp(1 - r))
  79. def result_string(self, order=4):
  80. assert order <= 4, "BLEU scores for order > 4 aren't supported"
  81. fmt = 'BLEU{} = {:2.2f}, {:2.1f}'
  82. for _ in range(1, order):
  83. fmt += '/{:2.1f}'
  84. fmt += ' (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})'
  85. bleup = [p * 100 for p in self.precision()[:order]]
  86. return fmt.format(order, self.score(order=order), *bleup,
  87. self.brevity(), self.stat.predlen/self.stat.reflen,
  88. self.stat.predlen, self.stat.reflen)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...