Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

fairseq_model.py 3.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import torch.nn as nn
  8. from . import FairseqDecoder, FairseqEncoder
  9. class FairseqModel(nn.Module):
  10. """Base class for encoder-decoder models."""
  11. def __init__(self, encoder, decoder):
  12. super().__init__()
  13. self.encoder = encoder
  14. self.decoder = decoder
  15. assert isinstance(self.encoder, FairseqEncoder)
  16. assert isinstance(self.decoder, FairseqDecoder)
  17. self.src_dict = encoder.dictionary
  18. self.dst_dict = decoder.dictionary
  19. assert self.src_dict.pad() == self.dst_dict.pad()
  20. assert self.src_dict.eos() == self.dst_dict.eos()
  21. assert self.src_dict.unk() == self.dst_dict.unk()
  22. self._is_generation_fast = False
  23. @staticmethod
  24. def add_args(parser):
  25. """Add model-specific arguments to the parser."""
  26. pass
  27. @classmethod
  28. def build_model(cls, args, src_dict, dst_dict):
  29. """Build a new model instance."""
  30. raise NotImplementedError
  31. def forward(self, src_tokens, src_lengths, prev_output_tokens):
  32. encoder_out = self.encoder(src_tokens, src_lengths)
  33. decoder_out = self.decoder(prev_output_tokens, encoder_out)
  34. return decoder_out
  35. def get_normalized_probs(self, net_output, log_probs):
  36. """Get normalized probabilities (or log probs) from a net's output."""
  37. return self.decoder.get_normalized_probs(net_output, log_probs)
  38. def get_targets(self, sample, net_output):
  39. """Get targets from either the sample or the net's output."""
  40. return sample['target']
  41. def max_encoder_positions(self):
  42. """Maximum input length supported by the encoder."""
  43. return self.encoder.max_positions()
  44. def max_decoder_positions(self):
  45. """Maximum output length supported by the decoder."""
  46. return self.decoder.max_positions()
  47. def load_state_dict(self, state_dict, strict=True):
  48. """Copies parameters and buffers from state_dict into this module and
  49. its descendants.
  50. Overrides the method in nn.Module; compared with that method this
  51. additionally "upgrades" state_dicts from old checkpoints.
  52. """
  53. state_dict = self.upgrade_state_dict(state_dict)
  54. super().load_state_dict(state_dict, strict)
  55. def upgrade_state_dict(self, state_dict):
  56. state_dict = self.encoder.upgrade_state_dict(state_dict)
  57. state_dict = self.decoder.upgrade_state_dict(state_dict)
  58. return state_dict
  59. def make_generation_fast_(self, **kwargs):
  60. """Optimize model for faster generation."""
  61. if self._is_generation_fast:
  62. return # only apply once
  63. self._is_generation_fast = True
  64. # remove weight norm from all modules in the network
  65. def apply_remove_weight_norm(module):
  66. try:
  67. nn.utils.remove_weight_norm(module)
  68. except ValueError: # this module didn't have weight norm
  69. return
  70. self.apply(apply_remove_weight_norm)
  71. def apply_make_generation_fast_(module):
  72. if module != self and hasattr(module, 'make_generation_fast_'):
  73. module.make_generation_fast_(**kwargs)
  74. self.apply(apply_make_generation_fast_)
  75. def train(mode):
  76. if mode:
  77. raise RuntimeError('cannot train after make_generation_fast')
  78. # this model should no longer be used for training
  79. self.eval()
  80. self.train = train
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...