Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 975 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
  1. import sys
  2. import os
  3. import pickle
  4. import numpy as np
  5. import yaml
  6. from sklearn.ensemble import RandomForestClassifier
  7. params = yaml.safe_load(open('params.yaml'))['train']
  8. if len(sys.argv) != 3:
  9. sys.stderr.write('Arguments error. Usage:\n')
  10. sys.stderr.write('\tpython train.py features model\n')
  11. sys.exit(1)
  12. input = sys.argv[1]
  13. output = sys.argv[2]
  14. seed = params['seed']
  15. n_est = params['n_est']
  16. min_split = params['min_split']
  17. with open(os.path.join(input, 'train.pkl'), 'rb') as fd:
  18. matrix = pickle.load(fd)
  19. labels = np.squeeze(matrix[:, 1].toarray())
  20. x = matrix[:, 2:]
  21. sys.stderr.write('Input matrix size {}\n'.format(matrix.shape))
  22. sys.stderr.write('X matrix size {}\n'.format(x.shape))
  23. sys.stderr.write('Y matrix size {}\n'.format(labels.shape))
  24. clf = RandomForestClassifier(
  25. n_estimators=n_est,
  26. min_samples_split=min_split,
  27. n_jobs=2,
  28. random_state=seed
  29. )
  30. clf.fit(x, labels)
  31. with open(output, 'wb') as fd:
  32. pickle.dump(clf, fd)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...