Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

distributed_training_utils.py 9.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
  1. import sys
  2. import itertools
  3. from contextlib import contextmanager
  4. import torch
  5. import torch.nn as nn
  6. from torch import distributed as dist
  7. from torch.cuda.amp import autocast
  8. from torch.distributed.elastic.multiprocessing import Std
  9. from torch.distributed.elastic.multiprocessing.errors import record
  10. from torch.distributed.launcher.api import LaunchConfig, elastic_launch
  11. from super_gradients.common.data_types.enum import MultiGPUMode
  12. from super_gradients.common.environment.argparse_utils import EXTRA_ARGS
  13. from super_gradients.common.environment.ddp_utils import find_free_port, is_distributed
  14. from super_gradients.common.abstractions.abstract_logger import get_logger
  15. logger = get_logger(__name__)
  16. def distributed_all_reduce_tensor_average(tensor, n):
  17. """
  18. This method performs a reduce operation on multiple nodes running distributed training
  19. It first sums all of the results and then divides the summation
  20. :param tensor: The tensor to perform the reduce operation for
  21. :param n: Number of nodes
  22. :return: Averaged tensor from all of the nodes
  23. """
  24. rt = tensor.clone()
  25. torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM)
  26. rt /= n
  27. return rt
  28. def reduce_results_tuple_for_ddp(validation_results_tuple, device):
  29. """Gather all validation tuples from the various devices and average them"""
  30. validation_results_list = list(validation_results_tuple)
  31. for i, validation_result in enumerate(validation_results_list):
  32. if torch.is_tensor(validation_result):
  33. validation_result = validation_result.clone().detach()
  34. else:
  35. validation_result = torch.tensor(validation_result)
  36. validation_results_list[i] = distributed_all_reduce_tensor_average(tensor=validation_result.to(device), n=torch.distributed.get_world_size())
  37. validation_results_tuple = tuple(validation_results_list)
  38. return validation_results_tuple
  39. class MultiGPUModeAutocastWrapper:
  40. def __init__(self, func):
  41. self.func = func
  42. def __call__(self, *args, **kwargs):
  43. with autocast():
  44. out = self.func(*args, **kwargs)
  45. return out
  46. def scaled_all_reduce(tensors: torch.Tensor, num_gpus: int):
  47. """
  48. Performs the scaled all_reduce operation on the provided tensors.
  49. The input tensors are modified in-place.
  50. Currently supports only the sum
  51. reduction operator.
  52. The reduced values are scaled by the inverse size of the
  53. process group (equivalent to num_gpus).
  54. """
  55. # There is no need for reduction in the single-proc case
  56. if num_gpus == 1:
  57. return tensors
  58. # Queue the reductions
  59. reductions = []
  60. for tensor in tensors:
  61. reduction = torch.distributed.all_reduce(tensor, async_op=True)
  62. reductions.append(reduction)
  63. # Wait for reductions to finish
  64. for reduction in reductions:
  65. reduction.wait()
  66. # Scale the results
  67. for tensor in tensors:
  68. tensor.mul_(1.0 / num_gpus)
  69. return tensors
  70. @torch.no_grad()
  71. def compute_precise_bn_stats(model: nn.Module, loader: torch.utils.data.DataLoader, precise_bn_batch_size: int, num_gpus: int):
  72. """
  73. :param model: The model being trained (ie: Trainer.net)
  74. :param loader: Training dataloader (ie: Trainer.train_loader)
  75. :param precise_bn_batch_size: The effective batch size we want to calculate the batchnorm on. For example, if we are training a model
  76. on 8 gpus, with a batch of 128 on each gpu, a good rule of thumb would be to give it 8192
  77. (ie: effective_batch_size * num_gpus = batch_per_gpu * num_gpus * num_gpus).
  78. If precise_bn_batch_size is not provided in the training_params, the latter heuristic
  79. will be taken.
  80. param num_gpus: The number of gpus we are training on
  81. """
  82. # Compute the number of minibatches to use
  83. num_iter = int(precise_bn_batch_size / (loader.batch_size * num_gpus)) if precise_bn_batch_size else num_gpus
  84. num_iter = min(num_iter, len(loader))
  85. # Retrieve the BN layers
  86. bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
  87. # Initialize BN stats storage for computing mean(mean(batch)) and mean(var(batch))
  88. running_means = [torch.zeros_like(bn.running_mean) for bn in bns]
  89. running_vars = [torch.zeros_like(bn.running_var) for bn in bns]
  90. # Remember momentum values
  91. momentums = [bn.momentum for bn in bns]
  92. # Set momentum to 1.0 to compute BN stats that only reflect the current batch
  93. for bn in bns:
  94. bn.momentum = 1.0
  95. # Average the BN stats for each BN layer over the batches
  96. for inputs, _labels in itertools.islice(loader, num_iter):
  97. model(inputs.cuda())
  98. for i, bn in enumerate(bns):
  99. running_means[i] += bn.running_mean / num_iter
  100. running_vars[i] += bn.running_var / num_iter
  101. # Sync BN stats across GPUs (no reduction if 1 GPU used)
  102. running_means = scaled_all_reduce(running_means, num_gpus=num_gpus)
  103. running_vars = scaled_all_reduce(running_vars, num_gpus=num_gpus)
  104. # Set BN stats and restore original momentum values
  105. for i, bn in enumerate(bns):
  106. bn.running_mean = running_means[i]
  107. bn.running_var = running_vars[i]
  108. bn.momentum = momentums[i]
  109. def get_local_rank():
  110. """
  111. Returns the local rank if running in DDP, and 0 otherwise
  112. :return: local rank
  113. """
  114. return dist.get_rank() if dist.is_initialized() else 0
  115. def get_world_size() -> int:
  116. """
  117. Returns the world size if running in DDP, and 1 otherwise
  118. :return: world size
  119. """
  120. if not dist.is_available():
  121. return 1
  122. if not dist.is_initialized():
  123. return 1
  124. return dist.get_world_size()
  125. @contextmanager
  126. def wait_for_the_master(local_rank: int):
  127. """
  128. Make all processes waiting for the master to do some task.
  129. """
  130. if local_rank > 0:
  131. dist.barrier()
  132. yield
  133. if local_rank == 0:
  134. if not dist.is_available():
  135. return
  136. if not dist.is_initialized():
  137. return
  138. else:
  139. dist.barrier()
  140. def setup_device(multi_gpu: MultiGPUMode = MultiGPUMode.OFF, num_gpus: int = None):
  141. """
  142. If required, launch ddp subprocesses.
  143. :param multi_gpu: DDP, DP or Off
  144. :param num_gpus: Number of GPU's to use.
  145. """
  146. if multi_gpu == MultiGPUMode.AUTO and torch.cuda.device_count() > 1:
  147. multi_gpu = MultiGPUMode.DISTRIBUTED_DATA_PARALLEL
  148. if require_gpu_setup(multi_gpu):
  149. num_gpus = num_gpus or torch.cuda.device_count()
  150. if num_gpus > torch.cuda.device_count():
  151. raise ValueError(f"You specified num_gpus={num_gpus} but only {torch.cuda.device_count()} GPU's are available")
  152. restart_script_with_ddp(num_gpus)
  153. def setup_gpu_mode(gpu_mode: MultiGPUMode = MultiGPUMode.OFF, num_gpus: int = None):
  154. """If required, launch ddp subprocesses (deprecated).
  155. :param gpu_mode: DDP, DP or Off
  156. :param num_gpus: Number of GPU's to use.
  157. """
  158. logger.warning("setup_gpu_mode is now deprecated in favor of setup_device. This will be removed in next version")
  159. setup_device(multi_gpu=gpu_mode, num_gpus=num_gpus)
  160. def require_gpu_setup(multi_gpu: MultiGPUMode) -> bool:
  161. """Check if the environment requires a setup in order to work with DDP."""
  162. return (multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL) and (not is_distributed())
  163. @record
  164. def restart_script_with_ddp(num_gpus: int = None):
  165. """Launch the same script as the one that was launched (i.e. the command used to start the current process is re-used) but on subprocesses (i.e. with DDP).
  166. :param num_gpus: How many gpu's you want to run the script on. If not specified, every available device will be used.
  167. """
  168. ddp_port = find_free_port()
  169. # Get the value fom recipe if specified, otherwise take all available devices.
  170. num_gpus = num_gpus if num_gpus else torch.cuda.device_count()
  171. if num_gpus > torch.cuda.device_count():
  172. raise ValueError(f"You specified num_gpus={num_gpus} but only {torch.cuda.device_count()} GPU's are available")
  173. logger.info(
  174. "Launching DDP with:\n"
  175. f" - ddp_port = {ddp_port}\n"
  176. f" - num_gpus = {num_gpus}/{torch.cuda.device_count()} available\n"
  177. "-------------------------------------\n"
  178. )
  179. config = LaunchConfig(
  180. nproc_per_node=num_gpus,
  181. min_nodes=1,
  182. max_nodes=1,
  183. run_id="sg_initiated",
  184. role="default",
  185. rdzv_endpoint=f"127.0.0.1:{ddp_port}",
  186. rdzv_backend="static",
  187. rdzv_configs={"rank": 0, "timeout": 900},
  188. rdzv_timeout=-1,
  189. max_restarts=0,
  190. monitor_interval=5,
  191. start_method="spawn",
  192. log_dir=None,
  193. redirects=Std.NONE,
  194. tee=Std.NONE,
  195. metrics_cfg={},
  196. )
  197. elastic_launch(config=config, entrypoint=sys.executable)(*sys.argv, *EXTRA_ARGS)
  198. # The code below should actually never be reached as the process will be in a loop inside elastic_launch until any subprocess crashes.
  199. sys.exit("Main process finished")
  200. def get_gpu_mem_utilization():
  201. """GPU memory managed by the caching allocator in bytes for a given device."""
  202. # Workaround to work on any torch version
  203. if hasattr(torch.cuda, "memory_reserved"):
  204. return torch.cuda.memory_reserved()
  205. else:
  206. return torch.cuda.memory_cached()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...