Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

bce_dice_loss.py 1.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
  1. from typing import List
  2. import torch
  3. from super_gradients.common.object_names import Losses
  4. from super_gradients.common.registry.registry import register_loss
  5. from super_gradients.training.losses.bce_loss import BCE
  6. from super_gradients.training.losses.dice_loss import BinaryDiceLoss
  7. @register_loss(name=Losses.BCE_DICE_LOSS, deprecated_name="bce_dice_loss")
  8. class BCEDiceLoss(torch.nn.Module):
  9. """
  10. Binary Cross Entropy + Dice Loss
  11. Weighted average of BCE and Dice loss
  12. :param loss_weights: List of size 2 s.t loss_weights[0], loss_weights[1] are the weights for BCE, Dice respectively.
  13. :param logits: Whether to use logits or not.
  14. """
  15. def __init__(self, loss_weights: List[float] = [0.5, 0.5], logits: bool = True):
  16. super(BCEDiceLoss, self).__init__()
  17. self.loss_weights = loss_weights
  18. self.bce = BCE()
  19. self.dice = BinaryDiceLoss(apply_sigmoid=logits)
  20. def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
  21. """
  22. :param input: Network's raw output shaped (N,1,H,W)
  23. :param target: Ground truth shaped (N,H,W)
  24. """
  25. return self.loss_weights[0] * self.bce(input, target) + self.loss_weights[1] * self.dice(input, target)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...