Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

evaluate.py 1.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  1. """
  2. Compute model performance on test set and save results to `mlflow`.
  3. """
  4. import dask
  5. import dask.distributed
  6. import mlflow
  7. import pickle
  8. import sklearn.metrics as metrics
  9. from sklearn.metrics import precision_recall_curve
  10. import conf
  11. import featurization
  12. import split_train_test
  13. import train_model
  14. import xml_to_tsv
  15. @dask.delayed
  16. def evaluate(model_file, test_matrix_file):
  17. """Return AUC for text dataset"""
  18. with open(model_file, 'rb') as fd:
  19. model = pickle.load(fd)
  20. with open(test_matrix_file, 'rb') as fd:
  21. matrix = pickle.load(fd)
  22. labels = matrix[:, 1].toarray()
  23. x = matrix[:, 2:]
  24. predictions_by_class = model.predict_proba(x)
  25. predictions = predictions_by_class[:, 1]
  26. precision, recall, thresholds = precision_recall_curve(labels, predictions)
  27. auc = metrics.auc(recall, precision)
  28. return auc
  29. if __name__ == '__main__':
  30. client = dask.distributed.Client('localhost:8786')
  31. MODEL_FILE = conf.model
  32. TEST_MATRIX_FILE = conf.test_matrix
  33. METRICS_FILE = conf.metrics_file
  34. auc = evaluate(MODEL_FILE, TEST_MATRIX_FILE).compute()
  35. print('AUC={}'.format(auc))
  36. with open(METRICS_FILE, 'w') as fd:
  37. fd.write('AUC: {:4f}\n'.format(auc))
  38. CONFIGURATIONS = {
  39. 'xml_to_tsv': xml_to_tsv.get_params(),
  40. 'split_train_test': split_train_test.get_params(),
  41. 'featurization': featurization.get_params(),
  42. 'train_model': train_model.get_params()
  43. }
  44. with mlflow.start_run():
  45. for stage, params in CONFIGURATIONS.items():
  46. for param, value in CONFIGURATIONS.items():
  47. mlflow.log_param(param, value)
  48. mlflow.log_metric("AUC", auc)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...