Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

fairseq_criterion.py 1.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. from torch.nn.modules.loss import _Loss
  8. class FairseqCriterion(_Loss):
  9. def __init__(self, args, src_dict, dst_dict):
  10. super().__init__()
  11. self.args = args
  12. self.padding_idx = dst_dict.pad()
  13. @staticmethod
  14. def add_args(parser):
  15. """Add criterion-specific arguments to the parser."""
  16. pass
  17. def forward(self, model, sample, reduce=True):
  18. """Compute the loss for the given sample.
  19. Returns a tuple with three elements:
  20. 1) the loss, as a Variable
  21. 2) the sample size, which is used as the denominator for the gradient
  22. 3) logging outputs to display while training
  23. """
  24. raise NotImplementedError
  25. @staticmethod
  26. def aggregate_logging_outputs(logging_outputs):
  27. """Aggregate logging outputs from data parallel training."""
  28. raise NotImplementedError
  29. @staticmethod
  30. def grad_denom(sample_sizes):
  31. """Compute the gradient denominator for a set of sample sizes."""
  32. return sum(sample_sizes)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...