Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

ddrnet_loss.py 1.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
  1. import torch
  2. from typing import Union
  3. from super_gradients.training.losses.ohem_ce_loss import OhemCELoss
  4. class DDRNetLoss(OhemCELoss):
  5. def __init__(self,
  6. threshold: float = 0.7,
  7. ohem_percentage: float = 0.1,
  8. weights: list = [1.0, 0.4],
  9. ignore_label=255):
  10. """
  11. This loss is an extension of the Ohem (Online Hard Example Mining Cross Entropy) Loss.
  12. as define in paper:
  13. Accurate Semantic Segmentation of Road Scenes ( https://arxiv.org/pdf/2101.06085.pdf )
  14. :param threshold: threshold to th hard example mining algorithm
  15. :param ohem_percentage: minimum percentage of total pixels for the hard example mining algorithm
  16. (taking only the largest) losses
  17. :param weights: weights per each input of the loss. This loss supports a multi output (like in DDRNet with
  18. an auxiliary head). the losses of each head can be weighted.
  19. :param ignore_label: targets label to be ignored
  20. """
  21. super().__init__(threshold=threshold, mining_percent=ohem_percentage, ignore_lb=ignore_label)
  22. self.weights = weights
  23. def forward(self, predictions_list: Union[list, tuple, torch.Tensor],
  24. targets: torch.Tensor):
  25. if isinstance(predictions_list, torch.Tensor):
  26. predictions_list = (predictions_list,)
  27. assert len(predictions_list) == len(self.weights), "num of prediction must be the same as num of loss weights"
  28. losses = []
  29. unweighted_losses = []
  30. for predictions, weight in zip(predictions_list, self.weights):
  31. unweighted_loss = super().forward(predictions, targets)
  32. unweighted_losses.append(unweighted_loss)
  33. losses.append(unweighted_loss * weight)
  34. total_loss = sum(losses)
  35. unweighted_losses.append(total_loss)
  36. return total_loss, torch.stack(unweighted_losses, dim=0).detach()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...