Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

trainer.py 13 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. """
  8. Train a network across multiple GPUs.
  9. """
  10. from collections import defaultdict, OrderedDict
  11. from itertools import chain
  12. import torch
  13. from fairseq import distributed_utils, optim, utils
  14. from fairseq.meters import AverageMeter, TimeMeter
  15. from fairseq.optim import lr_scheduler
  16. class Trainer(object):
  17. """Main class for data parallel training.
  18. This class supports data parallel training, where multiple workers each
  19. have a full model replica and gradients are accumulated synchronously via
  20. torch.distributed.all_reduce.
  21. """
  22. def __init__(self, args, model, criterion):
  23. if not torch.cuda.is_available():
  24. raise NotImplementedError('Training on CPU is not supported')
  25. self.args = args
  26. # copy model and criterion to current device
  27. self.model = model.cuda()
  28. self.criterion = criterion.cuda()
  29. # initialize optimizer and LR scheduler
  30. self._build_optimizer()
  31. # initialize meters
  32. self.meters = OrderedDict()
  33. self.meters['train_loss'] = AverageMeter()
  34. self.meters['train_nll_loss'] = AverageMeter()
  35. self.meters['valid_loss'] = AverageMeter()
  36. self.meters['valid_nll_loss'] = AverageMeter()
  37. self.meters['wps'] = TimeMeter() # words per second
  38. self.meters['ups'] = TimeMeter() # updates per second
  39. self.meters['wpb'] = AverageMeter() # words per batch
  40. self.meters['bsz'] = AverageMeter() # sentences per batch
  41. self.meters['gnorm'] = AverageMeter() # gradient norm
  42. self.meters['clip'] = AverageMeter() # % of updates clipped
  43. self.meters['oom'] = AverageMeter() # out of memory
  44. self.meters['wall'] = TimeMeter() # wall time in seconds
  45. self._buffered_stats = defaultdict(lambda: [])
  46. self._flat_grads = None
  47. self._num_updates = 0
  48. self._optim_history = None
  49. def _build_optimizer(self):
  50. self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
  51. self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
  52. def save_checkpoint(self, filename, extra_state):
  53. """Save all training state in a checkpoint file."""
  54. if distributed_utils.is_master(self.args): # only save one checkpoint
  55. utils.save_state(filename, self.args, self.model, self.criterion, self.optimizer,
  56. self.lr_scheduler, self._num_updates, self._optim_history, extra_state)
  57. def load_checkpoint(self, filename):
  58. """Load all training state from a checkpoint file."""
  59. extra_state, self._optim_history, last_optim_state = \
  60. utils.load_model_state(filename, self.model)
  61. if last_optim_state is not None:
  62. # rebuild optimizer after loading model, since params may have changed
  63. self._build_optimizer()
  64. # only reload optimizer and lr_scheduler if they match
  65. last_optim = self._optim_history[-1]
  66. if last_optim['criterion_name'] == self.criterion.__class__.__name__:
  67. self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
  68. if last_optim['optimizer_name'] == self.optimizer.__class__.__name__:
  69. self.optimizer.load_state_dict(last_optim_state)
  70. self._num_updates = last_optim['num_updates']
  71. return extra_state
  72. def train_step(self, sample, update_params=True):
  73. """Do forward, backward and parameter update."""
  74. sample = self._prepare_sample(sample, volatile=False)
  75. # forward and backward pass
  76. loss, sample_size, logging_output, oom_fwd = self._forward(sample)
  77. oom_bwd = self._backward(loss)
  78. # buffer stats and logging outputs
  79. self._buffered_stats['sample_sizes'].append(sample_size)
  80. self._buffered_stats['logging_outputs'].append(logging_output)
  81. self._buffered_stats['ooms_fwd'].append(oom_fwd)
  82. self._buffered_stats['ooms_bwd'].append(oom_bwd)
  83. # update parameters
  84. if update_params:
  85. # gather logging outputs from all replicas
  86. sample_sizes = self._buffered_stats['sample_sizes']
  87. logging_outputs = self._buffered_stats['logging_outputs']
  88. ooms_fwd = self._buffered_stats['ooms_fwd']
  89. ooms_bwd = self._buffered_stats['ooms_bwd']
  90. if self.args.distributed_world_size > 1:
  91. sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
  92. lambda l: list(chain.from_iterable(l)),
  93. zip(*distributed_utils.all_gather_list(
  94. (sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
  95. ))
  96. )
  97. ooms_fwd = sum(ooms_fwd)
  98. ooms_bwd = sum(ooms_bwd)
  99. # aggregate stats and logging outputs
  100. ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
  101. nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
  102. agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
  103. grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
  104. try:
  105. # all-reduce and rescale gradients, then take an optimization step
  106. grad_norm = self._all_reduce_and_rescale(grad_denom)
  107. self._opt()
  108. # update meters
  109. self.meters['wps'].update(ntokens)
  110. self.meters['ups'].update(1.)
  111. self.meters['wpb'].update(ntokens)
  112. self.meters['bsz'].update(nsentences)
  113. if grad_norm is not None:
  114. self.meters['gnorm'].update(grad_norm)
  115. self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.)
  116. self.meters['oom'].update(ooms_fwd + ooms_bwd)
  117. # update loss meters for training
  118. if 'loss' in agg_logging_output:
  119. self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom)
  120. # criterions can optionally log the NLL loss too
  121. if 'nll_loss' in agg_logging_output:
  122. self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
  123. except OverflowError as e:
  124. self.zero_grad()
  125. print('| WARNING: overflow detected, ' + str(e))
  126. self.clear_buffered_stats()
  127. return agg_logging_output
  128. else:
  129. return None # buffering updates
  130. def _forward(self, sample, eval=False):
  131. # prepare model and optimizer
  132. if eval:
  133. self.model.eval()
  134. else:
  135. self.model.train()
  136. loss = None
  137. sample_size = 0
  138. logging_output = {
  139. 'ntokens': sample['ntokens'] if sample is not None else 0,
  140. 'nsentences': sample['target'].size(0) if sample is not None else 0,
  141. }
  142. oom = 0
  143. if sample is not None:
  144. try:
  145. with utils.maybe_no_grad(eval):
  146. # calculate loss and sample size
  147. loss, sample_size, logging_output_ = self.criterion(self.model, sample)
  148. logging_output.update(logging_output_)
  149. except RuntimeError as e:
  150. if not eval and 'out of memory' in str(e):
  151. print('| WARNING: ran out of memory, skipping batch')
  152. oom = 1
  153. loss = None
  154. else:
  155. raise e
  156. return loss, sample_size, logging_output, oom
  157. def _backward(self, loss):
  158. oom = 0
  159. if loss is not None:
  160. try:
  161. # backward pass
  162. loss.backward()
  163. except RuntimeError as e:
  164. if 'out of memory' in str(e):
  165. print('| WARNING: ran out of memory, skipping batch')
  166. oom = 1
  167. self.zero_grad()
  168. else:
  169. raise e
  170. return oom
  171. def _all_reduce_and_rescale(self, grad_denom):
  172. # flatten grads into a single buffer and all-reduce
  173. flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads)
  174. if self.args.distributed_world_size > 1:
  175. torch.distributed.all_reduce(flat_grads)
  176. # rescale and clip gradients
  177. flat_grads.div_(grad_denom)
  178. grad_norm = utils.clip_grad_norm_(flat_grads, self.args.clip_norm)
  179. # copy grads back into model parameters
  180. self._set_flat_grads(flat_grads)
  181. return grad_norm
  182. def _get_grads(self):
  183. grads = []
  184. for name, p in self.model.named_parameters():
  185. if not p.requires_grad:
  186. continue
  187. if p.grad is None:
  188. raise RuntimeError('Model parameter did not receive gradient: ' + name + '. '
  189. 'Use the param in the forward pass or set requires_grad=False')
  190. grads.append(p.grad.data)
  191. return grads
  192. def _get_flat_grads(self, out=None):
  193. grads = self._get_grads()
  194. if out is None:
  195. grads_size = sum(g.numel() for g in grads)
  196. out = grads[0].new(grads_size).zero_()
  197. offset = 0
  198. for g in grads:
  199. numel = g.numel()
  200. out[offset:offset+numel].copy_(g.view(-1))
  201. offset += numel
  202. return out[:offset]
  203. def _set_flat_grads(self, new_grads):
  204. grads = self._get_grads()
  205. offset = 0
  206. for g in grads:
  207. numel = g.numel()
  208. g.copy_(new_grads[offset:offset+numel].view_as(g))
  209. offset += numel
  210. def _opt(self):
  211. # take an optimization step
  212. self.optimizer.step()
  213. self.zero_grad()
  214. self._num_updates += 1
  215. # update learning rate
  216. self.lr_scheduler.step_update(self._num_updates)
  217. def valid_step(self, sample):
  218. """Do forward pass in evaluation mode."""
  219. sample = self._prepare_sample(sample, volatile=True)
  220. # forward pass
  221. _loss, sample_size, logging_output, oom_fwd = self._forward(sample, eval=True)
  222. assert not oom_fwd, 'Ran out of memory during validation'
  223. # gather logging outputs from all GPUs
  224. if self.args.distributed_world_size > 1:
  225. sample_sizes, logging_outputs = zip(*distributed_utils.all_gather_list(
  226. (sample_size, logging_output)
  227. ))
  228. else:
  229. sample_sizes = [sample_size]
  230. logging_outputs = [logging_output]
  231. # aggregate stats and logging outputs
  232. ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
  233. grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
  234. agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
  235. # update loss meters for validation
  236. if 'loss' in agg_logging_output:
  237. self.meters['valid_loss'].update(agg_logging_output['loss'], grad_denom)
  238. # criterions can optionally log the NLL loss too
  239. if 'nll_loss' in agg_logging_output:
  240. self.meters['valid_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
  241. return agg_logging_output
  242. def dummy_train_step(self, dummy_batch):
  243. """Dummy training step for warming caching allocator."""
  244. self.train_step(dummy_batch, update_params=False)
  245. self.zero_grad()
  246. self.clear_buffered_stats()
  247. def zero_grad(self):
  248. self.optimizer.zero_grad()
  249. def clear_buffered_stats(self):
  250. self._buffered_stats.clear()
  251. def lr_step(self, epoch, val_loss=None):
  252. """Adjust the learning rate based on the validation loss."""
  253. return self.lr_scheduler.step(epoch, val_loss)
  254. def get_lr(self):
  255. """Get the current learning rate."""
  256. return self.optimizer.get_lr()
  257. def get_model(self):
  258. """Get the model replica."""
  259. return self.model
  260. def get_meter(self, name):
  261. """Get a specific meter by name."""
  262. if name not in self.meters:
  263. return None
  264. return self.meters[name]
  265. def get_num_updates(self):
  266. """Get the number of parameters updates."""
  267. return self._num_updates
  268. def _prepare_sample(self, sample, volatile):
  269. if sample is None or len(sample) == 0:
  270. return None
  271. return utils.make_variable(sample, volatile=volatile, cuda=True)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...