Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

options.py 14 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import argparse
  8. import torch
  9. from fairseq.criterions import CRITERION_REGISTRY
  10. from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
  11. from fairseq.optim import OPTIMIZER_REGISTRY
  12. from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
  13. def get_training_parser():
  14. parser = get_parser('Trainer')
  15. add_dataset_args(parser, train=True)
  16. add_distributed_training_args(parser)
  17. add_model_args(parser)
  18. add_optimization_args(parser)
  19. add_checkpoint_args(parser)
  20. return parser
  21. def get_generation_parser(interactive=False):
  22. parser = get_parser('Generation')
  23. add_dataset_args(parser, gen=True)
  24. add_generation_args(parser)
  25. if interactive:
  26. add_interactive_args(parser)
  27. return parser
  28. def _eval_float_list(x):
  29. if isinstance(x, str):
  30. x = eval(x)
  31. try:
  32. return list(x)
  33. except:
  34. return [float(x)]
  35. def parse_args_and_arch(parser, input_args=None):
  36. # The parser doesn't know about model/criterion/optimizer-specific args, so
  37. # we parse twice. First we parse the model/criterion/optimizer, then we
  38. # parse a second time after adding the *-specific arguments.
  39. # If input_args is given, we will parse those args instead of sys.argv.
  40. args, _ = parser.parse_known_args(input_args)
  41. # Add model-specific args to parser.
  42. model_specific_group = parser.add_argument_group(
  43. 'Model-specific configuration',
  44. # Only include attributes which are explicitly given as command-line
  45. # arguments or which have default values.
  46. argument_default=argparse.SUPPRESS,
  47. )
  48. ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)
  49. # Add *-specific args to parser.
  50. CRITERION_REGISTRY[args.criterion].add_args(parser)
  51. OPTIMIZER_REGISTRY[args.optimizer].add_args(parser)
  52. LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser)
  53. # Parse a second time.
  54. args = parser.parse_args(input_args)
  55. # Post-process args.
  56. args.lr = _eval_float_list(args.lr)
  57. args.update_freq = _eval_float_list(args.update_freq)
  58. if args.max_sentences_valid is None:
  59. args.max_sentences_valid = args.max_sentences
  60. # Apply architecture configuration.
  61. ARCH_CONFIG_REGISTRY[args.arch](args)
  62. return args
  63. def get_parser(desc):
  64. parser = argparse.ArgumentParser(
  65. description='Facebook AI Research Sequence-to-Sequence Toolkit -- ' + desc)
  66. parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
  67. parser.add_argument('--log-interval', type=int, default=1000, metavar='N',
  68. help='log progress every N batches (when progress bar is disabled)')
  69. parser.add_argument('--log-format', default=None, help='log format to use',
  70. choices=['json', 'none', 'simple', 'tqdm'])
  71. parser.add_argument('--seed', default=1, type=int, metavar='N',
  72. help='pseudo random number generator seed')
  73. return parser
  74. def add_dataset_args(parser, train=False, gen=False):
  75. group = parser.add_argument_group('Dataset and data loading')
  76. group.add_argument('data', metavar='DIR',
  77. help='path to data directory')
  78. group.add_argument('-s', '--source-lang', default=None, metavar='SRC',
  79. help='source language')
  80. group.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
  81. help='target language')
  82. group.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
  83. help='max number of tokens in the source sequence')
  84. group.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
  85. help='max number of tokens in the target sequence')
  86. group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
  87. help='Ignore too long or too short lines in valid and test set')
  88. group.add_argument('--max-tokens', type=int, metavar='N',
  89. help='maximum number of tokens in a batch')
  90. group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N',
  91. help='maximum number of sentences in a batch')
  92. if train:
  93. group.add_argument('--train-subset', default='train', metavar='SPLIT',
  94. choices=['train', 'valid', 'test'],
  95. help='data subset to use for training (train, valid, test)')
  96. group.add_argument('--valid-subset', default='valid', metavar='SPLIT',
  97. help='comma separated list of data subsets to use for validation'
  98. ' (train, valid, valid1,test, test1)')
  99. group.add_argument('--max-sentences-valid', type=int, metavar='N',
  100. help='maximum number of sentences in a validation batch'
  101. ' (defaults to --max-sentences)')
  102. group.add_argument('--sample-without-replacement', default=0, type=int, metavar='N',
  103. help='If bigger than 0, use that number of mini-batches for each epoch,'
  104. ' where each sample is drawn randomly without replacement from the'
  105. ' dataset')
  106. if gen:
  107. group.add_argument('--gen-subset', default='test', metavar='SPLIT',
  108. help='data subset to generate (train, valid, test)')
  109. group.add_argument('--num-shards', default=1, type=int, metavar='N',
  110. help='shard generation over N shards')
  111. group.add_argument('--shard-id', default=0, type=int, metavar='ID',
  112. help='id of the shard to generate (id < num_shards)')
  113. return group
  114. def add_distributed_training_args(parser):
  115. group = parser.add_argument_group('Distributed training')
  116. group.add_argument('--distributed-world-size', type=int, metavar='N',
  117. default=torch.cuda.device_count(),
  118. help='total number of GPUs across all nodes (default: all visible GPUs)')
  119. group.add_argument('--distributed-rank', default=0, type=int,
  120. help='rank of the current worker')
  121. group.add_argument('--distributed-backend', default='nccl', type=str,
  122. help='distributed backend')
  123. group.add_argument('--distributed-init-method', default=None, type=str,
  124. help='typically tcp://hostname:port that will be used to '
  125. 'establish initial connetion')
  126. group.add_argument('--distributed-port', default=-1, type=int,
  127. help='port number (not required if using --distributed-init-method)')
  128. group.add_argument('--device-id', default=0, type=int,
  129. help='which GPU to use (usually configured automatically)')
  130. return group
  131. def add_optimization_args(parser):
  132. group = parser.add_argument_group('Optimization')
  133. group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N',
  134. help='force stop training at specified epoch')
  135. group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N',
  136. help='force stop training at specified update')
  137. group.add_argument('--clip-norm', default=25, type=float, metavar='NORM',
  138. help='clip threshold of gradients')
  139. group.add_argument('--sentence-avg', action='store_true',
  140. help='normalize gradients by the number of sentences in a batch'
  141. ' (default is to normalize by number of tokens)')
  142. group.add_argument('--update-freq', default='1', metavar='N',
  143. help='update parameters every N_i batches, when in epoch i')
  144. group.add_argument('--fp16', action='store_true',
  145. help='use FP16 during training')
  146. # Optimizer definitions can be found under fairseq/optim/
  147. group.add_argument('--optimizer', default='nag', metavar='OPT',
  148. choices=OPTIMIZER_REGISTRY.keys(),
  149. help='optimizer: {} (default: nag)'.format(', '.join(OPTIMIZER_REGISTRY.keys())))
  150. group.add_argument('--lr', '--learning-rate', default='0.25', metavar='LR_1,LR_2,...,LR_N',
  151. help='learning rate for the first N epochs; all epochs >N using LR_N'
  152. ' (note: this may be interpreted differently depending on --lr-scheduler)')
  153. group.add_argument('--momentum', default=0.99, type=float, metavar='M',
  154. help='momentum factor')
  155. group.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
  156. help='weight decay')
  157. # Learning rate schedulers can be found under fairseq/optim/lr_scheduler/
  158. group.add_argument('--lr-scheduler', default='reduce_lr_on_plateau',
  159. help='learning rate scheduler: {} (default: reduce_lr_on_plateau)'.format(
  160. ', '.join(LR_SCHEDULER_REGISTRY.keys())))
  161. group.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
  162. help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)')
  163. group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR',
  164. help='minimum learning rate')
  165. return group
  166. def add_checkpoint_args(parser):
  167. group = parser.add_argument_group('Checkpointing')
  168. group.add_argument('--save-dir', metavar='DIR', default='checkpoints',
  169. help='path to save checkpoints')
  170. group.add_argument('--restore-file', default='checkpoint_last.pt',
  171. help='filename in save-dir from which to load checkpoint')
  172. group.add_argument('--save-interval', type=int, default=1, metavar='N',
  173. help='save a checkpoint every N epochs')
  174. group.add_argument('--no-save', action='store_true',
  175. help='don\'t save models or checkpoints')
  176. group.add_argument('--no-epoch-checkpoints', action='store_true',
  177. help='only store last and best checkpoints')
  178. group.add_argument('--validate-interval', type=int, default=1, metavar='N',
  179. help='validate every N epochs')
  180. return group
  181. def add_generation_args(parser):
  182. group = parser.add_argument_group('Generation')
  183. group.add_argument('--path', metavar='FILE', action='append',
  184. help='path(s) to model file(s)')
  185. group.add_argument('--beam', default=5, type=int, metavar='N',
  186. help='beam size')
  187. group.add_argument('--nbest', default=1, type=int, metavar='N',
  188. help='number of hypotheses to output')
  189. group.add_argument('--max-len-a', default=0, type=float, metavar='N',
  190. help=('generate sequences of maximum length ax + b, '
  191. 'where x is the source length'))
  192. group.add_argument('--max-len-b', default=200, type=int, metavar='N',
  193. help=('generate sequences of maximum length ax + b, '
  194. 'where x is the source length'))
  195. group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
  196. help='remove BPE tokens before scoring')
  197. group.add_argument('--no-early-stop', action='store_true',
  198. help=('continue searching even after finalizing k=beam '
  199. 'hypotheses; this is more correct, but increases '
  200. 'generation time by 50%%'))
  201. group.add_argument('--unnormalized', action='store_true',
  202. help='compare unnormalized hypothesis scores')
  203. group.add_argument('--cpu', action='store_true', help='generate on CPU')
  204. group.add_argument('--no-beamable-mm', action='store_true',
  205. help='don\'t use BeamableMM in attention layers')
  206. group.add_argument('--lenpen', default=1, type=float,
  207. help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
  208. group.add_argument('--unkpen', default=0, type=float,
  209. help='unknown word penalty: <0 produces more unks, >0 produces fewer')
  210. group.add_argument('--replace-unk', nargs='?', const=True, default=None,
  211. help='perform unknown replacement (optionally with alignment dictionary)')
  212. group.add_argument('--quiet', action='store_true',
  213. help='only print final scores')
  214. group.add_argument('--score-reference', action='store_true',
  215. help='just score the reference translation')
  216. group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
  217. help='initialize generation by target prefix of given length')
  218. group.add_argument('--sampling', action='store_true',
  219. help='sample hypotheses instead of using beam search')
  220. return group
  221. def add_interactive_args(parser):
  222. group = parser.add_argument_group('Interactive')
  223. group.add_argument('--buffer-size', default=0, type=int, metavar='N',
  224. help='read this many sentences into a buffer before processing them')
  225. def add_model_args(parser):
  226. group = parser.add_argument_group('Model configuration')
  227. # Model definitions can be found under fairseq/models/
  228. #
  229. # The model architecture can be specified in several ways.
  230. # In increasing order of priority:
  231. # 1) model defaults (lowest priority)
  232. # 2) --arch argument
  233. # 3) --encoder/decoder-* arguments (highest priority)
  234. group.add_argument(
  235. '--arch', '-a', default='fconv', metavar='ARCH', required=True,
  236. choices=ARCH_MODEL_REGISTRY.keys(),
  237. help='model architecture: {} (default: fconv)'.format(
  238. ', '.join(ARCH_MODEL_REGISTRY.keys())),
  239. )
  240. # Criterion definitions can be found under fairseq/criterions/
  241. group.add_argument(
  242. '--criterion', default='cross_entropy', metavar='CRIT',
  243. choices=CRITERION_REGISTRY.keys(),
  244. help='training criterion: {} (default: cross_entropy)'.format(
  245. ', '.join(CRITERION_REGISTRY.keys())),
  246. )
  247. return group
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...