Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

class_balancer_test.py 5.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
  1. import os
  2. import tempfile
  3. import numpy as np
  4. import torch
  5. import unittest
  6. from torch.utils.data import Dataset
  7. from super_gradients.dataset_interfaces import HasClassesInformation
  8. from super_gradients.training.datasets.samplers.class_balanced_sampler import ClassBalancer
  9. class SingleLabelUnbalancedDataset(Dataset, HasClassesInformation):
  10. def __init__(self, num_classes: int) -> None:
  11. super().__init__()
  12. self.num_classes = num_classes
  13. self.ignore_empty_annotations = False
  14. def __len__(self) -> int:
  15. return self.num_classes
  16. def __getitem__(self, idx: int) -> torch.Tensor:
  17. return torch.tensor([idx] * idx) # no class 0
  18. def get_sample_classes_information(self, index) -> np.ndarray:
  19. info = np.zeros(self.num_classes, dtype=np.int)
  20. info[index] = index
  21. return info
  22. def get_dataset_classes_information(self) -> np.ndarray:
  23. return np.diag(np.arange(self.num_classes))
  24. class MultiLabelUnbalancedDataset(Dataset, HasClassesInformation):
  25. def __init__(self, num_classes: int) -> None:
  26. super().__init__()
  27. self.num_classes = num_classes
  28. self.ignore_empty_annotations = False
  29. def __len__(self) -> int:
  30. return self.num_classes
  31. def __getitem__(self, idx: int) -> torch.Tensor:
  32. return torch.tensor([idx, 0]) # class 0 appears everywhere, other classes appear only once.
  33. def get_sample_classes_information(self, index) -> np.ndarray:
  34. info = np.zeros(self.num_classes, dtype=int)
  35. info[index] = 1
  36. info[0] += 1
  37. return info
  38. def get_dataset_classes_information(self) -> np.ndarray:
  39. diag = np.eye(self.num_classes, dtype=int)
  40. diag[:, 0] += 1
  41. return diag
  42. class ClassBalancerTest(unittest.TestCase):
  43. def setUp(self) -> None:
  44. self.single_label_dataset = SingleLabelUnbalancedDataset(num_classes=5) # [[], [1], [2,2], [3,3,3], [4,4,4,4]]
  45. self.multi_label_dataset = MultiLabelUnbalancedDataset(num_classes=5) # [[0,0], [1,0], [2,0], [3,0], [4,0]]
  46. def test_without_oversampling(self):
  47. repeat_factors = ClassBalancer.get_sample_repeat_factors(
  48. self.single_label_dataset,
  49. oversample_threshold=0.0,
  50. )
  51. expected_mappings = [1] * len(self.single_label_dataset)
  52. self.assertListEqual(expected_mappings, repeat_factors)
  53. def test_oversampling_frequent_classes_less_often_than_scarce(self):
  54. repeat_factors = ClassBalancer.get_sample_repeat_factors(
  55. self.single_label_dataset,
  56. oversample_threshold=1.0,
  57. )
  58. # reminder: samples = [[], [1], [2,2], [3,3,3], [4,4,4,4]]
  59. self.assertEqual(repeat_factors[0], 1.0) # do not over sample empty annotations
  60. # expected something like [1.0, a, b, c, d], a>b>c>d>1.0
  61. diffs = np.diff(repeat_factors[1:])
  62. self.assertTrue(np.all(diffs < 0.0))
  63. def test_multi_class_over_sampling(self):
  64. """
  65. Interestingly, when we have a class that appears in every sample ([[0,0], [1,0], [2,0], [3,0], [4,0]]),
  66. and other samples have the same frequencies, we are still oversampling samples, but use the same repeat factor for all.
  67. The reason is that originally we have #0 class appearing 6 times, and other classes appear 1 time, which is 6x freq; after resampling,
  68. we have #0 class appearing ~14 times, and other classes appear ~3 times. Note that lower bound for class #0 is 4x freq, and after resampling it is 4.6x.
  69. """
  70. repeat_factors = ClassBalancer.get_sample_repeat_factors(
  71. self.multi_label_dataset,
  72. oversample_threshold=1.0,
  73. )
  74. # reminder: samples = [[0,0], [1,0], [2,0], [3,0], [4,0]]
  75. self.assertEqual(1.0, repeat_factors[0]) # do not over sample the biggest class
  76. # expected something like [1.0, a, b, c, d], a=b=c=d>x>1.0
  77. diffs = np.diff(repeat_factors[1:])
  78. self.assertTrue(np.all(diffs == 0.0))
  79. def test_no_oversample_below_threshold(self):
  80. repeat_factors = ClassBalancer.get_sample_repeat_factors(
  81. self.single_label_dataset,
  82. oversample_threshold=0.5,
  83. )
  84. # reminder: samples = [[], [1], [2,2], [3,3,3], [4,4,4,4]]
  85. # overall we have 5 images, class #1 appears 1/5 (in image 1), #2 appears 2/5 (image 2), #3 appears 3/5 (image 3), #4 appears 4/5 (image 4).
  86. # We will not oversample IMAGES 3 and 4, nor the empty image 0.
  87. oversampled_indices = np.array([False, True, True, False, False])
  88. self.assertTrue(np.all(np.array(repeat_factors)[~oversampled_indices] == 1.0)) # all
  89. # make sure indices that are oversampled are with expected repeat factor
  90. self.assertTrue(np.all(np.diff(np.array(repeat_factors)[oversampled_indices]) < 0.0))
  91. def test_precomputed_repeat_factors(self):
  92. repeat_factors = ClassBalancer.get_sample_repeat_factors(
  93. self.single_label_dataset,
  94. oversample_threshold=None,
  95. )
  96. with tempfile.TemporaryDirectory() as temp_dir:
  97. precomputed_file = os.path.join(temp_dir, "precomputed_repeat_factors.json")
  98. ClassBalancer.precompute_sample_repeat_factors(precomputed_file, self.single_label_dataset)
  99. loaded_repeat_factors = ClassBalancer.from_precomputed_sample_repeat_factors(precomputed_file)
  100. np.testing.assert_almost_equal(repeat_factors, loaded_repeat_factors, decimal=3)
  101. if __name__ == "__main__":
  102. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...