1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
|
- #!/usr/bin/env python3
- import argparse
- import collections
- import torch
- import os
- import re
- def average_checkpoints(inputs):
- """Loads checkpoints from inputs and returns a model with averaged weights.
- Args:
- inputs: An iterable of string paths of checkpoints to load from.
- Returns:
- A dict of string keys mapping to various values. The 'model' key
- from the returned dict should correspond to an OrderedDict mapping
- string parameter names to torch Tensors.
- """
- params_dict = collections.OrderedDict()
- params_keys = None
- new_state = None
- for f in inputs:
- state = torch.load(
- f,
- map_location=(
- lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
- ),
- )
- # Copies over the settings from the first checkpoint
- if new_state is None:
- new_state = state
- model_params = state['model']
- model_params_keys = list(model_params.keys())
- if params_keys is None:
- params_keys = model_params_keys
- elif params_keys != model_params_keys:
- raise KeyError(
- 'For checkpoint {}, expected list of params: {}, '
- 'but found: {}'.format(f, params_keys, model_params_keys)
- )
- for k in params_keys:
- if k not in params_dict:
- params_dict[k] = []
- p = model_params[k]
- if isinstance(p, torch.HalfTensor):
- p = p.float()
- params_dict[k].append(p)
- averaged_params = collections.OrderedDict()
- # v should be a list of torch Tensor.
- for k, v in params_dict.items():
- summed_v = None
- for x in v:
- summed_v = summed_v + x if summed_v is not None else x
- averaged_params[k] = summed_v / len(v)
- new_state['model'] = averaged_params
- return new_state
- def last_n_checkpoints(paths, n):
- assert len(paths) == 1
- path = paths[0]
- pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
- files = os.listdir(path)
- entries = []
- for f in files:
- m = pt_regexp.fullmatch(f)
- if m is not None:
- entries.append((int(m.group(1)), m.group(0)))
- if len(entries) < n:
- raise Exception('Found {} checkpoint files but need at least {}', len(entries), n)
- return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]
- def main():
- parser = argparse.ArgumentParser(
- description='Tool to average the params of input checkpoints to '
- 'produce a new checkpoint',
- )
- parser.add_argument(
- '--inputs',
- required=True,
- nargs='+',
- help='Input checkpoint file paths.',
- )
- parser.add_argument(
- '--output',
- required=True,
- metavar='FILE',
- help='Write the new checkpoint containing the averaged weights to this '
- 'path.',
- )
- parser.add_argument(
- '--num',
- type=int,
- help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
- 'and average last num of those',
- )
- args = parser.parse_args()
- print(args)
- if args.num is not None:
- args.inputs = last_n_checkpoints(args.inputs, args.num)
- print('averaging checkpoints: ', args.inputs)
- new_state = average_checkpoints(args.inputs)
- torch.save(new_state, args.output)
- print('Finished writing averaged checkpoint to {}.'.format(args.output))
- if __name__ == '__main__':
- main()
|