Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

multiprocessing_train.py 2.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  1. #!/usr/bin/env python3 -u
  2. # Copyright (c) 2017-present, Facebook, Inc.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the license found in the LICENSE file in
  6. # the root directory of this source tree. An additional grant of patent rights
  7. # can be found in the PATENTS file in the same directory.
  8. import os
  9. import random
  10. import signal
  11. import torch
  12. from fairseq import distributed_utils, options
  13. from train import main as single_process_main
  14. def main(args):
  15. # Set distributed training parameters for a single node.
  16. args.distributed_world_size = torch.cuda.device_count()
  17. args.distributed_init_method = 'tcp://localhost:{port}'.format(
  18. port=random.randint(10000, 20000))
  19. mp = torch.multiprocessing.get_context('spawn')
  20. # Create a thread to listen for errors in the child processes.
  21. error_queue = mp.SimpleQueue()
  22. error_handler = ErrorHandler(error_queue)
  23. # Train with multiprocessing.
  24. procs = []
  25. for i in range(args.distributed_world_size):
  26. args.distributed_rank = i
  27. args.device_id = i
  28. procs.append(mp.Process(target=run, args=(args, error_queue, ), daemon=True))
  29. procs[i].start()
  30. error_handler.add_child(procs[i].pid)
  31. for p in procs:
  32. p.join()
  33. def run(args, error_queue):
  34. try:
  35. args.distributed_rank = distributed_utils.distributed_init(args)
  36. single_process_main(args)
  37. except KeyboardInterrupt:
  38. pass # killed by parent, do nothing
  39. except Exception:
  40. # propagate exception to parent process, keeping original traceback
  41. import traceback
  42. error_queue.put((args.distributed_rank, traceback.format_exc()))
  43. class ErrorHandler(object):
  44. """A class that listens for exceptions in children processes and propagates
  45. the tracebacks to the parent process."""
  46. def __init__(self, error_queue):
  47. import signal
  48. import threading
  49. self.error_queue = error_queue
  50. self.children_pids = []
  51. self.error_thread = threading.Thread(target=self.error_listener, daemon=True)
  52. self.error_thread.start()
  53. signal.signal(signal.SIGUSR1, self.signal_handler)
  54. def add_child(self, pid):
  55. self.children_pids.append(pid)
  56. def error_listener(self):
  57. (rank, original_trace) = self.error_queue.get()
  58. self.error_queue.put((rank, original_trace))
  59. os.kill(os.getpid(), signal.SIGUSR1)
  60. def signal_handler(self, signalnum, stackframe):
  61. for pid in self.children_pids:
  62. os.kill(pid, signal.SIGINT) # kill children processes
  63. (rank, original_trace) = self.error_queue.get()
  64. msg = "\n\n-- Tracebacks above this line can probably be ignored --\n\n"
  65. msg += original_trace
  66. raise Exception(msg)
  67. if __name__ == '__main__':
  68. parser = options.get_training_parser()
  69. args = options.parse_args_and_arch(parser)
  70. main(args)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...