Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

ohem_ce_loss.py 2.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
  1. import torch
  2. from torch import nn
  3. from torch.nn.modules.loss import _Loss
  4. class OhemCELoss(_Loss):
  5. """
  6. OhemCELoss - Online Hard Example Mining Cross Entropy Loss
  7. """
  8. def __init__(self,
  9. threshold: float,
  10. mining_percent: float = 0.1,
  11. ignore_lb: int = -100,
  12. num_pixels_exclude_ignored: bool = True):
  13. """
  14. :param threshold: Sample below probability threshold, is considered hard.
  15. :param num_pixels_exclude_ignored: How to calculate total pixels from which extract mining percent of the
  16. samples.
  17. i.e for num_pixels=100, ignore_pixels=30, mining_percent=0.1:
  18. num_pixels_exclude_ignored=False => num_mining = 100 * 0.1 = 10
  19. num_pixels_exclude_ignored=True => num_mining = (100 - 30) * 0.1 = 7
  20. """
  21. super().__init__()
  22. assert 0 <= mining_percent <= 1, "mining percent should be a value from 0 to 1"
  23. self.thresh = -torch.log(torch.tensor(threshold, dtype=torch.float))
  24. self.mining_percent = mining_percent
  25. self.ignore_lb = ignore_lb
  26. self.num_pixels_exclude_ignored = num_pixels_exclude_ignored
  27. self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
  28. def forward(self, logits, labels):
  29. loss = self.criteria(logits, labels).view(-1)
  30. if self.num_pixels_exclude_ignored:
  31. # remove ignore label elements
  32. loss = loss[labels.view(-1) != self.ignore_lb]
  33. # num pixels in a batch -> num_pixels = batch_size * width * height - ignore_pixels
  34. num_pixels = loss.numel()
  35. else:
  36. num_pixels = labels.numel()
  37. # if all pixels are ignore labels, return empty loss tensor
  38. if num_pixels == 0:
  39. return torch.tensor([0.]).requires_grad_(True)
  40. num_mining = int(self.mining_percent * num_pixels)
  41. # in case mining_percent=1, prevent out of bound exception
  42. num_mining = min(num_mining, num_pixels - 1)
  43. self.thresh = self.thresh.to(logits.device)
  44. loss, _ = torch.sort(loss, descending=True)
  45. if loss[num_mining] > self.thresh:
  46. loss = loss[loss > self.thresh]
  47. else:
  48. loss = loss[:num_mining]
  49. return torch.mean(loss)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...