1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
|
- # source: https://github.com/LIVIAETS/boundary-loss
- # paper: https://doi.org/10.1016/j.media.2020.101851
- # license: unspecified as of 2021-12-06
- # only selected code from repo
- import logging
- from functools import partial
- from typing import Any, Callable, cast, Iterable, List, Set, Tuple, TypeVar, Union
- from scipy.ndimage import distance_transform_edt as eucl_distance
- import numpy as np
- import torch
- import torch.sparse
- from torch import einsum, Tensor
- logger = logging.getLogger(__name__)
- EPS = 1e-10
- A = TypeVar("A")
- B = TypeVar("B")
- T = TypeVar("T", Tensor, np.ndarray)
- # fns
- def soft_size(a: Tensor) -> Tensor:
- return torch.einsum("bk...->bk", a)[..., None]
- def batch_soft_size(a: Tensor) -> Tensor:
- return torch.einsum("bk...->k", a)[..., None]
- # Assert utils
- def uniq(a: Tensor) -> Set:
- return set(torch.unique(a.cpu()).numpy())
- def sset(a: Tensor, sub: Iterable) -> bool:
- return uniq(a).issubset(sub)
- def eq(a: Tensor, b) -> bool:
- return torch.eq(a, b).all()
- # DISABLED: This keeps crashing at random - not sure what's causing this? maybe fp16 training?
- def simplex(t: Tensor, axis=1) -> bool:
- return True
- _sum = cast(Tensor, t.sum(axis).type(torch.float32))
- _ones = torch.ones_like(_sum, dtype=torch.float32)
- return torch.allclose(_sum, _ones)
- def one_hot(t: Tensor, axis=1) -> bool:
- return simplex(t, axis) and sset(t, [0, 1])
- # # Metrics and shitz
- def meta_dice(sum_str: str, label: Tensor, pred: Tensor, smooth: float = EPS) -> Tensor:
- assert label.shape == pred.shape
- assert one_hot(label)
- assert one_hot(pred)
- inter_size: Tensor = einsum(sum_str, [intersection(label, pred)]).type(
- torch.float32
- )
- sum_sizes: Tensor = (einsum(sum_str, [label]) + einsum(sum_str, [pred])).type(
- torch.float32
- )
- dices: Tensor = (2 * inter_size + smooth) / (sum_sizes + smooth)
- return dices
- dice_coef = partial(meta_dice, "bk...->bk")
- dice_batch = partial(meta_dice, "bk...->k") # used for 3d dice
- def intersection(a: Tensor, b: Tensor) -> Tensor:
- assert a.shape == b.shape
- assert sset(a, [0, 1])
- assert sset(b, [0, 1])
- res = a & b
- assert sset(res, [0, 1])
- return res
- def union(a: Tensor, b: Tensor) -> Tensor:
- assert a.shape == b.shape
- assert sset(a, [0, 1])
- assert sset(b, [0, 1])
- res = a | b
- assert sset(res, [0, 1])
- return res
- def inter_sum(a: Tensor, b: Tensor) -> Tensor:
- return einsum("bk...->bk", intersection(a, b).type(torch.float32))
- def union_sum(a: Tensor, b: Tensor) -> Tensor:
- return einsum("bk...->bk", union(a, b).type(torch.float32))
- # switch between representations
- def probs2class(probs: Tensor) -> Tensor:
- b, _, *img_shape = probs.shape
- assert simplex(probs)
- res = probs.argmax(dim=1)
- assert res.shape == (b, *img_shape)
- return res
- def class2one_hot(seg: Tensor, K: int) -> Tensor:
- # Breaking change but otherwise can't deal with both 2d and 3d
- # if len(seg.shape) == 3: # Only w, h, d, used by the dataloader
- # return class2one_hot(seg.unsqueeze(dim=0), K)[0]
- assert sset(seg, list(range(K))), (uniq(seg), K)
- b, *img_shape = seg.shape # type: Tuple[int, ...]
- device = seg.device
- res = torch.zeros((b, K, *img_shape), dtype=torch.int32, device=device).scatter_(
- 1, seg[:, None, ...], 1
- )
- assert res.shape == (b, K, *img_shape)
- assert one_hot(res)
- return res
- def np_class2one_hot(seg: np.ndarray, K: int) -> np.ndarray:
- return class2one_hot(torch.from_numpy(seg.copy()).type(torch.int64), K).numpy()
- def probs2one_hot(probs: Tensor) -> Tensor:
- _, K, *_ = probs.shape
- assert simplex(probs)
- res = class2one_hot(probs2class(probs), K)
- assert res.shape == probs.shape
- assert one_hot(res)
- return res
- def one_hot2dist(
- seg: np.ndarray, resolution: Tuple[float, float, float] = None, dtype=None
- ) -> np.ndarray:
- assert one_hot(torch.tensor(seg), axis=0)
- K: int = len(seg)
- res = np.zeros_like(seg, dtype=dtype)
- for k in range(K):
- posmask = seg[k].astype(np.bool)
- if posmask.any():
- negmask = ~posmask
- res[k] = (
- eucl_distance(negmask, sampling=resolution) * negmask
- - (eucl_distance(posmask, sampling=resolution) - 1) * posmask
- )
- # The idea is to leave blank the negative classes
- # since this is one-hot encoded, another class will supervise that pixel
- return res
- class CrossEntropy:
- def __init__(self, **kwargs):
- # Self.idc is used to filter out some classes of the target mask. Use fancy indexing
- self.idc: List[int] = kwargs["idc"]
- logger.debug(f"Initialized {self.__class__.__name__} with {kwargs}")
- def __call__(self, probs: Tensor, target: Tensor) -> Tensor:
- assert simplex(probs) and simplex(target)
- log_p: Tensor = (probs[:, self.idc, ...] + 1e-10).log()
- mask: Tensor = cast(Tensor, target[:, self.idc, ...].type(torch.float32))
- loss = -einsum("bkwh,bkwh->", mask, log_p)
- loss /= mask.sum() + 1e-10
- return loss
- class GeneralizedDice:
- def __init__(self, **kwargs):
- # Self.idc is used to filter out some classes of the target mask. Use fancy indexing
- self.idc: List[int] = kwargs["idc"]
- logger.debug(f"Initialized {self.__class__.__name__} with {kwargs}")
- def __call__(self, probs: Tensor, target: Tensor) -> Tensor:
- assert simplex(probs) and simplex(target)
- pc = probs[:, self.idc, ...].type(torch.float32)
- tc = target[:, self.idc, ...].type(torch.float32)
- # modification: move EPS outside to reduce risk of zero-division
- # orig: w: Tensor = 1 / ((einsum("bkwh->bk", tc).type(torch.float32) + EPS) ** 2)
- w: Tensor = 1 / ((einsum("bkwh->bk", tc).type(torch.float32) ** 2) + EPS)
- intersection: Tensor = w * einsum("bkwh,bkwh->bk", pc, tc)
- union: Tensor = w * (einsum("bkwh->bk", pc) + einsum("bkwh->bk", tc))
- divided: Tensor = 1 - 2 * (einsum("bk->b", intersection) + EPS) / (
- einsum("bk->b", union) + EPS
- )
- loss = divided.mean()
- return loss
- class DiceLoss:
- def __init__(self, **kwargs):
- # Self.idc is used to filter out some classes of the target mask. Use fancy indexing
- self.idc: List[int] = kwargs["idc"]
- logger.debug(f"Initialized {self.__class__.__name__} with {kwargs}")
- def __call__(self, probs: Tensor, target: Tensor) -> Tensor:
- assert simplex(probs) and simplex(target)
- pc = probs[:, self.idc, ...].type(torch.float32)
- tc = target[:, self.idc, ...].type(torch.float32)
- intersection: Tensor = einsum("bcwh,bcwh->bc", pc, tc)
- union: Tensor = einsum("bkwh->bk", pc) + einsum("bkwh->bk", tc)
- divided: Tensor = torch.ones_like(intersection) - (2 * intersection + EPS) / (
- union + EPS
- )
- loss = divided.mean()
- return loss
- class SurfaceLoss:
- def __init__(self, **kwargs):
- # Self.idc is used to filter out some classes of the target mask. Use fancy indexing
- self.idc: List[int] = kwargs["idc"]
- logger.debug(f"Initialized {self.__class__.__name__} with {kwargs}")
- def __call__(self, probs: Tensor, dist_maps: Tensor) -> Tensor:
- assert simplex(probs)
- assert not one_hot(dist_maps)
- pc = probs[:, self.idc, ...].type(torch.float32)
- dc = dist_maps[:, self.idc, ...].type(torch.float32)
- multipled = einsum("bkwh,bkwh->bkwh", pc, dc)
- loss = multipled.mean()
- return loss
- BoundaryLoss = SurfaceLoss
- class FocalLoss:
- def __init__(self, **kwargs):
- # Self.idc is used to filter out some classes of the target mask. Use fancy indexing
- self.idc: List[int] = kwargs["idc"]
- self.gamma: float = kwargs["gamma"]
- logger.debug(f"Initialized {self.__class__.__name__} with {kwargs}")
- def __call__(self, probs: Tensor, target: Tensor) -> Tensor:
- assert simplex(probs) and simplex(target)
- masked_probs: Tensor = probs[:, self.idc, ...]
- log_p: Tensor = (masked_probs + EPS).log()
- mask: Tensor = cast(Tensor, target[:, self.idc, ...].type(torch.float32))
- w: Tensor = (1 - masked_probs) ** self.gamma
- loss = -einsum("bkwh,bkwh,bkwh->", w, mask, log_p)
- loss /= mask.sum() + EPS
- return loss
|