Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

class_balanced_sampler_test.py 4.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
  1. import random
  2. from typing import Dict
  3. import numpy as np
  4. import torch
  5. import unittest
  6. from torch.utils.data import Dataset
  7. from super_gradients.dataset_interfaces import HasClassesInformation
  8. from super_gradients.training import dataloaders
  9. from super_gradients.training.datasets.samplers.class_balanced_sampler import ClassBalancedSampler
  10. class DummyFreqDataset(Dataset, HasClassesInformation):
  11. def __init__(self, class_id_to_frequency: Dict[int, int], total_samples: int) -> None:
  12. self.total_samples = total_samples
  13. self.num_classes = len(class_id_to_frequency)
  14. self.class_id_to_frequency = class_id_to_frequency
  15. self.ignore_empty_annotations = True
  16. self._setup_data_source()
  17. super().__init__()
  18. def _setup_data_source(self) -> int:
  19. flattened_list = list()
  20. for k, v in self.class_id_to_frequency.items():
  21. flattened_list.extend([k] * v)
  22. random.shuffle(flattened_list)
  23. self.idx_to_classes = np.array_split(flattened_list, self.total_samples)
  24. return len(self.idx_to_classes)
  25. def __len__(self) -> int:
  26. return len(self.idx_to_classes)
  27. def __getitem__(self, index: int):
  28. return self.idx_to_classes[index]
  29. def get_sample_classes_information(self, index: int) -> np.ndarray:
  30. classes = self.idx_to_classes[index]
  31. return np.bincount(classes, minlength=self.num_classes)
  32. def get_dataset_classes_information(self) -> np.ndarray:
  33. return np.vstack([self.get_sample_classes_information(index) for index in range(len(self))])
  34. class ClassBalancedSamplerTest(unittest.TestCase):
  35. def test_balancing_classes_that_are_with_same_frequency(self):
  36. id_to_freq = {0: 30000, 1: 30000, 2: 30000}
  37. total_samples = 60000
  38. dataset = DummyFreqDataset(class_id_to_frequency=id_to_freq, total_samples=total_samples)
  39. sampler = ClassBalancedSampler(dataset=dataset)
  40. dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, sampler=sampler)
  41. classes_sampled = {k: 0 for k in id_to_freq.keys()}
  42. for batch in dataloader:
  43. for element in batch:
  44. for cls in element:
  45. classes_sampled[cls.item()] += 1
  46. for k in classes_sampled.keys():
  47. expected_freq = id_to_freq[k] / total_samples
  48. sampled_freq = classes_sampled[k] / total_samples
  49. self.assertAlmostEqual(expected_freq, sampled_freq, places=1)
  50. def test_balancing_scarce_classes(self):
  51. id_to_freq = {0: 10000, 1: 1000, 2: 10000}
  52. total_samples = 15000
  53. dataset = DummyFreqDataset(class_id_to_frequency=id_to_freq, total_samples=total_samples)
  54. sampler = ClassBalancedSampler(dataset=dataset)
  55. dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, sampler=sampler)
  56. classes_sampled = {k: 0 for k in id_to_freq.keys()}
  57. for batch in dataloader:
  58. for element in batch:
  59. for cls in element:
  60. classes_sampled[cls.item()] += 1
  61. for k in classes_sampled.keys():
  62. original_freq = id_to_freq[k] / total_samples
  63. sampled_freq = classes_sampled[k] / total_samples
  64. if k == 1: # over sampled class
  65. self.assertGreater(sampled_freq, original_freq)
  66. else:
  67. self.assertLess(sampled_freq, original_freq)
  68. def test_get_from_config(self):
  69. id_to_freq = {0: 10, 1: 1, 2: 10}
  70. total_samples = 15
  71. dataset = DummyFreqDataset(class_id_to_frequency=id_to_freq, total_samples=total_samples)
  72. dataloader_params = {
  73. "batch_size": 4,
  74. "sampler": {"ClassBalancedSampler": {"oversample_threshold": 1.0, "oversample_aggressiveness": 1.5}},
  75. "drop_last": True,
  76. }
  77. dataloader = dataloaders.get(dataset=dataset, dataloader_params=dataloader_params)
  78. self.assertTrue(isinstance(dataloader.sampler, ClassBalancedSampler))
  79. if __name__ == "__main__":
  80. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...