Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_deprecations.py 6.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
  1. import unittest
  2. import warnings
  3. from typing import Union
  4. from omegaconf import DictConfig
  5. from torch import nn
  6. from super_gradients import Trainer
  7. from super_gradients.common.registry import register_model
  8. from super_gradients.training import models
  9. from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
  10. from super_gradients.training.metrics import Accuracy, Top5
  11. from super_gradients.training.models import CustomizableDetector, get_arch_params, ResNet18
  12. from super_gradients.training.params import TrainingParams
  13. from super_gradients.training.utils import HpmStruct
  14. from super_gradients.training.utils.utils import arch_params_deprecated
  15. from super_gradients.training.transforms.transforms import DetectionTargetsFormatTransform, DetectionHorizontalFlip, DetectionPaddedRescale
  16. @register_model("DummyModel")
  17. class DummyModel(CustomizableDetector):
  18. def __init__(self, arch_params: Union[str, dict, HpmStruct, DictConfig]):
  19. super().__init__(arch_params)
  20. @register_model("DummyModelV2")
  21. class DummyModelV2(nn.Module):
  22. @arch_params_deprecated
  23. def __init__(self, backbone, head, neck):
  24. super().__init__()
  25. class DeprecationsUnitTest(unittest.TestCase):
  26. def test_deprecated_arch_params_inside_base_class_via_direct_call(self):
  27. arch_params = get_arch_params("yolo_nas_l_arch_params")
  28. arch_params = HpmStruct(**arch_params)
  29. model = DummyModel(arch_params)
  30. assert isinstance(model, DummyModel)
  31. def test_deprecated_arch_params_inside_base_class_via_models_get(self):
  32. arch_params = get_arch_params("yolo_nas_l_arch_params")
  33. model = models.get("DummyModel", arch_params=arch_params, num_classes=80)
  34. assert isinstance(model, DummyModel)
  35. def test_deprecated_arch_params_top_level_class_via_direct_call(self):
  36. arch_params = HpmStruct(backbone=dict(), head=dict(), neck=dict())
  37. model = DummyModelV2(arch_params)
  38. assert isinstance(model, DummyModelV2)
  39. def test_deprecated_arch_params_top_level_class_via_models_get(self):
  40. arch_params = dict(backbone=dict(), head=dict(), neck=dict())
  41. model = models.get("DummyModelV2", arch_params=arch_params, num_classes=80)
  42. assert isinstance(model, DummyModelV2)
  43. def test_deprecated_make_divisible(self):
  44. try:
  45. with self.assertWarns(DeprecationWarning):
  46. from super_gradients.training.models import make_divisible # noqa
  47. assert make_divisible(1, 1) == 1
  48. except ImportError:
  49. self.fail("ImportError raised unexpectedly for make_divisible")
  50. def test_deprecated_BasicBlock(self):
  51. try:
  52. with self.assertWarns(DeprecationWarning):
  53. from super_gradients.training.models import BasicBlock, BasicResNetBlock # noqa
  54. assert isinstance(BasicBlock(1, 1, 1), BasicResNetBlock)
  55. except ImportError:
  56. self.fail("ImportError raised unexpectedly for BasicBlock")
  57. def test_deprecated_max_targets(self):
  58. with self.assertWarns(DeprecationWarning):
  59. DetectionTargetsFormatTransform(max_targets=1)
  60. DetectionHorizontalFlip(prob=1.0, max_targets=1)
  61. DetectionPaddedRescale(input_dim=(2, 2), max_targets=1)
  62. def test_moved_Bottleneck_import(self):
  63. try:
  64. with self.assertWarns(DeprecationWarning):
  65. from super_gradients.training.models import Bottleneck as OldBottleneck # noqa
  66. from super_gradients.training.models.classification_models.resnet import Bottleneck
  67. assert isinstance(OldBottleneck(1, 1, 1), Bottleneck)
  68. except ImportError:
  69. self.fail("ImportError raised unexpectedly for Bottleneck")
  70. def test_deprecated_optimizers_dict(self):
  71. try:
  72. with self.assertWarns(DeprecationWarning):
  73. from super_gradients.training.utils.optimizers.all_optimizers import OPTIMIZERS # noqa
  74. except ImportError:
  75. self.fail("ImportError raised unexpectedly for OPTIMIZERS")
  76. def test_deprecated_HpmStruct_import(self):
  77. try:
  78. with self.assertWarns(DeprecationWarning):
  79. from super_gradients.training.models import HpmStruct as OldHpmStruct
  80. from super_gradients.training.utils import HpmStruct
  81. assert isinstance(OldHpmStruct(a=1), HpmStruct)
  82. except ImportError:
  83. self.fail("ImportError raised unexpectedly for HpmStruct")
  84. def test_deprecated_criterion_params(self):
  85. with self.assertWarns(DeprecationWarning):
  86. warnings.simplefilter("always")
  87. train_params = {
  88. "max_epochs": 4,
  89. "lr_decay_factor": 0.1,
  90. "lr_updates": [4],
  91. "lr_mode": "StepLRScheduler",
  92. "lr_warmup_epochs": 0,
  93. "initial_lr": 0.1,
  94. "loss": "CrossEntropyLoss",
  95. "optimizer": "SGD",
  96. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  97. "loss": "CrossEntropyLoss",
  98. "train_metrics_list": [],
  99. "valid_metrics_list": [],
  100. "metric_to_watch": "Accuracy",
  101. "greater_metric_to_watch_is_better": True,
  102. }
  103. train_params = TrainingParams(**train_params)
  104. train_params.override(criterion_params={"ignore_index": 0})
  105. def test_train_with_deprecated_criterion_params(self):
  106. trainer = Trainer("test_train_with_precise_bn_explicit_size")
  107. net = ResNet18(num_classes=5, arch_params={})
  108. train_params = {
  109. "max_epochs": 2,
  110. "lr_updates": [1],
  111. "lr_decay_factor": 0.1,
  112. "lr_mode": "StepLRScheduler",
  113. "lr_warmup_epochs": 0,
  114. "initial_lr": 0.1,
  115. "loss": "CrossEntropyLoss",
  116. "criterion_params": {"ignore_index": -300},
  117. "optimizer": "SGD",
  118. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  119. "train_metrics_list": [Accuracy(), Top5()],
  120. "valid_metrics_list": [Accuracy(), Top5()],
  121. "metric_to_watch": "Accuracy",
  122. "greater_metric_to_watch_is_better": True,
  123. }
  124. trainer.train(
  125. model=net,
  126. training_params=train_params,
  127. train_loader=classification_test_dataloader(batch_size=10),
  128. valid_loader=classification_test_dataloader(batch_size=10),
  129. )
  130. self.assertEqual(trainer.criterion.ignore_index, -300)
  131. if __name__ == "__main__":
  132. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...