Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

cross_entropy.py 2.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import math
  8. import torch.nn.functional as F
  9. from fairseq import utils
  10. from . import FairseqCriterion, register_criterion
  11. @register_criterion('cross_entropy')
  12. class CrossEntropyCriterion(FairseqCriterion):
  13. def __init__(self, args, src_dict, dst_dict):
  14. super().__init__(args, src_dict, dst_dict)
  15. def forward(self, model, sample, reduce=True):
  16. """Compute the loss for the given sample.
  17. Returns a tuple with three elements:
  18. 1) the loss, as a Variable
  19. 2) the sample size, which is used as the denominator for the gradient
  20. 3) logging outputs to display while training
  21. """
  22. net_output = model(**sample['net_input'])
  23. lprobs = model.get_normalized_probs(net_output, log_probs=True)
  24. lprobs = lprobs.view(-1, lprobs.size(-1))
  25. target = model.get_targets(sample, net_output).view(-1)
  26. loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx,
  27. reduce=reduce)
  28. sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
  29. logging_output = {
  30. 'loss': utils.item(loss.data) if reduce else loss.data,
  31. 'ntokens': sample['ntokens'],
  32. 'sample_size': sample_size,
  33. }
  34. return loss, sample_size, logging_output
  35. @staticmethod
  36. def aggregate_logging_outputs(logging_outputs):
  37. """Aggregate logging outputs from data parallel training."""
  38. loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
  39. ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
  40. sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
  41. agg_output = {
  42. 'loss': loss_sum / sample_size / math.log(2),
  43. 'sample_size': sample_size,
  44. }
  45. if sample_size != ntokens:
  46. agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
  47. return agg_output
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...