Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

evaluate.py 1.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
  1. import json
  2. import math
  3. import os
  4. import pickle
  5. import sys
  6. import sklearn.metrics as metrics
  7. if len(sys.argv) != 6:
  8. sys.stderr.write("Arguments error. Usage:\n")
  9. sys.stderr.write("\tpython evaluate.py model features scores prc roc\n")
  10. sys.exit(1)
  11. model_file = sys.argv[1]
  12. matrix_file = os.path.join(sys.argv[2], "test.pkl")
  13. scores_file = sys.argv[3]
  14. prc_file = sys.argv[4]
  15. roc_file = sys.argv[5]
  16. with open(model_file, "rb") as fd:
  17. model = pickle.load(fd)
  18. with open(matrix_file, "rb") as fd:
  19. matrix = pickle.load(fd)
  20. labels = matrix[:, 1].toarray()
  21. x = matrix[:, 2:]
  22. predictions_by_class = model.predict_proba(x)
  23. predictions = predictions_by_class[:, 1]
  24. precision, recall, prc_thresholds = metrics.precision_recall_curve(labels, predictions)
  25. fpr, tpr, roc_thresholds = metrics.roc_curve(labels, predictions)
  26. avg_prec = metrics.average_precision_score(labels, predictions)
  27. roc_auc = metrics.roc_auc_score(labels, predictions)
  28. with open(scores_file, "w") as fd:
  29. json.dump({"avg_prec": avg_prec, "roc_auc": roc_auc}, fd, indent=4)
  30. # ROC has a drop_intermediate arg that reduces the number of points.
  31. # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html#sklearn.metrics.roc_curve.
  32. # PRC lacks this arg, so we manually reduce to 1000 points as a rough estimate.
  33. nth_point = math.ceil(len(prc_thresholds) / 1000)
  34. prc_points = list(zip(precision, recall, prc_thresholds))[::nth_point]
  35. with open(prc_file, "w") as fd:
  36. json.dump(
  37. {
  38. "prc": [
  39. {"precision": p, "recall": r, "threshold": t}
  40. for p, r, t in prc_points
  41. ]
  42. },
  43. fd,
  44. indent=4,
  45. )
  46. with open(roc_file, "w") as fd:
  47. json.dump(
  48. {
  49. "roc": [
  50. {"fpr": fp, "tpr": tp, "threshold": t}
  51. for fp, tp, t in zip(fpr, tpr, roc_thresholds)
  52. ]
  53. },
  54. fd,
  55. indent=4,
  56. )
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...