Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

stage_04_evaluate.py 2.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
  1. import argparse
  2. import os
  3. import math
  4. from tqdm import tqdm
  5. import logging
  6. from src.utils.all_utils import read_yaml, create_directory, save_reports
  7. import joblib
  8. import numpy as np
  9. from sklearn.ensemble import RandomForestClassifier
  10. from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve, roc_curve
  11. STAGE = "Four"
  12. logging.basicConfig(
  13. filename=os.path.join("logs", 'running_logs.log'),
  14. level=logging.INFO,
  15. format="[%(asctime)s: %(levelname)s: %(module)s]: %(message)s",
  16. filemode="a"
  17. )
  18. def main(config_path):
  19. config = read_yaml(config_path)
  20. # params = read_yaml(params_path)
  21. artifacts = config["artifacts"]
  22. featurized_data_dir_path = os.path.join(artifacts["ARTIFACTS_DIR"], artifacts["FEATURIZED_DATA"])
  23. featurized_test_data_path = os.path.join(featurized_data_dir_path, artifacts["FEATURIZED_OUT_TEST"])
  24. model_dir_path = os.path.join(artifacts["ARTIFACTS_DIR"], artifacts["MODEL_DIR"])
  25. model_path = os.path.join(model_dir_path, artifacts["MODEL_NAME"])
  26. # Load matrix
  27. matrix = joblib.load(featurized_test_data_path)
  28. # Load model
  29. model = joblib.load(model_path)
  30. labels = np.squeeze(matrix[:, 1].toarray())
  31. X = matrix[:,2:]
  32. # Predict
  33. predictions_by_class = model.predict_proba(X)
  34. predictions = predictions_by_class[:, 1]
  35. PRC_json_path = config['plots']['PRC']
  36. ROC_json_path = config['plots']['ROC']
  37. scores_json_path = config['metrics']['SCORES']
  38. avg_prec = average_precision_score(labels, predictions)
  39. roc_auc = roc_auc_score(labels, predictions)
  40. scores = {
  41. "avg_prec": avg_prec,
  42. "roc_auc": roc_auc
  43. }
  44. save_reports(scores, scores_json_path)
  45. # PRC
  46. precision, recall, thresh = precision_recall_curve(labels, predictions)
  47. nth_point = math.ceil(len(thresh)/1000)
  48. prc_points = list(zip(precision, recall, thresh))[::nth_point]
  49. prc_data = {
  50. "prc": [
  51. {"precision": p, "recall": r, "threshold": t}
  52. for p, r, t in prc_points
  53. ]
  54. }
  55. save_reports(prc_data, PRC_json_path)
  56. fpr, tpr, roc_threshold = roc_curve(labels, predictions)
  57. roc_data = {
  58. "roc": [
  59. {"fpr": fp, "tpr": tp, "threshold": t}
  60. for fp, tp, t in zip(fpr, tpr, roc_threshold)
  61. ]
  62. }
  63. save_reports(roc_data, ROC_json_path)
  64. if __name__ == '__main__':
  65. args = argparse.ArgumentParser()
  66. args.add_argument("--config", "-c", default="config/config.yaml")
  67. # args.add_argument("--params", "-p", default="params.yaml")
  68. parsed_args = args.parse_args()
  69. try:
  70. logging.info("\n********************")
  71. logging.info(f">>>>> stage {STAGE} started <<<<<")
  72. main(config_path=parsed_args.config)
  73. logging.info(f">>>>> stage {STAGE} completed!<<<<<\n")
  74. except Exception as e:
  75. logging.exception(e)
  76. raise e
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...