Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

selection.py 6.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
  1. import lightgbm as lgb
  2. import shap
  3. from lightgbm import Booster
  4. from scipy.stats import kendalltau
  5. from sklearn.metrics import *
  6. from functools import cached_property
  7. from sklearn.base import TransformerMixin
  8. from dataclasses import *
  9. from yspecies.partition import ExpressionPartitions
  10. from yspecies.utils import *
  11. from yspecies.models import *
  12. @dataclass
  13. class Fold:
  14. '''
  15. Class to contain information about the fold, useful for reproducibility
  16. '''
  17. feature_weights: np.ndarray
  18. shap_dataframe: pd.DataFrame
  19. metrics: Metrics
  20. validation_species: List = field(default_factory=lambda: [])
  21. @cached_property
  22. def shap_values(self) -> List[np.ndarray]:
  23. return self.shap_dataframe.to_numpy(copy=True)
  24. @cached_property
  25. def shap_absolute_sum(self):
  26. return self.shap_dataframe.abs().sum(axis=0)
  27. @cached_property
  28. def shap_absolute_sum_non_zero(self):
  29. return self.shap_absolute_sum[self.shap_absolute_sum>0.0].sort_values(ascending=False)
  30. from yspecies.results import FeatureResults
  31. @dataclass
  32. class ShapSelector(TransformerMixin):
  33. '''
  34. Class that gets partioner and model factory and selects best features.
  35. TODO: rewrite everything to Pipeline
  36. '''
  37. model_factory: ModelFactory
  38. models: List = field(default_factory=lambda: [])
  39. select_by_gain: bool = True #if we should use gain for selection, otherwise uses median Shap Values
  40. def fit(self, partitions: ExpressionPartitions, y=None) -> 'DataExtractor':
  41. '''
  42. trains models on fig stage
  43. :param partitions:
  44. :param y:
  45. :return:
  46. '''
  47. self.models = []
  48. ifolds = partitions.nfold
  49. for i in range(ifolds):
  50. X_train, X_test, y_train, y_test = partitions.split_fold(i)
  51. index_of_categorical = [ind for ind, c in enumerate(X_train.columns) if c in partitions.features.categorical]
  52. model = self.model_factory.regression_model(X_train, X_test, y_train, y_test, index_of_categorical)
  53. self.models.append(model)
  54. return self
  55. def compute_folds(self, partitions: ExpressionPartitions) -> List[Fold]:
  56. '''
  57. Subfunction to compute weight_of_features, shap_values_out_of_fold, metrics_out_of_fold
  58. :param partitions:
  59. :return:
  60. '''
  61. folds = partitions.nfold
  62. #shap_values_out_of_fold = np.zeros()
  63. #interaction_values_out_of_fold = [[[0 for i in range(len(X.values[0]))] for i in range(len(X.values[0]))] for z in range(len(X))]
  64. #metrics = pd.DataFrame(np.zeros([folds, 3]), columns=["R^2", "MSE", "MAE"])
  65. #.sum(axis=0)
  66. assert len(self.models) == folds, "for each bootstrap there should be a model"
  67. result = []
  68. for i in range(folds):
  69. X_test = partitions.partitions_x[i]
  70. y_test = partitions.partitions_y[i]
  71. # get trained model and record accuracy metrics
  72. model = self.models[i] #just using already trained model
  73. fold_predictions = model.predict(X_test, num_iteration=model.best_iteration)
  74. explainer = shap.TreeExplainer(model)
  75. shap_values = explainer.shap_values(partitions.X)
  76. f = Fold(model.feature_importance(importance_type='gain'),
  77. pd.DataFrame(data = shap_values, index=partitions.X.index, columns=partitions.X.columns),
  78. Metrics.calculate(y_test, fold_predictions), partitions.validation_species[i]
  79. )
  80. result.append(f)
  81. #interaction_values = explainer.shap_interaction_values(X)
  82. #shap_values_out_of_fold = np.add(shap_values_out_of_fold, shap_values)
  83. #interaction_values_out_of_fold = np.add(interaction_values_out_of_fold, interaction_values)
  84. return result
  85. def transform(self, partitions: ExpressionPartitions) -> FeatureResults:
  86. folds = self.compute_folds(partitions)
  87. fold_shap_values = [f.shap_values for f in folds]
  88. # calculate shap values out of fold
  89. mean_shap_values = np.mean(fold_shap_values, axis=0)
  90. #mean_metrics = metrics.mean(axis=0)
  91. #print("MEAN metrics = "+str(mean_metrics))
  92. shap_values_transposed = mean_shap_values.T
  93. fold_number = partitions.nfold
  94. X_transposed = partitions.X_T.values
  95. gain_score_name = 'gain_score_to_'+partitions.features.to_predict if self.select_by_gain else 'shap_absolute_sum_to_'+partitions.features.to_predict
  96. kendal_tau_name = 'kendall_tau_to_'+partitions.features.to_predict
  97. # get features that have stable weight across self.bootstraps
  98. output_features_by_weight = []
  99. for i, column in enumerate(folds[0].shap_dataframe.columns):
  100. non_zero_cols = 0
  101. cols = []
  102. for f in folds:
  103. weight = f.feature_weights[i] if self.select_by_gain else folds[0].shap_absolute_sum[column]
  104. cols.append(weight)
  105. if weight!= 0:
  106. non_zero_cols += 1
  107. if non_zero_cols == fold_number:
  108. if 'ENSG' in partitions.X.columns[i]: #TODO: change from hard-coded ENSG checkup to something more meaningful
  109. output_features_by_weight.append({
  110. 'ensembl_id': partitions.X.columns[i],
  111. gain_score_name: np.mean(cols),
  112. #'name': partitions.X.columns[i], #ensemble_data.gene_name_of_gene_id(X.columns[i]),
  113. kendal_tau_name: kendalltau(shap_values_transposed[i], X_transposed[i], nan_policy='omit')[0]
  114. })
  115. selected_features = pd.DataFrame(output_features_by_weight)
  116. selected_features = selected_features.set_index("ensembl_id")
  117. if isinstance(partitions.data.genes_meta, pd.DataFrame):
  118. selected_features = partitions.data.genes_meta.drop(columns=["species"])\
  119. .join(selected_features, how="inner") \
  120. .sort_values(by=[gain_score_name], ascending=False)
  121. return FeatureResults(selected_features, folds, partitions)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...