Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

dataset_index_mapping_test.py 1.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
  1. import torch
  2. import unittest
  3. from torch.utils.data import Dataset
  4. from super_gradients.training.datasets.balancing_classes_utils import IndexMappingDatasetWrapper
  5. class DummyDataset(Dataset):
  6. def __init__(self, num_classes: int) -> None:
  7. super().__init__()
  8. self.num_classes = num_classes
  9. self.ignore_empty_annotations = False
  10. def __len__(self) -> int:
  11. return self.num_classes
  12. def __getitem__(self, idx: int) -> torch.Tensor:
  13. return torch.tensor(idx)
  14. class DatasetIndexMappingTest(unittest.TestCase):
  15. def setUp(self) -> None:
  16. self.dummy_dataset = DummyDataset(num_classes=5)
  17. def test_mapping_indices_that_does_nothing(self):
  18. wrapper = IndexMappingDatasetWrapper(self.dummy_dataset, list(range(len(self.dummy_dataset))))
  19. self.assertEqual(len(wrapper), len(self.dummy_dataset))
  20. for i in range(len(wrapper)):
  21. self.assertEqual(self.dummy_dataset[i], wrapper[i])
  22. def test_mapping_indices_that_samples_only_specific_index(self):
  23. c = 3
  24. i = 1
  25. mapping = [i] * c
  26. wrapper = IndexMappingDatasetWrapper(self.dummy_dataset, mapping)
  27. for j in range(len(wrapper)):
  28. self.assertEqual(self.dummy_dataset[i], wrapper[j])
  29. if __name__ == "__main__":
  30. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...