Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

distributed_training_utils.py 4.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
  1. import torch
  2. from torch.cuda.amp import autocast
  3. import torch.nn as nn
  4. import itertools
  5. def distributed_all_reduce_tensor_average(tensor, n):
  6. """
  7. This method performs a reduce operation on multiple nodes running distributed training
  8. It first sums all of the results and then divides the summation
  9. :param tensor: The tensor to perform the reduce operation for
  10. :param n: Number of nodes
  11. :return: Averaged tensor from all of the nodes
  12. """
  13. rt = tensor.clone()
  14. torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM)
  15. rt /= n
  16. return rt
  17. def reduce_results_tuple_for_ddp(validation_results_tuple, device):
  18. """Gather all validation tuples from the various devices and average them"""
  19. validation_results_list = list(validation_results_tuple)
  20. for i, validation_result in enumerate(validation_results_list):
  21. validation_results_list[i] = distributed_all_reduce_tensor_average(torch.tensor(validation_result).to(device),
  22. torch.distributed.get_world_size())
  23. validation_results_tuple = tuple(validation_results_list)
  24. return validation_results_tuple
  25. class MultiGPUModeAutocastWrapper():
  26. def __init__(self, func):
  27. self.func = func
  28. def __call__(self, *args, **kwargs):
  29. with autocast():
  30. out = self.func(*args, **kwargs)
  31. return out
  32. def scaled_all_reduce(tensors: torch.Tensor, num_gpus: int):
  33. """
  34. Performs the scaled all_reduce operation on the provided tensors.
  35. The input tensors are modified in-place.
  36. Currently supports only the sum
  37. reduction operator.
  38. The reduced values are scaled by the inverse size of the
  39. process group (equivalent to num_gpus).
  40. """
  41. # There is no need for reduction in the single-proc case
  42. if num_gpus == 1:
  43. return tensors
  44. # Queue the reductions
  45. reductions = []
  46. for tensor in tensors:
  47. reduction = torch.distributed.all_reduce(tensor, async_op=True)
  48. reductions.append(reduction)
  49. # Wait for reductions to finish
  50. for reduction in reductions:
  51. reduction.wait()
  52. # Scale the results
  53. for tensor in tensors:
  54. tensor.mul_(1.0 / num_gpus)
  55. return tensors
  56. @torch.no_grad()
  57. def compute_precise_bn_stats(model: nn.Module, loader: torch.utils.data.DataLoader, precise_bn_batch_size: int, num_gpus: int):
  58. '''
  59. :param model: The model being trained (ie: SgModel.net)
  60. :param loader: Training dataloader (ie: SgModel.train_loader)
  61. :param precise_bn_batch_size: The effective batch size we want to calculate the batchnorm on. For example, if we are training a model
  62. on 8 gpus, with a batch of 128 on each gpu, a good rule of thumb would be to give it 8192
  63. (ie: effective_batch_size * num_gpus = batch_per_gpu * num_gpus * num_gpus).
  64. If precise_bn_batch_size is not provided in the training_params, the latter heuristic
  65. will be taken.
  66. param num_gpus: The number of gpus we are training on
  67. '''
  68. # Compute the number of minibatches to use
  69. num_iter = int(precise_bn_batch_size / (loader.batch_size * num_gpus)) if precise_bn_batch_size else num_gpus
  70. num_iter = min(num_iter, len(loader))
  71. # Retrieve the BN layers
  72. bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
  73. # Initialize BN stats storage for computing mean(mean(batch)) and mean(var(batch))
  74. running_means = [torch.zeros_like(bn.running_mean) for bn in bns]
  75. running_vars = [torch.zeros_like(bn.running_var) for bn in bns]
  76. # Remember momentum values
  77. momentums = [bn.momentum for bn in bns]
  78. # Set momentum to 1.0 to compute BN stats that only reflect the current batch
  79. for bn in bns:
  80. bn.momentum = 1.0
  81. # Average the BN stats for each BN layer over the batches
  82. for inputs, _labels in itertools.islice(loader, num_iter):
  83. model(inputs.cuda())
  84. for i, bn in enumerate(bns):
  85. running_means[i] += bn.running_mean / num_iter
  86. running_vars[i] += bn.running_var / num_iter
  87. # Sync BN stats across GPUs (no reduction if 1 GPU used)
  88. running_means = scaled_all_reduce(running_means, num_gpus=num_gpus)
  89. running_vars = scaled_all_reduce(running_vars, num_gpus=num_gpus)
  90. # Set BN stats and restore original momentum values
  91. for i, bn in enumerate(bns):
  92. bn.running_mean = running_means[i]
  93. bn.running_var = running_vars[i]
  94. bn.momentum = momentums[i]
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...