Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

fp16_trainer.py 4.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. """
  8. Train a network on multiple GPUs.
  9. """
  10. import torch
  11. from fairseq import optim
  12. from fairseq.meters import AverageMeter
  13. from fairseq.optim import lr_scheduler
  14. from fairseq.trainer import Trainer
  15. class DynamicLossScaler:
  16. def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000):
  17. self.loss_scale = init_scale
  18. self.scale_factor = scale_factor
  19. self.scale_window = scale_window
  20. self._iter = 0
  21. self._last_overflow_iter = -1
  22. def update_scale(self, overflow):
  23. if overflow:
  24. self.loss_scale /= self.scale_factor
  25. self._last_overflow_iter = self._iter
  26. elif (self._iter - self._last_overflow_iter) % self.scale_window == 0:
  27. self.loss_scale *= self.scale_factor
  28. self._iter += 1
  29. @staticmethod
  30. def has_overflow(grad_norm):
  31. # detect inf and nan
  32. if grad_norm == float('inf') or grad_norm != grad_norm:
  33. return True
  34. return False
  35. class FP16Trainer(Trainer):
  36. """Modified trainer for FP16.
  37. We maintain two copies of the model's parameters, both in FP16 and FP32.
  38. We do forward/backward with FP16 and compute the loss + optimize with FP32.
  39. """
  40. def __init__(self, args, model, criterion):
  41. super().__init__(args, model, criterion)
  42. # convert model to FP16 (but keep criterion FP32)
  43. self.model.half()
  44. # dynamically scale loss to reduce overflow
  45. self.scaler = DynamicLossScaler(init_scale=2.**7)
  46. self.meters['loss_scale'] = AverageMeter()
  47. def _build_optimizer(self):
  48. # create FP32 copy of parameters and grads
  49. params = [p for p in self.model.parameters() if p.requires_grad]
  50. total_param_size = sum(p.data.numel() for p in params)
  51. self.fp32_params = params[0].new(0).float().new(total_param_size)
  52. offset = 0
  53. for p in params:
  54. numel = p.data.numel()
  55. self.fp32_params[offset:offset+numel].copy_(p.data.view(-1))
  56. offset += numel
  57. self.fp32_params = torch.nn.Parameter(self.fp32_params)
  58. self.fp32_params.grad = self.fp32_params.data.new(total_param_size)
  59. # create optimizer using the copied FP32 params
  60. self.optimizer = optim.build_optimizer(self.args, [self.fp32_params])
  61. self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
  62. def save_checkpoint(self, filename, extra_state):
  63. """Save all training state in a checkpoint file."""
  64. extra_state['loss_scale'] = self.scaler.loss_scale
  65. super().save_checkpoint(filename, extra_state)
  66. def load_checkpoint(self, filename):
  67. """Load all training state from a checkpoint file."""
  68. extra_state = super().load_checkpoint(filename)
  69. if extra_state is not None and 'loss_scale' in extra_state:
  70. self.scaler.loss_scale = extra_state['loss_scale']
  71. return extra_state
  72. def zero_grad(self):
  73. # zero both the FP16 and FP32 grads
  74. self.model.zero_grad() # FP16
  75. self.optimizer.zero_grad() # FP32
  76. def _backward(self, loss):
  77. self.meters['loss_scale'].reset()
  78. self.meters['loss_scale'].update(self.scaler.loss_scale)
  79. if loss is not None:
  80. # dynamically rescale loss to stay in FP16 range
  81. loss = loss * self.scaler.loss_scale
  82. return super()._backward(loss)
  83. def _all_reduce_and_rescale(self, grad_denom):
  84. # undo effect of dynamic loss scaling on gradients
  85. grad_denom *= self.scaler.loss_scale
  86. # all-reduce and rescale gradients
  87. grad_norm = super()._all_reduce_and_rescale(grad_denom)
  88. # detect overflow and adjust loss scale
  89. overflow = DynamicLossScaler.has_overflow(grad_norm)
  90. self.scaler.update_scale(overflow)
  91. if overflow:
  92. raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
  93. return grad_norm
  94. def _get_flat_grads(self, out=None):
  95. if out is None:
  96. out = self.fp32_params.grad
  97. return super()._get_flat_grads(out)
  98. def _set_flat_grads(self, new_grads):
  99. # no-op
  100. assert new_grads.data_ptr() == self.fp32_params.grad.data.data_ptr()
  101. def _opt(self):
  102. # take an optimization step using the FP32 params and grads
  103. super()._opt()
  104. # copy FP32 params back into FP16 model
  105. offset = 0
  106. for p in self.model.parameters():
  107. if not p.requires_grad:
  108. continue
  109. numel = p.data.numel()
  110. p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data))
  111. offset += numel
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...