Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#443 Feature/SG 344 - ActivationsTypeFactory

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-344-activations-factory
@@ -5,6 +5,15 @@ from super_gradients.common.factories.metrics_factory import MetricsFactory
 from super_gradients.common.factories.optimizers_type_factory import OptimizersTypeFactory
 from super_gradients.common.factories.optimizers_type_factory import OptimizersTypeFactory
 from super_gradients.common.factories.samplers_factory import SamplersFactory
 from super_gradients.common.factories.samplers_factory import SamplersFactory
 from super_gradients.common.factories.transforms_factory import TransformsFactory
 from super_gradients.common.factories.transforms_factory import TransformsFactory
+from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
 
 
-
-__all__ = ["CallbacksFactory", "ListFactory", "LossesFactory", "MetricsFactory", "OptimizersTypeFactory", "SamplersFactory", "TransformsFactory"]
+__all__ = [
+    "CallbacksFactory",
+    "ListFactory",
+    "LossesFactory",
+    "MetricsFactory",
+    "OptimizersTypeFactory",
+    "SamplersFactory",
+    "TransformsFactory",
+    "ActivationsTypeFactory",
+]
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
  1. from typing import Union, Type, Mapping
  2. from super_gradients.common.factories.base_factory import AbstractFactory
  3. from super_gradients.training.utils.activations_utils import get_builtin_activation_type
  4. class ActivationsTypeFactory(AbstractFactory):
  5. """
  6. This is a special factory for getting a type of the activation function by name.
  7. This factory does not instantiate a module, but rather return the type to be instantiated via call method.
  8. """
  9. def get(self, conf: Union[str, Mapping]) -> Type:
  10. """
  11. Get a type.
  12. :param conf: a configuration
  13. if string - assumed to be a type name (not the real name, but a name defined in the Factory)
  14. a dictionary is not supported, since the actual instantiation takes place elsewhere
  15. If provided value is not one of the three above, the value will be returned as is
  16. """
  17. if isinstance(conf, str):
  18. return get_builtin_activation_type(conf)
  19. if isinstance(conf, Mapping):
  20. (type_name,) = list(conf.keys())
  21. type_args = conf[type_name]
  22. return get_builtin_activation_type(type_name, **type_args)
  23. raise RuntimeError(f"Unsupported conf param type {type(conf)}")
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
  1. from functools import partial
  2. from typing import Type, Union, Dict
  3. import torch
  4. from torch import nn
  5. def get_builtin_activation_type(activation: Union[str, None], **kwargs) -> Type:
  6. """
  7. Returns activation class by its name from torch.nn namespace. This function support all modules available from
  8. torch.nn and also their lower-case aliases.
  9. On top of that, it supports a few aliaes: leaky_relu (LeakyReLU), swish (silu).
  10. >>> act_cls = get_activation_type("LeakyReLU", inplace=True, slope=0.01)
  11. >>> act = act_cls()
  12. Args:
  13. activation: Activation function name (E.g. ReLU). If None will return nn.Identity
  14. **kwargs: Extra arguments to pass to constructor during instantiation (E.g. inplace=True)
  15. Returns:
  16. Type of the activation function that is ready to be instantiated
  17. """
  18. if activation is None:
  19. activation_cls = nn.Identity
  20. else:
  21. lowercase_aliases: Dict[str, str] = dict((k.lower(), k) for k in torch.nn.__dict__.keys())
  22. # Register additional aliases
  23. lowercase_aliases["leaky_relu"] = "LeakyReLU" # LeakyRelu in snake_case
  24. lowercase_aliases["swish"] = "SiLU" # Swish shich is equivalent to SiLU
  25. lowercase_aliases["none"] = "Identity"
  26. if activation in lowercase_aliases:
  27. activation = lowercase_aliases[activation]
  28. if activation not in torch.nn.__dict__:
  29. raise KeyError(f"Requested activation function {activation} is not known")
  30. activation_cls = torch.nn.__dict__[activation]
  31. if len(kwargs):
  32. activation_cls = partial(activation_cls, **kwargs)
  33. return activation_cls
Discard
@@ -3,9 +3,12 @@ import unittest
 import torch
 import torch
 
 
 from super_gradients import Trainer
 from super_gradients import Trainer
+from super_gradients.common.decorators.factory_decorator import resolve_param
+from super_gradients.common.factories import ActivationsTypeFactory
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
+from torch import nn
 
 
 
 
 class FactoriesTest(unittest.TestCase):
 class FactoriesTest(unittest.TestCase):
@@ -36,6 +39,16 @@ class FactoriesTest(unittest.TestCase):
         self.assertIsInstance(trainer.valid_metrics.Top5, Top5)
         self.assertIsInstance(trainer.valid_metrics.Top5, Top5)
         self.assertIsInstance(trainer.optimizer, torch.optim.ASGD)
         self.assertIsInstance(trainer.optimizer, torch.optim.ASGD)
 
 
+    def test_activations_factory(self):
+        class DummyModel(nn.Module):
+            @resolve_param("activation_in_head", ActivationsTypeFactory())
+            def __init__(self, activation_in_head):
+                super().__init__()
+                self.activation_in_head = activation_in_head()
+
+        model = DummyModel(activation_in_head="leaky_relu")
+        self.assertIsInstance(model.activation_in_head, nn.LeakyReLU)
+
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
     unittest.main()
     unittest.main()
Discard