Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

evaluate.py 3.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  1. """
  2. Compute model performance on test set and save results to MLflow.
  3. Routine Listings
  4. ----------------
  5. evaluate()
  6. Return AUC for text dataset.
  7. save_mlflow_run(params, metrices, artifacts)
  8. Save MLflow run (params, metrices, artifacts) to tracking server.
  9. save_plot(path, train_auc, test_auc)
  10. Save plot of train and test AUC to file.
  11. """
  12. import dask
  13. import dask.distributed
  14. import mlflow
  15. import pandas as pd
  16. import pickle
  17. import sklearn.metrics as metrics
  18. from sklearn.metrics import precision_recall_curve
  19. import conf
  20. import featurization
  21. import split_train_test
  22. import train_model
  23. import xml_to_tsv
  24. @dask.delayed
  25. def evaluate(model_file, test_matrix_file):
  26. """Return AUC for text dataset."""
  27. with open(model_file, 'rb') as fd:
  28. model = pickle.load(fd)
  29. with open(test_matrix_file, 'rb') as fd:
  30. matrix = pickle.load(fd)
  31. labels = matrix[:, 1].toarray()
  32. x = matrix[:, 2:]
  33. predictions_by_class = model.predict_proba(x)
  34. predictions = predictions_by_class[:, 1]
  35. precision, recall, thresholds = precision_recall_curve(labels, predictions)
  36. auc = metrics.auc(recall, precision)
  37. return auc
  38. @dask.delayed
  39. def save_mlflow_run(params, metrices, artifacts):
  40. """Save MLflow run (params, metrices, artifacts) to tracking server."""
  41. mlflow.set_tracking_uri('http://localhost:5000')
  42. mlflow.set_experiment('dvc_dask_use_case')
  43. with mlflow.start_run():
  44. for stage, stage_params in params.items():
  45. for key, value in stage_params.items():
  46. mlflow.log_param(key, value)
  47. for metric, value in metrices.items():
  48. mlflow.log_metric(metric, value)
  49. for path in artifacts:
  50. mlflow.log_artifact(path)
  51. @dask.delayed
  52. def save_plot(path, train_auc, test_auc):
  53. """Save plot of train and test AUC to file."""
  54. data = pd.Series(
  55. {'train_auc': train_auc, 'test_auc': test_auc})
  56. ax = data.plot.bar()
  57. fig = ax.get_figure()
  58. fig.savefig(path)
  59. if __name__ == '__main__':
  60. client = dask.distributed.Client('localhost:8786')
  61. INPUT_TRAIN_MATRIX_PATH = conf.data_dir/'featurization'/'matrix-train.p'
  62. INPUT_TEST_MATRIX_PATH = conf.data_dir/'featurization'/'matrix-test.p'
  63. INPUT_MODEL_PATH = conf.data_dir/'train_model'/'model.p'
  64. dvc_stage_name = __file__.strip('.py')
  65. STAGE_OUTPUT_PATH = conf.data_dir/dvc_stage_name
  66. conf.remote_mkdir(STAGE_OUTPUT_PATH).compute()
  67. OUTPUT_METRICS_PATH = 'eval.txt'
  68. OUTPUT_PLOT_PATH = STAGE_OUTPUT_PATH/'train_test_auc_plot.png'
  69. train_auc = evaluate(INPUT_MODEL_PATH, INPUT_TRAIN_MATRIX_PATH).compute()
  70. test_auc = evaluate(INPUT_MODEL_PATH, INPUT_TEST_MATRIX_PATH).compute()
  71. save_plot(OUTPUT_PLOT_PATH, train_auc, test_auc).compute()
  72. print('TRAIN_AUC={}'.format(train_auc))
  73. print('TEST_AUC={}'.format(test_auc))
  74. with open(OUTPUT_METRICS_PATH, 'w') as fd:
  75. fd.write('TRAIN_AUC: {:4f}\n'.format(train_auc))
  76. fd.write('TEST_AUC: {:4f}\n'.format(test_auc))
  77. CONFIGURATIONS = {
  78. 'xml_to_tsv': xml_to_tsv.get_params(),
  79. 'split_train_test': split_train_test.get_params(),
  80. 'featurization': featurization.get_params(),
  81. 'train_model': train_model.get_params()
  82. }
  83. overall_scores = {
  84. 'TRAIN_AUC': train_auc,
  85. 'TEST_AUC': test_auc
  86. }
  87. save_mlflow_run(
  88. CONFIGURATIONS, overall_scores, [OUTPUT_PLOT_PATH]).compute()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...