1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
|
- # Copyright (c) 2017-present, Facebook, Inc.
- # All rights reserved.
- #
- # This source code is licensed under the license found in the LICENSE file in
- # the root directory of this source tree. An additional grant of patent rights
- # can be found in the PATENTS file in the same directory.
- from io import StringIO
- import os
- import random
- import sys
- import tempfile
- import unittest
- import torch
- from fairseq import options
- import preprocess
- import train
- import generate
- import interactive
- class TestBinaries(unittest.TestCase):
- def test_binaries(self):
- # comment this out to debug the unittest if it's failing
- self.mock_stdout()
- with tempfile.TemporaryDirectory() as data_dir:
- self.create_dummy_data(data_dir)
- self.preprocess_data(data_dir)
- self.train_model(data_dir)
- self.generate(data_dir)
- self.unmock_stdout()
- def create_dummy_data(self, data_dir, num_examples=1000, maxlen=20):
- def _create_dummy_data(filename):
- data = torch.rand(num_examples * maxlen)
- data = 97 + torch.floor(26 * data).int()
- with open(os.path.join(data_dir, filename), 'w') as h:
- offset = 0
- for _ in range(num_examples):
- ex_len = random.randint(1, maxlen)
- ex_str = ' '.join(map(chr, data[offset:offset+ex_len]))
- print(ex_str, file=h)
- offset += ex_len
- _create_dummy_data('train.in')
- _create_dummy_data('train.out')
- _create_dummy_data('valid.in')
- _create_dummy_data('valid.out')
- _create_dummy_data('test.in')
- _create_dummy_data('test.out')
- def preprocess_data(self, data_dir):
- preprocess_parser = preprocess.get_parser()
- preprocess_args = preprocess_parser.parse_args([
- '--source-lang', 'in',
- '--target-lang', 'out',
- '--trainpref', os.path.join(data_dir, 'train'),
- '--validpref', os.path.join(data_dir, 'valid'),
- '--testpref', os.path.join(data_dir, 'test'),
- '--thresholdtgt', '0',
- '--thresholdsrc', '0',
- '--destdir', data_dir,
- ])
- preprocess.main(preprocess_args)
- def train_model(self, data_dir):
- train_parser = options.get_training_parser()
- train_args = options.parse_args_and_arch(
- train_parser,
- [
- data_dir,
- '--arch', 'fconv_iwslt_de_en',
- '--optimizer', 'nag',
- '--lr', '0.05',
- '--max-tokens', '500',
- '--save-dir', data_dir,
- '--max-epoch', '1',
- '--no-progress-bar',
- '--distributed-world-size', '1',
- ],
- )
- train.main(train_args)
- def generate(self, data_dir):
- generate_parser = options.get_generation_parser()
- generate_args = generate_parser.parse_args([
- data_dir,
- '--path', os.path.join(data_dir, 'checkpoint_best.pt'),
- '--beam', '5',
- '--batch-size', '32',
- '--gen-subset', 'valid',
- '--no-progress-bar',
- ])
- # evaluate model in batch mode
- generate.main(generate_args)
- # evaluate model interactively
- generate_args.max_sentences = None
- orig_stdin = sys.stdin
- sys.stdin = StringIO('h e l l o\n')
- interactive.main(generate_args)
- sys.stdin = orig_stdin
- def mock_stdout(self):
- self._orig_stdout = sys.stdout
- sys.stdout = StringIO()
- def unmock_stdout(self):
- if hasattr(self, '_orig_stdout'):
- sys.stdout = self._orig_stdout
- if __name__ == '__main__':
- unittest.main()
|