Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

results.py 6.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
  1. import matplotlib.pyplot as plt
  2. from more_itertools import flatten
  3. from functools import cached_property
  4. from dataclasses import *
  5. import shap
  6. import yspecies
  7. from yspecies.selection import Fold
  8. from yspecies.utils import *
  9. from yspecies.partition import ExpressionPartitions
  10. @dataclass
  11. class FeatureResults:
  12. '''
  13. Feature results class
  14. '''
  15. selected: pd.DataFrame
  16. folds: List[Fold]
  17. #shap_dataframes: List[pd.DataFrame]
  18. #metrics: pd.DataFrame
  19. partitions: ExpressionPartitions = field(default_factory=lambda: None)
  20. @property
  21. def head(self) -> Fold:
  22. return self.folds[0]
  23. @cached_property
  24. def validation_species(self):
  25. return [f.validation_species for f in self.folds]
  26. @cached_property
  27. def metrics(self):
  28. return yspecies.selection.Metrics.combine([f.metrics for f in self.folds]).join(pd.Series(data = self.validation_species, name="validation_species"))
  29. def __repr__(self):
  30. #to fix jupyter freeze (see https://github.com/ipython/ipython/issues/9771 )
  31. return self._repr_html_()
  32. @cached_property
  33. def shap_sums(self):
  34. #TODO: rewrite
  35. shap_positive_sums = pd.DataFrame(np.vstack([np.sum(more_or_value(v, 0.0, 0.0), axis=0) for v in self.shap_values]).T, index=self.partitions.X_T.index)
  36. shap_positive_sums = shap_positive_sums.rename(columns={c:f"plus_shap_{c}" for c in shap_positive_sums.columns})
  37. shap_negative_sums = pd.DataFrame(np.vstack([np.sum(less_or_value(v, 0.0, 0.0), axis=0) for v in self.shap_values]).T, index=self.partitions.X_T.index)
  38. shap_negative_sums = shap_negative_sums.rename(columns={c:f"minus_shap_{c}" for c in shap_negative_sums.columns})
  39. sh_cols = [c for c in flatten(zip(shap_positive_sums, shap_negative_sums))]
  40. shap_sums = shap_positive_sums.join(shap_negative_sums)[sh_cols]
  41. return shap_sums
  42. @cached_property
  43. def stable_shap_dataframe(self) -> pd.DataFrame:
  44. return pd.DataFrame(data=self.stable_shap_values, index=self.head.shap_dataframe.index, columns=self.head.shap_dataframe.columns)
  45. @cached_property
  46. def stable_shap_dataframe_T(self) ->pd.DataFrame:
  47. transposed = self.stable_shap_dataframe.T
  48. transposed.index.name = "ensembl_id"
  49. return transposed
  50. def gene_details(self, symbol: str, samples: pd.DataFrame):
  51. '''
  52. Returns details of the genes (which shap values per each sample)
  53. :param symbol:
  54. :param samples:
  55. :return:
  56. '''
  57. shaped = self.selected_extended[self.selected_extended["symbol"] == symbol]
  58. id = shaped.index[0]
  59. print(f"general info: {shaped.iloc[0][0:3]}")
  60. shaped.index = ["shap_values"]
  61. exp = self.partitions.X_T.loc[self.partitions.X_T.index == id]
  62. exp.index = ["expressions"]
  63. joined = pd.concat([exp, shaped], axis=0)
  64. result = joined.T.join(samples)
  65. result.index.name = "run"
  66. return result
  67. @cached_property
  68. def selected_extended(self):
  69. return self.selected.join(self.stable_shap_dataframe_T, how="left")
  70. @cached_property
  71. def stable_shap_values(self):
  72. return np.mean(self.shap_values, axis=0)
  73. @cached_property
  74. def shap_dataframes(self) -> List[np.ndarray]:
  75. return [f.shap_dataframe for f in self.folds]
  76. @cached_property
  77. def shap_values(self) -> List[np.ndarray]:
  78. return [f.shap_values for f in self.folds]
  79. @cached_property
  80. def feature_names(self):
  81. return self.partitions.data.genes_meta["symbol"].values
  82. def _plot_(self, shap_values: List[np.ndarray] or np.ndarray, gene_names: bool = True, save: Path = None,
  83. max_display=None, title=None, layered_violin_max_num_bins = 20,
  84. plot_type=None, color=None, axis_color="#333333", alpha=1, class_names=None
  85. ):
  86. #shap.summary_plot(shap_values, self.partitions.X, show=False)
  87. feature_names = None if gene_names is False else self.feature_names
  88. shap.summary_plot(shap_values, self.partitions.X, feature_names=feature_names, show=False,
  89. max_display=max_display, title=title, layered_violin_max_num_bins=layered_violin_max_num_bins,
  90. class_names=class_names,
  91. # class_inds=class_inds,
  92. plot_type=plot_type,
  93. color=color, axis_color=axis_color, alpha=alpha
  94. )
  95. fig = plt.gcf()
  96. if save is not None:
  97. from IPython.display import set_matplotlib_formats
  98. set_matplotlib_formats('svg')
  99. plt.savefig(save)
  100. plt.close()
  101. return fig
  102. def plot(self, gene_names: bool = True, save: Path = None,
  103. title=None, max_display=100, layered_violin_max_num_bins = 20,
  104. plot_type=None, color=None, axis_color="#333333", alpha=1, show=True, class_names=None):
  105. return self._plot_(self.stable_shap_values, gene_names, save, title, max_display,
  106. layered_violin_max_num_bins, plot_type, color, axis_color, alpha, class_names)
  107. def plot_folds(self, names: bool = True, save: Path = None, title=None,
  108. max_display=100, layered_violin_max_num_bins = 20,
  109. plot_type=None, color=None, axis_color="#333333", alpha=1):
  110. class_names = ["fold_"+str(i) for i in range(len(self.shap_values))]
  111. return self._plot_(self.shap_values, names, save, title, max_display,
  112. layered_violin_max_num_bins, plot_type, color, axis_color, alpha, class_names = class_names)
  113. def plot_one_fold(self, num: int, names: bool = True, save: Path = None, title=None,
  114. max_display=100, layered_violin_max_num_bins = 20,
  115. plot_type=None, color=None, axis_color="#333333", alpha=1):
  116. assert num < len(self.shap_values), f"there are no shap values for fold {str(num)}!"
  117. return self._plot_(self.shap_values[num], names, save, title, max_display,
  118. layered_violin_max_num_bins, plot_type, color, axis_color, alpha)
  119. def _repr_html_(self):
  120. return f"<table border='2'>" \
  121. f"<caption><h3>Feature selection results</h3><caption>" \
  122. f"<tr style='text-align:center'><th>selected</th><th>metrics</th></tr>" \
  123. f"<tr><td>{self.selected._repr_html_()}</th><th>{self.metrics._repr_html_()}</th></tr>" \
  124. f"</table>"
  125. @cached_property
  126. def selected_shap(self):
  127. return self.selected.join(self.shap_values.T.set_index())
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...