Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

fairseq_optimizer.py 2.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import torch.optim
  8. class FairseqOptimizer(object):
  9. def __init__(self, args, params):
  10. super().__init__()
  11. self.args = args
  12. self.params = params
  13. @staticmethod
  14. def add_args(parser):
  15. """Add optimizer-specific arguments to the parser."""
  16. pass
  17. @property
  18. def optimizer(self):
  19. """Return a torch.optim.optimizer.Optimizer instance."""
  20. if not hasattr(self, '_optimizer'):
  21. raise NotImplementedError
  22. if not isinstance(self._optimizer, torch.optim.Optimizer):
  23. raise ValueError('_optimizer must be an instance of torch.optim.Optimizer')
  24. return self._optimizer
  25. @property
  26. def optimizer_config(self):
  27. """
  28. Return a kwarg dictionary that will be used to override optimizer
  29. args stored in checkpoints. This allows us to load a checkpoint and
  30. resume training using a different set of optimizer args, e.g., with a
  31. different learning rate.
  32. """
  33. raise NotImplementedError
  34. def get_lr(self):
  35. """Return the current learning rate."""
  36. return self.optimizer.param_groups[0]['lr']
  37. def set_lr(self, lr):
  38. """Set the learning rate."""
  39. for param_group in self.optimizer.param_groups:
  40. param_group['lr'] = lr
  41. def state_dict(self):
  42. """Return the optimizer's state dict."""
  43. return self.optimizer.state_dict()
  44. def load_state_dict(self, state_dict):
  45. """Load an optimizer state dict.
  46. In general we should prefer the configuration of the existing optimizer
  47. instance (e.g., learning rate) over that found in the state_dict. This
  48. allows us to resume training from a checkpoint using a new set of
  49. optimizer args.
  50. """
  51. self.optimizer.load_state_dict(state_dict)
  52. # override learning rate, momentum, etc. with latest values
  53. for group in self.optimizer.param_groups:
  54. group.update(self.optimizer_config)
  55. def step(self, closure=None):
  56. """Performs a single optimization step."""
  57. return self.optimizer.step(closure)
  58. def zero_grad(self):
  59. """Clears the gradients of all optimized parameters."""
  60. return self.optimizer.zero_grad()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...