Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

fairseq_decoder.py 1.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. class FairseqDecoder(nn.Module):
  10. """Base class for decoders."""
  11. def __init__(self, dictionary):
  12. super().__init__()
  13. self.dictionary = dictionary
  14. def forward(self, prev_output_tokens, encoder_out):
  15. raise NotImplementedError
  16. def get_normalized_probs(self, net_output, log_probs):
  17. """Get normalized probabilities (or log probs) from a net's output."""
  18. logits = net_output[0].float()
  19. if log_probs:
  20. return F.log_softmax(logits, dim=-1)
  21. else:
  22. return F.softmax(logits, dim=-1)
  23. def max_positions(self):
  24. """Maximum input length supported by the decoder."""
  25. raise NotImplementedError
  26. def upgrade_state_dict(self, state_dict):
  27. return state_dict
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...