Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

selection.py 7.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
  1. from dataclasses import *
  2. from functools import cached_property
  3. from typing import *
  4. import lightgbm as lgb
  5. import numpy as np
  6. import pandas as pd
  7. import shap
  8. from lightgbm import Booster
  9. from loguru import logger
  10. from sklearn.base import TransformerMixin
  11. from yspecies.models import Metrics, BasicMetrics
  12. from yspecies.partition import ExpressionPartitions
  13. @dataclass(frozen=True)
  14. class Fold:
  15. '''
  16. Class to contain information about the fold, useful for reproducibility
  17. '''
  18. num: int
  19. model: Booster
  20. partitions: ExpressionPartitions
  21. current_evals: List[BasicMetrics] = field(default_factory=lambda: [])
  22. @cached_property
  23. def explainer(self) -> shap.TreeExplainer:
  24. return shap.TreeExplainer(self.model)#, feature_perturbation=self.partitions.features.feature_perturbationw), data=self.partitions.X)
  25. @cached_property
  26. def shap_values(self):
  27. return self.explainer.shap_values(self.partitions.X)
  28. @cached_property
  29. def feature_weights(self) -> np.ndarray:
  30. return self.model.feature_importance(importance_type=self.partitions.features.importance_type)
  31. @cached_property
  32. def shap_dataframe(self) -> pd.DataFrame:
  33. return pd.DataFrame(data=self.shap_values, index=self.partitions.X.index, columns=self.partitions.X.columns)
  34. @cached_property
  35. def validation_species(self):
  36. return self.partitions.validation_species[self.num]
  37. @cached_property
  38. def _fold_train(self):
  39. return self.partitions.fold_train(self.num)
  40. @property
  41. def X_train(self):
  42. return self._fold_train[0]
  43. @property
  44. def y_train(self):
  45. return self._fold_train[1]
  46. @cached_property
  47. def X_test(self):
  48. return self.partitions.partitions_x[self.num]
  49. @cached_property
  50. def y_test(self):
  51. return self.partitions.partitions_y[self.num]
  52. @cached_property
  53. def fold_predictions(self):
  54. return self.model.predict(self.X_test)
  55. @cached_property
  56. def hold_out_predictions(self):
  57. return self.model.predict(self.partitions.hold_out_x) if self.partitions.n_hold_out > 0 else None
  58. @cached_property
  59. def validation_metrics(self) -> Metrics:
  60. #TODO: huber is wrong here
  61. return Metrics.calculate(self.partitions.hold_out_y, self.hold_out_predictions, None)
  62. @cached_property
  63. def metrics(self):
  64. return Metrics.calculate(self.y_test, self.fold_predictions, self.eval_metrics.huber)
  65. @property
  66. def eval_last_num(self) -> int:
  67. return len(self.current_evals) - 1
  68. @cached_property
  69. def eval_metrics(self):
  70. best_iteration_num = self.model.best_iteration
  71. eval_last_num = len(self.current_evals) -1
  72. metrics_num = best_iteration_num if best_iteration_num is not None and eval_last_num > best_iteration_num >= 0 else eval_last_num
  73. if self.current_evals[metrics_num].huber < self.current_evals[eval_last_num].huber:
  74. return self.current_evals[metrics_num]
  75. else:
  76. return self.current_evals[eval_last_num]
  77. @cached_property
  78. def explanation(self):
  79. return self.explainer(self.partitions.X)
  80. @cached_property
  81. def shap_values(self) -> List[np.ndarray]:
  82. return self.explainer.shap_values(X = self.partitions.X, y = self.partitions.Y)#(self.partitions.X, self.partitions.Y)
  83. @cached_property
  84. def interaction_values(self):
  85. return self.explainer.shap_interaction_values(self.partitions.X)
  86. @cached_property
  87. def shap_absolute_mean(self):
  88. return self.shap_dataframe.abs().mean(axis=0)
  89. @cached_property
  90. def shap_absolute_sum(self):
  91. return self.shap_dataframe.abs().sum(axis=0)
  92. @cached_property
  93. def shap_absolute_sum_non_zero(self):
  94. return self.shap_absolute_sum[self.shap_absolute_sum > 0.0].sort_values(ascending=False)
  95. @cached_property
  96. def expected_value(self):
  97. return self.explainer.expected_value
  98. def __repr__(self):
  99. #to fix jupyter freeze (see https://github.com/ipython/ipython/issues/9771 )
  100. return self._repr_html_()
  101. def _repr_html_(self):
  102. '''
  103. Function to provide nice HTML outlook in jupyter lab notebooks
  104. :return:
  105. '''
  106. return f"<table border='2'>" \
  107. f"<caption>Fold<caption>" \
  108. f"<tr><th>metrics</th><th>validation species</th><th>shap</th><th>nonzero shap</th><th>evals</th></tr>" \
  109. f"<tr><td>{self.metrics}</td><td>str({self.validation_species})</td><td>{str(self.shap_dataframe.shape)}</td><td>{str(self.shap_absolute_sum_non_zero.shape)}</td><td>{self.eval_metrics}</td></tr>" \
  110. f"</table>"
  111. @dataclass
  112. class CrossValidator(TransformerMixin):
  113. early_stopping_rounds: int = 10
  114. models: List = field(default_factory=lambda: [])
  115. evals: List = field(default_factory=lambda: [])
  116. @logger.catch
  117. def fit(self, to_fit: Tuple[ExpressionPartitions, Dict], y=None) -> 'CrossValidator':
  118. """
  119. :param to_fit: (partitions, parameters)
  120. :param y:
  121. :return:
  122. """
  123. partitions, parameters = to_fit
  124. self.models = []
  125. self.evals = []
  126. logger.info(f"===== fitting models with seed {partitions.seed} =====")
  127. logger.info(f"PARAMETERS:\n{parameters}")
  128. for i in range(0, partitions.n_folds - partitions.n_hold_out):
  129. X_train, X_test, y_train, y_test = partitions.split_fold(i)
  130. logger.info(f"SEED: {partitions.seed} | FOLD: {i} | VALIDATION_SPECIES: {str(partitions.validation_species[i])}")
  131. cat_index = partitions.categorical_index if len(partitions.categorical_index) > 0 else None
  132. model, eval_results = self.regression_model(X_train, X_test, y_train, y_test, parameters, cat_index, seed=partitions.seed)
  133. self.models.append(model)
  134. self.evals.append(eval_results)
  135. return self
  136. def regression_model(self, X_train, X_test, y_train, y_test, parameters: Dict, categorical=None,
  137. num_boost_round: int = 250, seed: int = None) -> Booster:
  138. '''
  139. trains a regression model
  140. :param X_train:
  141. :param X_test:
  142. :param y_train:
  143. :param y_test:
  144. :param categorical:
  145. :param parameters:
  146. :return:
  147. '''
  148. cat = categorical if (categorical is not None) and len(categorical) > 0 else "auto"
  149. lgb_train = lgb.Dataset(X_train, y_train, categorical_feature=cat)
  150. lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
  151. evals_result = {}
  152. stopping_callback = lgb.early_stopping(self.early_stopping_rounds)
  153. if seed is not None:
  154. parameters["seed"] = seed
  155. gbm = lgb.train(parameters,
  156. lgb_train,
  157. num_boost_round=num_boost_round,
  158. valid_sets=lgb_eval,
  159. evals_result=evals_result,
  160. verbose_eval=num_boost_round,
  161. callbacks=[stopping_callback]
  162. )
  163. return gbm, BasicMetrics.parse_eval(evals_result)
  164. @logger.catch
  165. def transform(self, to_select_from: Tuple[ExpressionPartitions, Dict]) -> Tuple[List[Fold], Dict]:
  166. partitions, parameters = to_select_from
  167. assert len(self.models) == partitions.n_cv_folds, "for each bootstrap there should be a model"
  168. folds = [Fold(i, self.models[i], partitions, self.evals[i]) for i in range(0, partitions.n_cv_folds)]
  169. return (folds, parameters)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...