Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

dist_utils.py 1.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. import logging
  7. import os
  8. from datetime import timedelta
  9. from typing import List
  10. import torch
  11. import torch.distributed as dist
  12. import torch.multiprocessing
  13. logger = logging.getLogger(__name__)
  14. def is_dist_initialized() -> bool:
  15. if not dist.is_available():
  16. return False
  17. if not dist.is_initialized():
  18. return False
  19. return True
  20. def get_rank() -> int:
  21. if not is_dist_initialized():
  22. return 0
  23. return dist.get_rank()
  24. def get_local_rank() -> int:
  25. if not is_dist_initialized():
  26. return 0
  27. return int(os.environ["LOCAL_RANK"])
  28. def get_world_size() -> int:
  29. if not is_dist_initialized():
  30. return 1
  31. return dist.get_world_size()
  32. def is_main_process() -> bool:
  33. return get_rank() == 0
  34. def init_distributed(loggers: List[logging.Logger]) -> None:
  35. """Initializes the distributed backend"""
  36. torch.multiprocessing.set_start_method("spawn")
  37. if "RANK" not in os.environ:
  38. logger.error(
  39. "Cannot init disributed context, as environment varaibles are not set."
  40. )
  41. return
  42. rank = int(os.environ["RANK"])
  43. world_size = int(os.environ["WORLD_SIZE"])
  44. local_rank = int(os.environ["LOCAL_RANK"])
  45. logger.info(
  46. f"Rank={rank} local rank={local_rank}, world_size={world_size}, is_master={rank == 0}"
  47. )
  48. dist.init_process_group(
  49. backend="nccl",
  50. init_method="env://",
  51. world_size=world_size,
  52. rank=rank,
  53. timeout=timedelta(seconds=180),
  54. )
  55. logger.info(f"Setting cuda:{local_rank} as main device")
  56. if not is_main_process():
  57. for to_mute in loggers:
  58. to_mute.setLevel(logging.ERROR)
  59. torch.cuda.set_device(local_rank)
  60. dist.barrier()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...