Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_model.py 1.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
  1. import sys
  2. import dask
  3. import dask.distributed
  4. import numpy as np
  5. from sklearn.ensemble import RandomForestClassifier
  6. import pickle
  7. import conf
  8. client = dask.distributed.Client('localhost:8786')
  9. INPUT = conf.train_matrix
  10. OUTPUT = conf.model
  11. @dask.delayed
  12. def workflow(input, output, seed):
  13. with open(input, 'rb') as fd:
  14. matrix = pickle.load(fd)
  15. labels = np.squeeze(matrix[:, 1].toarray())
  16. x = matrix[:, 2:]
  17. sys.stderr.write('Input matrix size {}\n'.format(matrix.shape))
  18. sys.stderr.write('X matrix size {}\n'.format(x.shape))
  19. sys.stderr.write('Y matrix size {}\n'.format(labels.shape))
  20. clf = RandomForestClassifier(n_estimators=100, n_jobs=2, random_state=seed)
  21. clf.fit(x, labels)
  22. with open(output, 'wb') as fd:
  23. pickle.dump(clf, fd)
  24. if len(sys.argv) != 2:
  25. sys.stderr.write('Arguments error. Usage:\n')
  26. sys.stderr.write(
  27. '\tpython train_model.py INPUT_MATRIX_FILE SEED OUTPUT_MODEL_FILE\n')
  28. sys.exit(1)
  29. seed = int(sys.argv[1])
  30. workflow(INPUT, OUTPUT, seed).compute()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...