Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

fairseq_incremental_decoder.py 1.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. from . import FairseqDecoder
  8. class FairseqIncrementalDecoder(FairseqDecoder):
  9. """Base class for incremental decoders."""
  10. def __init__(self, dictionary):
  11. super().__init__(dictionary)
  12. def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
  13. raise NotImplementedError
  14. def reorder_incremental_state(self, incremental_state, new_order):
  15. """Reorder incremental state.
  16. This should be called when the order of the input has changed from the
  17. previous time step. A typical use case is beam search, where the input
  18. order changes between time steps based on the selection of beams.
  19. """
  20. def apply_reorder_incremental_state(module):
  21. if module != self and hasattr(module, 'reorder_incremental_state'):
  22. module.reorder_incremental_state(
  23. incremental_state,
  24. new_order,
  25. )
  26. self.apply(apply_reorder_incremental_state)
  27. def reorder_encoder_out(self, encoder_out, new_order):
  28. return encoder_out
  29. def set_beam_size(self, beam_size):
  30. """Sets the beam size in the decoder and all children."""
  31. if getattr(self, '_beam_size', -1) != beam_size:
  32. def apply_set_beam_size(module):
  33. if module != self and hasattr(module, 'set_beam_size'):
  34. module.set_beam_size(beam_size)
  35. self.apply(apply_set_beam_size)
  36. self._beam_size = beam_size
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...