Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

trainer.py 3.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
  1. import dagshub
  2. import joblib
  3. from typing import List
  4. import numpy as np
  5. import pandas as pd
  6. from src.base_trainer import BaseTrainer
  7. from sklearn.metrics import (
  8. roc_auc_score,
  9. average_precision_score,
  10. accuracy_score,
  11. precision_score,
  12. recall_score,
  13. f1_score,
  14. confusion_matrix,
  15. classification_report
  16. )
  17. # --------------------------------------
  18. # Trainer
  19. # --------------------------------------
  20. class Trainer(BaseTrainer):
  21. """
  22. Trainer class
  23. """
  24. def __init__(self, model, target, model_path, logs_path, random_state):
  25. super().__init__()
  26. self.model = model
  27. self.target = target
  28. self.model_path=model_path + "/model.joblib"
  29. self.metrics_path=logs_path + "/metrics.csv"
  30. self.params_path=logs_path + "/params.yml"
  31. self.random_state = random_state
  32. def fit(self, train, val, test):
  33. with dagshub.dagshub_logger(metrics_path=self.metrics_path, hparams_path=self.params_path) as logger:
  34. print("Training model...")
  35. X_train, X_val, X_test = train.drop(columns=self.target), val.drop(columns=self.target), test.drop(columns=self.target)
  36. y_train, y_val, y_test = train[self.target], val[self.target], test[self.target]
  37. self.model.fit(
  38. X_train, y_train,
  39. eval_set = [(X_val, y_val)],
  40. early_stopping_rounds=100,
  41. )
  42. joblib.dump(self.model, self.model_path)
  43. logger.log_hyperparams(model_class=type(self.model).__name__)
  44. logger.log_hyperparams({"model": self.model.get_params()})
  45. print("Evaluating model...")
  46. train_metrics = self.evaluate(X_train, y_train)
  47. print("Train metrics:")
  48. print(train_metrics)
  49. logger.log_metrics({f"train__{k}": v for k, v in train_metrics.items()})
  50. test_metrics = self.evaluate(X_test, y_test)
  51. print("Test metrics:")
  52. print(test_metrics)
  53. logger.log_metrics({f"test__{k}": v for k, v in test_metrics.items()})
  54. logger.save()
  55. logger.close()
  56. def evaluate(self, X, y):
  57. y_pred = self.model.predict(X)
  58. y_pred_proba = self.model.predict_proba(X)
  59. print(y.nunique())
  60. if y.nunique()[0] <= 2:
  61. print("confusion_matrix:\n {}".format(confusion_matrix(y, y_pred)))
  62. print("classification report:\n {}".format(classification_report(y, y_pred)))
  63. print("AUC: {}".format(roc_auc_score(y, y_pred_proba)))
  64. return {
  65. "roc_auc": roc_auc_score(y, y_pred_proba),
  66. "average_precision": average_precision_score(y, y_pred_proba),
  67. "accuracy": accuracy_score(y, y_pred),
  68. "precision": precision_score(y, y_pred),
  69. "recall": recall_score(y, y_pred),
  70. "f1": f1_score(y, y_pred),
  71. }
  72. else:
  73. print("confusion_matrix:\n {}".format(confusion_matrix(y, y_pred)))
  74. print("classification report:\n {}".format(classification_report(y, y_pred)))
  75. print("AUC: {}".format(roc_auc_score(y, y_pred_proba, multi_class='ovr')))
  76. return {
  77. "roc_auc": roc_auc_score(y, y_pred_proba, multi_class='ovr'),
  78. #"average_precision": average_precision_score(y, y_pred_proba),
  79. "accuracy": accuracy_score(y, y_pred),
  80. #"precision": precision_score(y, y_pred),
  81. #"recall": recall_score(y, y_pred),
  82. #"f1": f1_score(y, y_pred),
  83. }
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...