Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_dataloader_adapter_non_regression.py 5.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
  1. import os
  2. import numpy as np
  3. import unittest
  4. import tempfile
  5. import shutil
  6. from data_gradients.managers.detection_manager import DetectionAnalysisManager
  7. from data_gradients.managers.segmentation_manager import SegmentationAnalysisManager
  8. from data_gradients.managers.classification_manager import ClassificationAnalysisManager
  9. from data_gradients.utils.data_classes.image_channels import ImageChannels
  10. from super_gradients.training.dataloaders.dataloaders import coco2017_val, cityscapes_stdc_seg50_val, cifar10_val
  11. from super_gradients.training.dataloaders.adapters import (
  12. DetectionDataloaderAdapterFactory,
  13. SegmentationDataloaderAdapterFactory,
  14. ClassificationDataloaderAdapterFactory,
  15. )
  16. class DataloaderAdapterNonRegressionTest(unittest.TestCase):
  17. def setUp(self) -> None:
  18. self.tmp_dir = tempfile.mkdtemp()
  19. def tearDown(self):
  20. shutil.rmtree(self.tmp_dir)
  21. def test_adapter_on_coco2017_val(self):
  22. # We use Validation set because it does not include augmentation (which is random and makes it impossible to compare results)
  23. loader = coco2017_val(
  24. dataset_params={"max_num_samples": 500, "with_crowd": False},
  25. dataloader_params={"collate_fn": "DetectionCollateFN"},
  26. ) # `max_num_samples` To make it faster
  27. analyzer = DetectionAnalysisManager(
  28. report_title="coco2017_val",
  29. log_dir=self.tmp_dir,
  30. train_data=loader,
  31. val_data=loader,
  32. class_names=loader.dataset.classes,
  33. image_channels=ImageChannels.from_str("RGB"),
  34. batches_early_stop=20,
  35. use_cache=True, # With this we will be asked about the data information only once
  36. bbox_format="cxcywh",
  37. is_label_first=True,
  38. )
  39. analyzer.run()
  40. adapted_loader = DetectionDataloaderAdapterFactory.from_dataloader(dataloader=loader, config_path=analyzer.data_config.cache_path)
  41. for (adapted_images, adapted_targets), (images, targets) in zip(adapted_loader, loader):
  42. assert np.isclose(adapted_targets, targets).all()
  43. assert np.isclose(adapted_images, images).all()
  44. os.remove(analyzer.data_config.cache_path)
  45. def test_adapter_on_cityscapes_stdc_seg50_val(self):
  46. # We use Validation set because it does not include augmentation (which is random and makes it impossible to compare results)
  47. loader = cityscapes_stdc_seg50_val()
  48. analyzer = SegmentationAnalysisManager(
  49. report_title="cityscapes_stdc_seg50_val",
  50. log_dir=self.tmp_dir,
  51. train_data=loader,
  52. val_data=loader,
  53. class_names=loader.dataset.classes + ["<unknown>"],
  54. image_channels=ImageChannels.from_str("RGB"),
  55. batches_early_stop=1,
  56. use_cache=True, # With this we will be asked about the data information only once
  57. )
  58. analyzer.run()
  59. adapted_loader = SegmentationDataloaderAdapterFactory.from_dataloader(dataloader=loader, config_path=analyzer.data_config.cache_path)
  60. for (adapted_images, adapted_targets), (images, targets) in zip(adapted_loader, loader):
  61. assert np.isclose(adapted_targets, targets).all()
  62. assert np.isclose(adapted_images, images).all()
  63. os.remove(analyzer.data_config.cache_path)
  64. def test_adapter_on_cifar10_val(self):
  65. # We use Validation set because it does not include augmentation (which is random and makes it impossible to compare results)
  66. loader = cifar10_val(dataset_params={"transforms": ["ToTensor"]})
  67. analyzer = ClassificationAnalysisManager(
  68. report_title="test_python_classification",
  69. log_dir=self.tmp_dir,
  70. train_data=loader,
  71. val_data=loader,
  72. class_names=list(range(10)),
  73. image_channels=ImageChannels.from_str("RGB"),
  74. batches_early_stop=20,
  75. use_cache=True, # With this we will be asked about the data information only once
  76. )
  77. analyzer.run()
  78. adapted_loader = ClassificationDataloaderAdapterFactory.from_dataloader(dataloader=loader, config_path=analyzer.data_config.cache_path)
  79. for (adapted_images, adapted_targets), (images, targets) in zip(adapted_loader, loader):
  80. assert np.isclose(adapted_targets, targets).all()
  81. assert np.isclose(adapted_images, images).all()
  82. os.remove(analyzer.data_config.cache_path)
  83. def test_ddp_python_based_adapter(self):
  84. # setup_device(num_gpus=3)
  85. # We use Validation set because it does not include augmentation (which is random and makes it impossible to compare results)
  86. loader = cifar10_val(dataset_params={"transforms": ["ToTensor"]})
  87. analyzer = ClassificationAnalysisManager(
  88. report_title="test_python_classification",
  89. log_dir=self.tmp_dir,
  90. train_data=loader,
  91. val_data=loader,
  92. class_names=list(range(10)),
  93. image_channels=ImageChannels.from_str("RGB"),
  94. batches_early_stop=20,
  95. use_cache=True, # With this we will be asked about the data information only once
  96. )
  97. analyzer.run()
  98. adapted_loader = ClassificationDataloaderAdapterFactory.from_dataloader(dataloader=loader, config_path=analyzer.data_config.cache_path)
  99. for (adapted_images, adapted_targets), (images, targets) in zip(adapted_loader, loader):
  100. assert np.isclose(adapted_targets, targets).all()
  101. assert np.isclose(adapted_images, images).all()
  102. os.remove(analyzer.data_config.cache_path)
  103. if __name__ == "__main__":
  104. DataloaderAdapterNonRegressionTest()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...