Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

gwdl.py 11 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
  1. # source: https://github.com/LucasFidon/GeneralizedWassersteinDiceLoss
  2. #
  3. # proposed in:
  4. # Fidon, L. et al. (2018). Generalised Wasserstein Dice Score for Imbalanced Multi-class
  5. # Segmentation Using Holistic Convolutional Networks. In: Crimi, A., Bakas, S., Kuijf, H.,
  6. # Menze, B., Reyes, M. (eds) Brainlesion: Glioma, Multiple Sclerosis, Stroke and Traumatic
  7. # Brain Injuries. BrainLes 2017. Lecture Notes in Computer Science(), vol 10670. Springer,
  8. # Cham. https://doi.org/10.1007/978-3-319-75238-9_6
  9. import numpy as np
  10. import torch
  11. import torch.nn.functional as F
  12. from torch.nn.modules.loss import _Loss
  13. SUPPORTED_WEIGHTING = ["default", "GDL"]
  14. class GeneralizedWassersteinDiceLoss(_Loss):
  15. """
  16. Generalized Wasserstein Dice Loss [1] in PyTorch.
  17. Optionally, one can use a weighting method for the
  18. class-specific sum of errors similar to the one used
  19. in the generalized Dice Loss [2].
  20. For this behaviour, please use weighting_mode='GDL'.
  21. The exact formula of the Wasserstein Dice loss in this case
  22. can be found in the Appendix of [3].
  23. References:
  24. ===========
  25. [1] "Generalised Wasserstein Dice Score for Imbalanced Multi-class
  26. Segmentation using Holistic Convolutional Networks."
  27. Fidon L. et al. MICCAI BrainLes (2017).
  28. [2] "Generalised dice overlap as a deep learning loss function
  29. for highly unbalanced segmentations."
  30. Sudre C., et al. MICCAI DLMIA (2017).
  31. [3] "Comparative study of deep learning methods for the automatic
  32. segmentation of lung, lesion and lesion type in CT scans of
  33. COVID-19 patients."
  34. Tilborghs, S. et al. arXiv preprint arXiv:2007.15546 (2020).
  35. """
  36. def __init__(self, dist_matrix, weighting_mode="default", reduction="mean"):
  37. """
  38. :param dist_matrix: 2d tensor or 2d numpy array; matrix of distances
  39. between the classes.
  40. It must have dimension C x C where C is the number of classes.
  41. :param: weighting_mode: str; indicates how to weight the class-specific
  42. sum of errors.
  43. 'default' corresponds to the GWDL used in the original paper [1],
  44. 'GDL' corresponds to the GWDL used in [2].
  45. :param reduction: str; reduction mode.
  46. References:
  47. ===========
  48. [1] "Generalised Wasserstein Dice Score for Imbalanced Multi-class
  49. Segmentation using Holistic Convolutional Networks."
  50. Fidon L. et al. MICCAI BrainLes (2017).
  51. [2] "Comparative study of deep learning methods for the automatic
  52. segmentation of lung, lesion and lesion type in CT scans of
  53. COVID-19 patients."
  54. Tilborghs, S. et al. arXiv preprint arXiv:2007.15546 (2020).
  55. """
  56. super(GeneralizedWassersteinDiceLoss, self).__init__(reduction=reduction)
  57. assert (
  58. weighting_mode in SUPPORTED_WEIGHTING
  59. ), "weighting_mode must be in %s" % str(SUPPORTED_WEIGHTING)
  60. self.M = dist_matrix
  61. if isinstance(self.M, np.ndarray):
  62. self.M = torch.from_numpy(self.M)
  63. if torch.max(self.M) != 1:
  64. print(
  65. "Normalize the maximum of the distance matrix "
  66. "used in the Generalized Wasserstein Dice Loss to 1."
  67. )
  68. self.M = self.M / torch.max(self.M)
  69. self.num_classes = self.M.size(0)
  70. self.alpha_mode = weighting_mode
  71. self.reduction = reduction
  72. def forward(self, input, target):
  73. """
  74. Compute the Generalized Wasserstein Dice loss
  75. between input and target tensors.
  76. :param input: tensor. input is the scores maps (before softmax).
  77. The expected shape of input is (N, C, H, W, D) in 3d
  78. and (N, C, H, W) in 2d.
  79. :param target: target is the target segmentation.
  80. The expected shape of target is (N, H, W, D) or (N, 1, H, W, D) in 3d
  81. and (N, H, W) or (N, 1, H, W) in 2d.
  82. :return: scalar tensor. Loss function value.
  83. """
  84. epsilon = np.spacing(1) # smallest number available
  85. # Convert the target segmentation to long if needed
  86. target = target.long()
  87. # Aggregate spatial dimensions
  88. flat_input = input.view(input.size(0), input.size(1), -1) # b,c,s
  89. flat_target = target.view(target.size(0), -1) # b,s
  90. # Apply the softmax to the input scores map
  91. probs = F.softmax(flat_input, dim=1) # b,c,s
  92. # Compute the Wasserstein distance map
  93. wass_dist_map = self.wasserstein_distance_map(probs, flat_target)
  94. # Compute the generalised number of true positives
  95. alpha = self.compute_alpha_generalized_true_positives(flat_target)
  96. # Compute the Generalized Wasserstein Dice loss
  97. if self.alpha_mode == "GDL":
  98. # use GDL-style alpha weights (i.e. normalize by the volume of each class)
  99. # contrary to [1] we also use alpha in the "generalized all error".
  100. true_pos = self.compute_generalized_true_positive(
  101. alpha, flat_target, wass_dist_map
  102. )
  103. denom = self.compute_denominator(alpha, flat_target, wass_dist_map)
  104. else: # default: as in [1]
  105. # (i.e. alpha=1 for all foreground classes and 0 for the background).
  106. # Compute the generalised number of true positives
  107. true_pos = self.compute_generalized_true_positive(
  108. alpha, flat_target, wass_dist_map
  109. )
  110. all_error = torch.sum(wass_dist_map, dim=1)
  111. denom = 2 * true_pos + all_error
  112. wass_dice = (2.0 * true_pos + epsilon) / (denom + epsilon)
  113. wass_dice_loss = 1.0 - wass_dice
  114. if self.reduction == "sum":
  115. return wass_dice_loss.sum()
  116. elif self.reduction == "none":
  117. return wass_dice_loss
  118. else: # default is mean reduction
  119. return wass_dice_loss.mean()
  120. def wasserstein_distance_map(self, flat_proba, flat_target):
  121. """
  122. Compute the voxel-wise Wasserstein distance (eq. 6 in [1]) for
  123. the flattened prediction and the flattened labels (ground_truth)
  124. with respect to the distance matrix on the label space M.
  125. References:
  126. ===========
  127. [1] "Generalised Wasserstein Dice Score for Imbalanced Multi-class
  128. Segmentation using Holistic Convolutional Networks",
  129. Fidon L. et al. MICCAI BrainLes 2017
  130. """
  131. # Turn the distance matrix to a map of identical matrix
  132. M_extended = torch.clone(self.M).to(flat_proba.device)
  133. M_extended = torch.unsqueeze(M_extended, dim=0) # C,C -> 1,C,C
  134. M_extended = torch.unsqueeze(M_extended, dim=3) # 1,C,C -> 1,C,C,1
  135. M_extended = M_extended.expand(
  136. (
  137. flat_proba.size(0),
  138. M_extended.size(1),
  139. M_extended.size(2),
  140. flat_proba.size(2),
  141. )
  142. )
  143. # Expand the feature dimensions of the target
  144. flat_target_extended = torch.unsqueeze(flat_target, dim=1) # b,s -> b,1,s
  145. flat_target_extended = flat_target_extended.expand( # b,1,s -> b,C,s
  146. (flat_target.size(0), M_extended.size(1), flat_target.size(1))
  147. )
  148. flat_target_extended = torch.unsqueeze(
  149. flat_target_extended, dim=1
  150. ) # b,C,s -> b,1,C,s
  151. # Extract the vector of class distances for the ground-truth label at each voxel
  152. M_extended = torch.gather(
  153. M_extended, dim=1, index=flat_target_extended
  154. ) # b,C,C,s -> b,1,C,s
  155. M_extended = torch.squeeze(M_extended, dim=1) # b,1,C,s -> b,C,s
  156. # Compute the wasserstein distance map
  157. wasserstein_map = M_extended * flat_proba
  158. # Sum over the classes
  159. wasserstein_map = torch.sum(wasserstein_map, dim=1) # b,C,s -> b,s
  160. return wasserstein_map
  161. def compute_generalized_true_positive(
  162. self, alpha, flat_target, wasserstein_distance_map
  163. ):
  164. # Extend alpha to a map and select value at each voxel according to flat_target
  165. alpha_extended = torch.unsqueeze(alpha, dim=2) # b,C -> b,C,1
  166. alpha_extended = alpha_extended.expand( # b,C,1 -> b,C,s
  167. (flat_target.size(0), self.num_classes, flat_target.size(1))
  168. )
  169. flat_target_extended = torch.unsqueeze(flat_target, dim=1) # b,s -> b,1,s
  170. alpha_extended = torch.gather(
  171. alpha_extended, index=flat_target_extended, dim=1
  172. ) # b,C,s -> b,1,s
  173. # Compute the generalized true positive as in eq. 9 of [1]
  174. generalized_true_pos = torch.sum(
  175. alpha_extended * (1.0 - wasserstein_distance_map),
  176. dim=[1, 2],
  177. )
  178. return generalized_true_pos
  179. def compute_denominator(self, alpha, flat_target, wasserstein_distance_map):
  180. # Extend alpha to a map and select value at each voxel according to flat_target
  181. alpha_extended = torch.unsqueeze(alpha, dim=2) # b,C -> b,C,1
  182. alpha_extended = alpha_extended.expand( # b,C,1 -> b,C,s
  183. (flat_target.size(0), self.num_classes, flat_target.size(1))
  184. )
  185. flat_target_extended = torch.unsqueeze(flat_target, dim=1) # b,s -> b,1,s
  186. alpha_extended = torch.gather(
  187. alpha_extended, index=flat_target_extended, dim=1
  188. ) # b,C,s -> b,1,s
  189. # Compute the generalized true positive as in eq. 9
  190. generalized_true_pos = torch.sum(
  191. alpha_extended * (2.0 - wasserstein_distance_map),
  192. dim=[1, 2],
  193. )
  194. return generalized_true_pos
  195. def compute_alpha_generalized_true_positives(self, flat_target):
  196. """
  197. Compute the weights alpha_l of eq. 9 in [1].
  198. References:
  199. ===========
  200. [1] "Generalised Wasserstein Dice Score for Imbalanced Multi-class
  201. Segmentation using Holistic Convolutional Networks",
  202. Fidon L. et al. MICCAI BrainLes 2017.
  203. """
  204. if self.alpha_mode == "GDL": # GDL style
  205. # Define alpha like in the generalized dice loss
  206. # i.e. the inverse of the volume of each class.
  207. # Convert target to one-hot class encoding.
  208. one_hot = (
  209. F.one_hot(flat_target, num_classes=self.num_classes) # shape: b,c,s
  210. .permute(0, 2, 1)
  211. .float()
  212. )
  213. volumes = torch.sum(one_hot, dim=2) # b,c
  214. alpha = 1.0 / (volumes + 1.0)
  215. else: # default, i.e. as in [1]
  216. # alpha weights are 0 for the background and 1 otherwise
  217. alpha_np = np.ones((flat_target.size(0), self.num_classes)) # b,c
  218. alpha_np[:, 0] = 0.0
  219. alpha = torch.from_numpy(alpha_np).float()
  220. if torch.cuda.is_available():
  221. alpha = alpha.cuda()
  222. return alpha
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...