Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

average_checkpoints.py 2.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
  1. #!/usr/bin/env python3
  2. import argparse
  3. import collections
  4. import torch
  5. def average_checkpoints(inputs):
  6. """Loads checkpoints from inputs and returns a model with averaged weights.
  7. Args:
  8. inputs: An iterable of string paths of checkpoints to load from.
  9. Returns:
  10. A dict of string keys mapping to various values. The 'model' key
  11. from the returned dict should correspond to an OrderedDict mapping
  12. string parameter names to torch Tensors.
  13. """
  14. params_dict = collections.OrderedDict()
  15. params_keys = None
  16. new_state = None
  17. for f in inputs:
  18. state = torch.load(
  19. f,
  20. map_location=(
  21. lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
  22. ),
  23. )
  24. # Copies over the settings from the first checkpoint
  25. if new_state is None:
  26. new_state = state
  27. model_params = state['model']
  28. model_params_keys = list(model_params.keys())
  29. if params_keys is None:
  30. params_keys = model_params_keys
  31. elif params_keys != model_params_keys:
  32. raise KeyError(
  33. 'For checkpoint {}, expected list of params: {}, '
  34. 'but found: {}'.format(f, params_keys, model_params_keys)
  35. )
  36. for k in params_keys:
  37. if k not in params_dict:
  38. params_dict[k] = []
  39. p = model_params[k]
  40. if isinstance(p, torch.HalfTensor):
  41. p = p.float()
  42. params_dict[k].append(p)
  43. averaged_params = collections.OrderedDict()
  44. # v should be a list of torch Tensor.
  45. for k, v in params_dict.items():
  46. summed_v = None
  47. for x in v:
  48. summed_v = summed_v + x if summed_v is not None else x
  49. averaged_params[k] = summed_v / len(v)
  50. new_state['model'] = averaged_params
  51. return new_state
  52. def main():
  53. parser = argparse.ArgumentParser(
  54. description='Tool to average the params of input checkpoints to '
  55. 'produce a new checkpoint',
  56. )
  57. parser.add_argument(
  58. '--inputs',
  59. required=True,
  60. nargs='+',
  61. help='Input checkpoint file paths.',
  62. )
  63. parser.add_argument(
  64. '--output',
  65. required=True,
  66. metavar='FILE',
  67. help='Write the new checkpoint containing the averaged weights to this '
  68. 'path.',
  69. )
  70. args = parser.parse_args()
  71. print(args)
  72. new_state = average_checkpoints(args.inputs)
  73. torch.save(new_state, args.output)
  74. print('Finished writing averaged checkpoint to {}.'.format(args.output))
  75. if __name__ == '__main__':
  76. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...