Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#360 get training_params

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-267_get_training_params_from_recipe
@@ -24,6 +24,7 @@ from super_gradients.training.datasets.segmentation_datasets import CityscapesDa
 from super_gradients.common.factories.samplers_factory import SamplersFactory
 from super_gradients.common.factories.samplers_factory import SamplersFactory
 from super_gradients.training.utils.distributed_training_utils import wait_for_the_master, get_local_rank
 from super_gradients.training.utils.distributed_training_utils import wait_for_the_master, get_local_rank
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
+from super_gradients.training.utils.utils import override_default_params_without_nones
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -83,7 +84,7 @@ def _process_dataloader_params(cfg, dataloader_params, dataset, train):
         default_dataloader_params["sampler"] = {"DistributedSampler": {}}
         default_dataloader_params["sampler"] = {"DistributedSampler": {}}
         default_dataloader_params = _instantiate_sampler(dataset, default_dataloader_params)
         default_dataloader_params = _instantiate_sampler(dataset, default_dataloader_params)
 
 
-    dataloader_params = _override_default_params_without_nones(dataloader_params, default_dataloader_params)
+    dataloader_params = override_default_params_without_nones(dataloader_params, default_dataloader_params)
     if get_param(dataloader_params, "batch_sampler"):
     if get_param(dataloader_params, "batch_sampler"):
         sampler = dataloader_params.pop("sampler")
         sampler = dataloader_params.pop("sampler")
         batch_size = dataloader_params.pop("batch_size")
         batch_size = dataloader_params.pop("batch_size")
@@ -92,13 +93,6 @@ def _process_dataloader_params(cfg, dataloader_params, dataset, train):
     return dataloader_params
     return dataloader_params
 
 
 
 
-def _override_default_params_without_nones(params, default_params):
-    for key, val in default_params.items():
-        if key not in params.keys() or params[key] is None:
-            params[key] = val
-    return params
-
-
 def _instantiate_sampler(dataset, dataloader_params):
 def _instantiate_sampler(dataset, dataloader_params):
     sampler_name = list(dataloader_params["sampler"].keys())[0]
     sampler_name = list(dataloader_params["sampler"].keys())[0]
     dataloader_params["sampler"][sampler_name]["dataset"] = dataset
     dataloader_params["sampler"][sampler_name]["dataset"] = dataset
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
  1. from .training_hyperparams import cifar10_resnet_train_params, cityscapes_ddrnet_train_params, \
  2. cityscapes_regseg48_train_params, \
  3. cityscapes_stdc_base_train_params, \
  4. cityscapes_stdc_seg50_train_params, \
  5. cityscapes_stdc_seg75_train_params, \
  6. coco2017_ssd_lite_mobilenet_v2_train_params, \
  7. coco2017_yolox_train_params, \
  8. coco_segmentation_shelfnet_lw_train_params, \
  9. imagenet_efficientnet_train_params, \
  10. imagenet_mobilenetv2_train_params, \
  11. imagenet_mobilenetv3_base_train_params, \
  12. imagenet_mobilenetv3_large_train_params, \
  13. imagenet_mobilenetv3_small_train_params, \
  14. imagenet_regnetY_train_params, \
  15. imagenet_repvgg_train_params, \
  16. imagenet_resnet50_train_params, \
  17. imagenet_resnet50_kd_train_params, \
  18. imagenet_vit_base_train_params, \
  19. imagenet_vit_large_train_params, \
  20. get
  21. __all__ = ["cifar10_resnet_train_params", "cityscapes_ddrnet_train_params",
  22. "cityscapes_regseg48_train_params",
  23. "cityscapes_stdc_base_train_params",
  24. "cityscapes_stdc_seg50_train_params",
  25. "cityscapes_stdc_seg75_train_params",
  26. "coco2017_ssd_lite_mobilenet_v2_train_params",
  27. "coco2017_yolox_train_params",
  28. "coco_segmentation_shelfnet_lw_train_params",
  29. "imagenet_efficientnet_train_params",
  30. "imagenet_mobilenetv2_train_params",
  31. "imagenet_mobilenetv3_base_train_params",
  32. "imagenet_mobilenetv3_large_train_params",
  33. "imagenet_mobilenetv3_small_train_params",
  34. "imagenet_regnetY_train_params",
  35. "imagenet_repvgg_train_params",
  36. "imagenet_resnet50_train_params",
  37. "imagenet_resnet50_kd_train_params",
  38. "imagenet_vit_base_train_params",
  39. "imagenet_vit_large_train_params",
  40. "get"]
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  1. import hydra
  2. import pkg_resources
  3. from hydra import compose, initialize_config_dir
  4. from hydra.core.global_hydra import GlobalHydra
  5. from super_gradients.training.utils.utils import override_default_params_without_nones
  6. from super_gradients.common.abstractions.abstract_logger import get_logger
  7. from typing import Dict
  8. logger = get_logger(__name__)
  9. def get(config_name, overriding_params: Dict = None) -> Dict:
  10. """
  11. Class for creating training hyper parameters dictionary, taking defaults from yaml
  12. files in src/super_gradients/recipes.
  13. :param overriding_params: Dict, dictionary like object containing entries to override in the recipe's training
  14. hyper parameters dictionary.
  15. :param config_name: yaml config filename in recipes (for example coco2017_yolox).
  16. """
  17. if overriding_params is None:
  18. overriding_params = dict()
  19. GlobalHydra.instance().clear()
  20. with initialize_config_dir(config_dir=pkg_resources.resource_filename("super_gradients.recipes", "")):
  21. cfg = compose(config_name=config_name)
  22. cfg = hydra.utils.instantiate(cfg)
  23. training_params = cfg.training_hyperparams
  24. training_params = override_default_params_without_nones(overriding_params, training_params)
  25. return training_params
  26. def cifar10_resnet_train_params(overriding_params: Dict = None):
  27. return get("cifar10_resnet", overriding_params)
  28. def cityscapes_ddrnet_train_params(overriding_params: Dict = None):
  29. return get("cityscapes_ddrnet", overriding_params)
  30. def cityscapes_regseg48_train_params(overriding_params: Dict = None):
  31. return get("cityscapes_regseg48", overriding_params)
  32. def cityscapes_stdc_base_train_params(overriding_params: Dict = None):
  33. return get("cityscapes_stdc_base", overriding_params)
  34. def cityscapes_stdc_seg50_train_params(overriding_params: Dict = None):
  35. return get("cityscapes_stdc_seg50", overriding_params)
  36. def cityscapes_stdc_seg75_train_params(overriding_params: Dict = None):
  37. return get("cityscapes_stdc_seg75", overriding_params)
  38. def coco2017_ssd_lite_mobilenet_v2_train_params(overriding_params: Dict = None):
  39. return get("coco2017_ssd_lite_mobilenet_v2", overriding_params)
  40. def coco2017_yolox_train_params(overriding_params: Dict = None):
  41. return get("coco2017_yolox", overriding_params)
  42. def coco_segmentation_shelfnet_lw_train_params(overriding_params: Dict = None):
  43. return get("coco_segmentation_shelfnet_lw", overriding_params)
  44. def imagenet_efficientnet_train_params(overriding_params: Dict = None):
  45. return get("imagenet_efficientnet", overriding_params)
  46. def imagenet_mobilenetv2_train_params(overriding_params: Dict = None):
  47. return get("imagenet_mobilenetv2", overriding_params)
  48. def imagenet_mobilenetv3_base_train_params(overriding_params: Dict = None):
  49. return get("imagenet_mobilenetv3_base", overriding_params)
  50. def imagenet_mobilenetv3_large_train_params(overriding_params: Dict = None):
  51. return get("imagenet_mobilenetv3_large", overriding_params)
  52. def imagenet_mobilenetv3_small_train_params(overriding_params: Dict = None):
  53. return get("imagenet_mobilenetv3_small", overriding_params)
  54. def imagenet_regnetY_train_params(overriding_params: Dict = None):
  55. return get("imagenet_regnetY", overriding_params)
  56. def imagenet_repvgg_train_params(overriding_params: Dict = None):
  57. return get("imagenet_repvgg", overriding_params)
  58. def imagenet_resnet50_train_params(overriding_params: Dict = None):
  59. return get("imagenet_resnet50", overriding_params)
  60. def imagenet_resnet50_kd_train_params(overriding_params: Dict = None):
  61. return get("imagenet_resnet50_kd", overriding_params)
  62. def imagenet_vit_base_train_params(overriding_params: Dict = None):
  63. return get("imagenet_vit_base", overriding_params)
  64. def imagenet_vit_large_train_params(overriding_params: Dict = None):
  65. return get("imagenet_vit_large", overriding_params)
Discard
@@ -2,7 +2,7 @@ import math
 import time
 import time
 from functools import lru_cache
 from functools import lru_cache
 from pathlib import Path
 from pathlib import Path
-from typing import Mapping, Optional, Tuple, Union, List
+from typing import Mapping, Optional, Tuple, Union, List, Dict
 from zipfile import ZipFile
 from zipfile import ZipFile
 import os
 import os
 from jsonschema import validate
 from jsonschema import validate
@@ -444,3 +444,16 @@ def get_image_size_from_path(img_path: str) -> Tuple[int, int]:
     """Get the image size of an image at a specific path"""
     """Get the image size of an image at a specific path"""
     with open(img_path, 'rb') as f:
     with open(img_path, 'rb') as f:
         return exif_size(Image.open(f))
         return exif_size(Image.open(f))
+
+
+def override_default_params_without_nones(params: Dict, default_params: Dict) -> Dict:
+    """
+    Helper method for overriding default dictionary's entries excluding entries with None values.
+    :param params: dict, output dictionary which will take the defaults.
+    :param default_params: dict, dictionary for the defaults.
+    :return: dict, params after manipulation,
+    """
+    for key, val in default_params.items():
+        if key not in params.keys() or params[key] is None:
+            params[key] = val
+    return params
Discard
@@ -4,7 +4,7 @@ import unittest
 from tests.integration_tests.ema_train_integration_test import EMAIntegrationTest
 from tests.integration_tests.ema_train_integration_test import EMAIntegrationTest
 from tests.unit_tests import ZeroWdForBnBiasTest, SaveCkptListUnitTest, TestAverageMeter, \
 from tests.unit_tests import ZeroWdForBnBiasTest, SaveCkptListUnitTest, TestAverageMeter, \
     TestRepVgg, TestWithoutTrainTest, OhemLossTest, EarlyStopTest, SegmentationTransformsTest, \
     TestRepVgg, TestWithoutTrainTest, OhemLossTest, EarlyStopTest, SegmentationTransformsTest, \
-    TestConvBnRelu, FactoriesTest, InitializeWithDataloadersTest
+    TestConvBnRelu, FactoriesTest, InitializeWithDataloadersTest, TrainingParamsTest
 from tests.end_to_end_tests import TestTrainer
 from tests.end_to_end_tests import TestTrainer
 from tests.unit_tests.phase_delegates_test import ContextMethodsTest
 from tests.unit_tests.phase_delegates_test import ContextMethodsTest
 from tests.unit_tests.random_erase_test import RandomEraseTest
 from tests.unit_tests.random_erase_test import RandomEraseTest
@@ -76,6 +76,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetSubclassing))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetSubclassing))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetCaching))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetCaching))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(MultiScaleTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(MultiScaleTest))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TrainingParamsTest))
 
 
     def _add_modules_to_end_to_end_tests_suite(self):
     def _add_modules_to_end_to_end_tests_suite(self):
         """
         """
Discard
@@ -16,10 +16,10 @@ from tests.unit_tests.segmentation_transforms_test import SegmentationTransforms
 from tests.unit_tests.pretrained_models_unit_test import PretrainedModelsUnitTest
 from tests.unit_tests.pretrained_models_unit_test import PretrainedModelsUnitTest
 from tests.unit_tests.conv_bn_relu_test import TestConvBnRelu
 from tests.unit_tests.conv_bn_relu_test import TestConvBnRelu
 from tests.unit_tests.initialize_with_dataloaders_test import InitializeWithDataloadersTest
 from tests.unit_tests.initialize_with_dataloaders_test import InitializeWithDataloadersTest
-
+from tests.unit_tests.training_params_factory_test import TrainingParamsTest
 
 
 __all__ = ['TestDatasetInterface', 'ZeroWdForBnBiasTest', 'SaveCkptListUnitTest',
 __all__ = ['TestDatasetInterface', 'ZeroWdForBnBiasTest', 'SaveCkptListUnitTest',
            'AllArchitecturesTest', 'TestAverageMeter', 'TestRepVgg', 'TestWithoutTrainTest',
            'AllArchitecturesTest', 'TestAverageMeter', 'TestRepVgg', 'TestWithoutTrainTest',
            'StrictLoadEnumTest', 'TrainWithInitializedObjectsTest', 'TestAutoAugment',
            'StrictLoadEnumTest', 'TrainWithInitializedObjectsTest', 'TestAutoAugment',
            'OhemLossTest', 'EarlyStopTest', 'SegmentationTransformsTest', 'PretrainedModelsUnitTest', 'TestConvBnRelu',
            'OhemLossTest', 'EarlyStopTest', 'SegmentationTransformsTest', 'PretrainedModelsUnitTest', 'TestConvBnRelu',
-           'FactoriesTest', 'InitializeWithDataloadersTest']
+           'FactoriesTest', 'InitializeWithDataloadersTest', 'TrainingParamsTest']
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
  1. import unittest
  2. from super_gradients.training import training_hyperparams
  3. class TrainingParamsTest(unittest.TestCase):
  4. def test_get_train_params(self):
  5. train_params = training_hyperparams.coco2017_yolox_train_params()
  6. self.assertTrue(train_params["loss"] == "yolox_loss")
  7. self.assertTrue(train_params["max_epochs"] == 300)
  8. def test_get_train_params_with_overrides(self):
  9. train_params = training_hyperparams.coco2017_yolox_train_params(overriding_params={"max_epochs": 5})
  10. self.assertTrue(train_params["loss"] == "yolox_loss")
  11. self.assertTrue(train_params["max_epochs"] == 5)
  12. if __name__ == '__main__':
  13. unittest.main()
Discard