Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

linearized_convolution.py 3.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import torch.nn.functional as F
  8. from fairseq import utils
  9. from .conv_tbc import ConvTBC
  10. class LinearizedConvolution(ConvTBC):
  11. """An optimized version of nn.Conv1d.
  12. At training time, this module uses ConvTBC, which is an optimized version
  13. of Conv1d. At inference time, it optimizes incremental generation (i.e.,
  14. one time step at a time) by replacing the convolutions with linear layers.
  15. Note that the input order changes from training to inference.
  16. """
  17. def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
  18. super().__init__(in_channels, out_channels, kernel_size, **kwargs)
  19. self._linearized_weight = None
  20. self.register_backward_hook(self._clear_linearized_weight)
  21. def forward(self, input, incremental_state=None):
  22. """
  23. Input:
  24. Time x Batch x Channel during training
  25. Batch x Time x Channel during inference
  26. Args:
  27. incremental_state: Used to buffer signal; if not None, then input is
  28. expected to contain a single frame. If the input order changes
  29. between time steps, call reorder_incremental_state.
  30. """
  31. if incremental_state is None:
  32. output = super().forward(input)
  33. if self.kernel_size[0] > 1 and self.padding[0] > 0:
  34. # remove future timesteps added by padding
  35. output = output[:-self.padding[0], :, :]
  36. return output
  37. # reshape weight
  38. weight = self._get_linearized_weight()
  39. kw = self.kernel_size[0]
  40. bsz = input.size(0) # input: bsz x len x dim
  41. if kw > 1:
  42. input = input.data
  43. input_buffer = self._get_input_buffer(incremental_state)
  44. if input_buffer is None:
  45. input_buffer = input.new(bsz, kw, input.size(2)).zero_()
  46. self._set_input_buffer(incremental_state, input_buffer)
  47. else:
  48. # shift buffer
  49. input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone()
  50. # append next input
  51. input_buffer[:, -1, :] = input[:, -1, :]
  52. input = utils.volatile_variable(input_buffer)
  53. with utils.maybe_no_grad():
  54. output = F.linear(input.view(bsz, -1), weight, self.bias)
  55. return output.view(bsz, 1, -1)
  56. def reorder_incremental_state(self, incremental_state, new_order):
  57. input_buffer = self._get_input_buffer(incremental_state)
  58. if input_buffer is not None:
  59. input_buffer = input_buffer.index_select(0, new_order)
  60. self._set_input_buffer(incremental_state, input_buffer)
  61. def _get_input_buffer(self, incremental_state):
  62. return utils.get_incremental_state(self, incremental_state, 'input_buffer')
  63. def _set_input_buffer(self, incremental_state, new_buffer):
  64. return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer)
  65. def _get_linearized_weight(self):
  66. if self._linearized_weight is None:
  67. kw = self.kernel_size[0]
  68. weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
  69. assert weight.size() == (self.out_channels, kw, self.in_channels)
  70. self._linearized_weight = weight.view(self.out_channels, -1)
  71. return self._linearized_weight
  72. def _clear_linearized_weight(self, *args):
  73. self._linearized_weight = None
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...