Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

class_balancing_test.py 4.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
  1. import numpy as np
  2. import torch
  3. import unittest
  4. from torch.utils.data import Dataset
  5. from super_gradients.training.datasets.balancing_classes_utils import get_repeat_factors
  6. class SingleLabelUnbalancedDataset(Dataset):
  7. def __init__(self, num_classes: int) -> None:
  8. super().__init__()
  9. self.num_classes = num_classes
  10. self.ignore_empty_annotations = False
  11. def __len__(self) -> int:
  12. return self.num_classes
  13. def __getitem__(self, idx: int) -> torch.Tensor:
  14. return torch.tensor([idx] * idx) # no class 0
  15. class MultiLabelUnbalancedDataset(Dataset):
  16. def __init__(self, num_classes: int) -> None:
  17. super().__init__()
  18. self.num_classes = num_classes
  19. self.ignore_empty_annotations = False
  20. def __len__(self) -> int:
  21. return self.num_classes
  22. def __getitem__(self, idx: int) -> torch.Tensor:
  23. return torch.tensor([idx, 0]) # class 0 appears everywhere, other classes appear only once.
  24. class ClassBalancingTest(unittest.TestCase):
  25. def setUp(self) -> None:
  26. self.single_label_dataset = SingleLabelUnbalancedDataset(num_classes=5) # [[], [1], [2,2], [3,3,3], [4,4,4,4]]
  27. self.multi_label_dataset = MultiLabelUnbalancedDataset(num_classes=5) # [[0,0], [1,0], [2,0], [3,0], [4,0]]
  28. def test_without_oversampling(self):
  29. repeat_factors = get_repeat_factors(
  30. index_to_classes=lambda idx: self.single_label_dataset[idx].tolist(),
  31. num_classes=self.single_label_dataset.num_classes,
  32. dataset_length=len(self.single_label_dataset),
  33. ignore_empty_annotations=self.single_label_dataset.ignore_empty_annotations,
  34. oversample_threshold=0.0,
  35. )
  36. expected_mappings = [1.0] * len(self.single_label_dataset)
  37. self.assertListEqual(expected_mappings, repeat_factors)
  38. def test_oversampling_frequent_classes_less_often_than_scarce(self):
  39. repeat_factors = get_repeat_factors(
  40. index_to_classes=lambda idx: self.single_label_dataset[idx].tolist(),
  41. num_classes=self.single_label_dataset.num_classes,
  42. dataset_length=len(self.single_label_dataset),
  43. ignore_empty_annotations=self.single_label_dataset.ignore_empty_annotations,
  44. oversample_threshold=1.0,
  45. )
  46. # reminder: samples = [[], [1], [2,2], [3,3,3], [4,4,4,4]]
  47. self.assertEqual(repeat_factors[0], 1.0) # do not over sample empty annotations
  48. # expected something like [1.0, a, b, c, d], a>b>c>d>1.0
  49. diffs = np.diff(repeat_factors[1:])
  50. self.assertTrue(np.all(diffs < 0.0))
  51. def test_multi_class_over_sampling(self):
  52. """
  53. Interestingly, when we have a class that appears in every sample ([[0,0], [1,0], [2,0], [3,0], [4,0]]),
  54. and other samples have the same frequencies, we are still oversampling samples, but use the same repeat factor for all.
  55. The reason is that originally we have #0 class appearing 6 times, and other classes appear 1 time, which is 6x freq; after resampling,
  56. we have #0 class appearing 14 times, and other classes appear 3 times. Note that lower bound for class #0 is 4x freq, and after resampling it is 4.6x.
  57. """
  58. repeat_factors = get_repeat_factors(
  59. index_to_classes=lambda idx: self.multi_label_dataset[idx].tolist(),
  60. num_classes=self.multi_label_dataset.num_classes,
  61. dataset_length=len(self.multi_label_dataset),
  62. ignore_empty_annotations=self.multi_label_dataset.ignore_empty_annotations,
  63. oversample_threshold=1.0,
  64. )
  65. # reminder: samples = [[0,0], [1,0], [2,0], [3,0], [4,0]]
  66. self.assertEqual(repeat_factors[0], 1.0) # do not over sample the biggest class
  67. # expected something like [1.0, a, b, c, d], a=b=c=d>1.0
  68. diffs = np.diff(repeat_factors[1:])
  69. self.assertTrue(np.all(diffs == 0.0))
  70. def test_no_oversample_below_threshold(self):
  71. repeat_factors = get_repeat_factors(
  72. index_to_classes=lambda idx: self.single_label_dataset[idx].tolist(),
  73. num_classes=self.single_label_dataset.num_classes,
  74. dataset_length=len(self.single_label_dataset),
  75. ignore_empty_annotations=self.single_label_dataset.ignore_empty_annotations,
  76. oversample_threshold=0.5,
  77. )
  78. # reminder: samples = [[], [1], [2,2], [3,3,3], [4,4,4,4]]
  79. # overall we have 5 images, class #1 appears 1/5 (in image 1), #2 appears 2/5 (image 2), #3 appears 3/5 (image 3), #4 appears 4/5 (image 4).
  80. # We will not oversample IMAGES 3 and 4, nor the empty image 0.
  81. oversampled_indices = np.array([False, True, True, False, False])
  82. self.assertTrue(np.all(np.array(repeat_factors)[~oversampled_indices] == 1.0)) # all
  83. # make sure indices that are oversampled are with expected repeat factor
  84. self.assertTrue(np.all(np.diff(np.array(repeat_factors)[oversampled_indices]) < 0.0))
  85. if __name__ == "__main__":
  86. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...