Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

gdl.py 781 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
  1. # alternative implementation to default GDL
  2. import torch
  3. class GeneralizedDiceLoss(torch.nn.Module):
  4. def __init__(self):
  5. super(GeneralizedDiceLoss, self).__init__()
  6. def forward(self, inp, targ):
  7. inp = inp.contiguous().permute(0, 2, 3, 1)
  8. targ = targ.contiguous().permute(0, 2, 3, 1)
  9. w = torch.zeros((targ.shape[-1],))
  10. w = 1.0 / (torch.sum(targ, (0, 1, 2)) ** 2 + 1e-9)
  11. numerator = targ * inp
  12. numerator = w * torch.sum(numerator, (0, 1, 2))
  13. numerator = torch.sum(numerator)
  14. denominator = targ + inp
  15. denominator = w * torch.sum(denominator, (0, 1, 2))
  16. denominator = torch.sum(denominator)
  17. dice = 2.0 * (numerator + 1e-9) / (denominator + 1e-9)
  18. return 1.0 - dice
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...