Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

distributed_utils.py 2.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import pickle
  8. import torch.distributed
  9. from fairseq import utils
  10. def is_master(args):
  11. return args.distributed_rank == 0
  12. def distributed_init(args):
  13. if args.distributed_world_size == 1:
  14. raise ValueError('Cannot initialize distributed with distributed_world_size=1')
  15. print('| distributed init (rank {}): {}'.format(
  16. args.distributed_rank, args.distributed_init_method), flush=True)
  17. if args.distributed_init_method.startswith('tcp://'):
  18. torch.distributed.init_process_group(
  19. backend=args.distributed_backend, init_method=args.distributed_init_method,
  20. world_size=args.distributed_world_size, rank=args.distributed_rank)
  21. else:
  22. torch.distributed.init_process_group(
  23. backend=args.distributed_backend, init_method=args.distributed_init_method,
  24. world_size=args.distributed_world_size)
  25. args.distributed_rank = torch.distributed.get_rank()
  26. if not is_master(args):
  27. suppress_output()
  28. return args.distributed_rank
  29. def suppress_output():
  30. """Suppress printing on the current device. Force printing with `force=True`."""
  31. import builtins as __builtin__
  32. builtin_print = __builtin__.print
  33. def print(*args, **kwargs):
  34. if 'force' in kwargs:
  35. force = kwargs.pop('force')
  36. if force:
  37. builtin_print(*args, **kwargs)
  38. __builtin__.print = print
  39. def all_gather_list(data, max_size=4096):
  40. """Gathers arbitrary data from all nodes into a list."""
  41. world_size = torch.distributed.get_world_size()
  42. if not hasattr(all_gather_list, '_in_buffer') or \
  43. max_size != all_gather_list._in_buffer.size():
  44. all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
  45. all_gather_list._out_buffers = [
  46. torch.cuda.ByteTensor(max_size)
  47. for i in range(world_size)
  48. ]
  49. in_buffer = all_gather_list._in_buffer
  50. out_buffers = all_gather_list._out_buffers
  51. enc = pickle.dumps(data)
  52. enc_size = len(enc)
  53. if enc_size + 2 > max_size:
  54. raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2))
  55. assert max_size < 255*256
  56. in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
  57. in_buffer[1] = enc_size % 255
  58. in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc))
  59. torch.distributed.all_gather(out_buffers, in_buffer.cuda())
  60. result = []
  61. for i in range(world_size):
  62. out_buffer = out_buffers[i]
  63. size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1])
  64. result.append(
  65. pickle.loads(bytes(out_buffer[2:size+2].tolist()))
  66. )
  67. return result
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...