Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

models.py 5.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
  1. from functools import cached_property
  2. import lightgbm as lgb
  3. from sklearn.base import TransformerMixin
  4. from sklearn.metrics import *
  5. from yspecies.partition import ExpressionPartitions
  6. from yspecies.utils import *
  7. @dataclass(frozen=True)
  8. class BasicMetrics:
  9. MAE: float
  10. MSE: float
  11. huber: float
  12. @staticmethod
  13. def from_dict(dict: Dict):
  14. return BasicMetrics(dict["l1"], dict["l2"], dict["huber"])
  15. @staticmethod
  16. def from_dict(dict: Dict, row: int):
  17. return BasicMetrics(dict["l1"][row], dict["l2"][row], dict["huber"][row])
  18. @staticmethod
  19. def parse_eval(evals_result: Dict):
  20. dict = list(evals_result.values())[0]
  21. l = len(dict["l1"])
  22. return [BasicMetrics.from_dict(dict, i) for i in range(0, l)]
  23. @dataclass(frozen=True)
  24. class Metrics:
  25. @staticmethod
  26. def from_numpy(arr: np.ndarray):
  27. return Metrics(arr[0], arr[1], arr[2], arr[3])
  28. @staticmethod
  29. def average(metrics: List['Metrics']) -> 'Metrics':
  30. return Metrics.from_numpy(np.average([m.to_numpy for m in metrics], axis=0))
  31. '''
  32. Class to store metrics
  33. '''
  34. @staticmethod
  35. def to_dataframe(metrics: Union[List['Metrics'], 'Metrics']) -> pd.DataFrame:
  36. metrics = [metrics] if isinstance(metrics, Metrics) else metrics
  37. mts = pd.DataFrame(np.zeros([len(metrics), 4]), columns=["R^2", "MAE", "MSE", "huber"]) #, "MSLE"
  38. for i, m in enumerate(metrics):
  39. mts.iloc[i] = m.to_numpy
  40. return mts
  41. @staticmethod
  42. def calculate(ground_truth, prediction, huber: float = None) -> 'Metrics':
  43. """
  44. Calculates metrics while getting huber from outside
  45. :param ground_truth:
  46. :param prediction:
  47. :param huber:
  48. :return:
  49. """
  50. return Metrics(
  51. r2_score(ground_truth, prediction),
  52. mean_absolute_error(ground_truth, prediction),
  53. mean_squared_error(ground_truth, prediction),
  54. huber=huber
  55. #mean_squared_log_error(ground_truth, prediction)
  56. )
  57. R2: float
  58. MAE: float
  59. MSE: float
  60. huber: float
  61. #MSLE: float
  62. @cached_property
  63. def to_numpy(self):
  64. return np.array([self.R2, self.MAE, self.MSE, self.huber])
  65. @dataclass(frozen=True)
  66. class ResultsCV:
  67. parameters: Dict
  68. evaluation: Dict
  69. @staticmethod
  70. def take_best(results: List['ResultsCV'], metrics: str = "huber", last: bool = False):
  71. result: float = None
  72. for r in results:
  73. value = r.last(metrics) if last else r.min(metrics)
  74. result = value if result is None or value < result else result
  75. return result
  76. @cached_property
  77. def keys(self):
  78. return list(self.evaluation.keys())
  79. @cached_property
  80. def mins(self):
  81. return {k: (np.array(self.evaluation[k]).min()) for k in self.keys}
  82. @cached_property
  83. def latest(self):
  84. return {k: (np.array(self.evaluation[k])[-1]) for k in self.keys}
  85. def min(self, metrics: str) -> float:
  86. return self.mins[metrics] if metrics in self.mins else self.mins[metrics+"-mean"]
  87. def last(self, metrics: str) -> float:
  88. return self.latest[metrics] if metrics in self.latest else self.latest[metrics+"-mean"]
  89. def _repr_html_(self):
  90. first = self.evaluation[self.keys[0]]
  91. return f"""<table border='2'>
  92. <caption><h3>CrossValidation results</h3><caption>
  93. <tr style='text-align:center'>{"".join([f'<th>{k}</th>' for k in self.keys])}</tr>
  94. {"".join(["<tr>" + "".join([f"<td>{self.evaluation[k][i]}</td>" for k in self.keys]) + "</tr>" for i in range(0, len(first))])}
  95. </table>"""
  96. @dataclass
  97. class BasicCrossValidator(TransformerMixin):
  98. evaluation: ResultsCV = None
  99. num_iterations: int = 200
  100. early_stopping_rounds: int = 10
  101. def num_boost_round(self, parameters: Dict):
  102. return parameters.get("num_iterations") if parameters.get("num_iterations") is not None else parameters.get("num_boost_round") if parameters.get("num_boost_round") is not None else self.num_iterations
  103. def fit(self, to_fit: Tuple[ExpressionPartitions, Dict], y=None) -> Dict:
  104. partitions, parameters = to_fit
  105. cat = partitions.categorical_index if partitions.features.has_categorical else "auto"
  106. lgb_train = lgb.Dataset(partitions.X, partitions.Y, categorical_feature=cat, free_raw_data=False)
  107. num_boost_round = self.num_boost_round(parameters)
  108. iterations = parameters.get("num_boost_round") if parameters.get("num_iterations") is None else parameters.get("num_boost_round")
  109. stopping_callback = lgb.early_stopping(self.early_stopping_rounds)
  110. eval_hist = lgb.cv(parameters,
  111. lgb_train,
  112. folds=partitions.folds,
  113. metrics=["mae", "mse", "huber"],
  114. categorical_feature=cat,
  115. show_stdv=True,
  116. verbose_eval=num_boost_round,
  117. seed=partitions.seed,
  118. num_boost_round=num_boost_round,
  119. #early_stopping_rounds=self.early_stopping_rounds,
  120. callbacks=[stopping_callback]
  121. )
  122. self.evaluation = ResultsCV(parameters, eval_hist)
  123. return self
  124. def transform(self, to_fit: Tuple[ExpressionPartitions, Dict]):
  125. assert self.evaluation is not None, "Cross validation should be fitted before calling transform!"
  126. return self.evaluation
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...