Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

beamable_mm.py 1.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import torch
  8. import torch.nn as nn
  9. class BeamableMM(nn.Module):
  10. """This module provides an optimized MM for beam decoding with attention.
  11. It leverage the fact that the source-side of the input is replicated beam
  12. times and the target-side of the input is of width one. This layer speeds up
  13. inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)}
  14. with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.
  15. """
  16. def __init__(self, beam_size=None):
  17. super(BeamableMM, self).__init__()
  18. self.beam_size = beam_size
  19. def forward(self, input1, input2):
  20. if (
  21. not self.training and # test mode
  22. self.beam_size is not None and # beam size is set
  23. input1.dim() == 3 and # only support batched input
  24. input1.size(1) == 1 # single time step update
  25. ):
  26. bsz, beam = input1.size(0), self.beam_size
  27. # bsz x 1 x nhu --> bsz/beam x beam x nhu
  28. input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1)
  29. # bsz x sz2 x nhu --> bsz/beam x sz2 x nhu
  30. input2 = input2.unfold(0, beam, beam)[:, :, :, 0]
  31. # use non batched operation if bsz = beam
  32. if input1.size(0) == 1:
  33. output = torch.mm(input1[0, :, :], input2[0, :, :])
  34. else:
  35. output = input1.bmm(input2)
  36. return output.view(bsz, 1, -1)
  37. else:
  38. return input1.bmm(input2)
  39. def set_beam_size(self, beam_size):
  40. self.beam_size = beam_size
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...