1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
|
- # Copyright (c) 2017-present, Facebook, Inc.
- # All rights reserved.
- #
- # This source code is licensed under the license found in the LICENSE file in
- # the root directory of this source tree. An additional grant of patent rights
- # can be found in the PATENTS file in the same directory.
- import torch.nn.functional as F
- from fairseq import utils
- from .conv_tbc import ConvTBC
- class LinearizedConvolution(ConvTBC):
- """An optimized version of nn.Conv1d.
- At training time, this module uses ConvTBC, which is an optimized version
- of Conv1d. At inference time, it optimizes incremental generation (i.e.,
- one time step at a time) by replacing the convolutions with linear layers.
- Note that the input order changes from training to inference.
- """
- def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
- super().__init__(in_channels, out_channels, kernel_size, **kwargs)
- self._linearized_weight = None
- self.register_backward_hook(self._clear_linearized_weight)
- def forward(self, input, incremental_state=None):
- """
- Input:
- Time x Batch x Channel during training
- Batch x Time x Channel during inference
- Args:
- incremental_state: Used to buffer signal; if not None, then input is
- expected to contain a single frame. If the input order changes
- between time steps, call reorder_incremental_state.
- """
- if incremental_state is None:
- output = super().forward(input)
- if self.kernel_size[0] > 1 and self.padding[0] > 0:
- # remove future timesteps added by padding
- output = output[:-self.padding[0], :, :]
- return output
- # reshape weight
- weight = self._get_linearized_weight()
- kw = self.kernel_size[0]
- bsz = input.size(0) # input: bsz x len x dim
- if kw > 1:
- input = input.data
- input_buffer = self._get_input_buffer(incremental_state)
- if input_buffer is None:
- input_buffer = input.new(bsz, kw, input.size(2)).zero_()
- self._set_input_buffer(incremental_state, input_buffer)
- else:
- # shift buffer
- input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone()
- # append next input
- input_buffer[:, -1, :] = input[:, -1, :]
- input = utils.volatile_variable(input_buffer)
- with utils.maybe_no_grad():
- output = F.linear(input.view(bsz, -1), weight, self.bias)
- return output.view(bsz, 1, -1)
- def reorder_incremental_state(self, incremental_state, new_order):
- input_buffer = self._get_input_buffer(incremental_state)
- if input_buffer is not None:
- input_buffer = input_buffer.index_select(0, new_order)
- self._set_input_buffer(incremental_state, input_buffer)
- def _get_input_buffer(self, incremental_state):
- return utils.get_incremental_state(self, incremental_state, 'input_buffer')
- def _set_input_buffer(self, incremental_state, new_buffer):
- return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer)
- def _get_linearized_weight(self):
- if self._linearized_weight is None:
- kw = self.kernel_size[0]
- weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
- assert weight.size() == (self.out_channels, kw, self.in_channels)
- self._linearized_weight = weight.view(self.out_channels, -1)
- return self._linearized_weight
- def _clear_linearized_weight(self, *args):
- self._linearized_weight = None
|