Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#405 Feature/sg 292 update register model

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-292-update_register_model
@@ -82,7 +82,61 @@ python main.py --config-name=my_recipe.yaml
 
 
 
 
 ### B. Model
 ### B. Model
-Coming soon
+
+```python
+import omegaconf
+import hydra
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from super_gradients import Trainer, init_trainer
+from super_gradients.common.registry import register_model
+
+
+@register_model('my_conv_net')  # will be registered as "my_conv_net"
+class MyConvNet(nn.Module):
+   def __init__(self, num_classes: int):
+      super().__init__()
+      self.conv1 = nn.Conv2d(3, 6, 5)
+      self.pool = nn.MaxPool2d(2, 2)
+      self.conv2 = nn.Conv2d(6, 16, 5)
+      self.fc1 = nn.Linear(16 * 5 * 5, 120)
+      self.fc2 = nn.Linear(120, 84)
+      self.fc3 = nn.Linear(84, num_classes)
+
+   def forward(self, x):
+      x = self.pool(F.relu(self.conv1(x)))
+      x = self.pool(F.relu(self.conv2(x)))
+      x = torch.flatten(x, 1)
+      x = F.relu(self.fc1(x))
+      x = F.relu(self.fc2(x))
+      x = self.fc3(x)
+      return x
+
+
+@hydra.main(config_path="recipes")
+def main(cfg: omegaconf.DictConfig) -> None:
+   Trainer.train_from_config(cfg)
+
+
+init_trainer()
+main()
+```
+
+*recipes/my_recipe.yaml* 
+```yaml
+... # Other recipe params
+
+architecture: my_conv_net
+```
+
+*Launch the script*
+```bash
+python main.py --config-name=my_recipe.yaml
+```
+
 
 
 ### C. Loss
 ### C. Loss
 
 
Discard
@@ -1,6 +1,8 @@
-from typing import Optional
+from typing import Optional, Tuple
+from typing import Type
 
 
 import hydra
 import hydra
+import torch
 
 
 from super_gradients.common import StrictLoad
 from super_gradients.common import StrictLoad
 from super_gradients.common.plugins.deci_client import DeciClient
 from super_gradients.common.plugins.deci_client import DeciClient
@@ -16,63 +18,93 @@ from super_gradients.training.utils.checkpoint_utils import (
     load_pretrained_weights_local,
     load_pretrained_weights_local,
 )
 )
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
+from super_gradients.training.utils.sg_trainer_utils import get_callable_param_names
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-def instantiate_model(name: str, arch_params: dict, pretrained_weights: str = None) -> SgModule:
+def get_architecture(model_name: str, arch_params: HpmStruct, pretrained_weights: str) -> Tuple[Type[torch.nn.Module], HpmStruct, str, bool]:
+    """
+    Get the corresponding architecture class.
+
+    :param model_name:          Define the model's architecture from models/ALL_ARCHITECTURES
+    :param arch_params:         Architecture hyper parameters. e.g.: block, num_blocks, etc.
+    :param pretrained_weights:  Describe the dataset of the pretrained weights (for example "imagenent")
+
+    :return:
+        - architecture_cls:     Class of the model
+        - arch_params:          Might be updated if loading from remote deci lab
+        - pretrained_weights:   Might be updated if loading from remote deci lab
+        - is_remote:            True if loading from remote deci lab
+    """
+    is_remote = False
+    if not isinstance(model_name, str):
+        raise ValueError("Parameter model_name is expected to be a string.")
+    elif model_name not in ARCHITECTURES.keys():
+        logger.info(f'Required model {model_name} not found in local SuperGradients. Trying to load a model from remote deci lab')
+        deci_client = DeciClient()
+        _arch_params = deci_client.get_model_arch_params(model_name)
+        if _arch_params is None:
+            raise ValueError("Unsupported model name " + str(model_name) + ", see docs or all_architectures.py for supported nets.")
+        _arch_params = hydra.utils.instantiate(_arch_params)
+        _arch_params = HpmStruct(**_arch_params)
+        _arch_params.override(**arch_params.to_dict())
+        model_name, arch_params, is_remote = _arch_params["model_name"], _arch_params, True
+        pretrained_weights = deci_client.get_model_weights(model_name)
+    return ARCHITECTURES[model_name], arch_params, pretrained_weights, is_remote
+
+
+def instantiate_model(model_name: str, arch_params: dict, num_classes: int, pretrained_weights: str = None) -> torch.nn.Module:
     """
     """
     Instantiates nn.Module according to architecture and arch_params, and handles pretrained weights and the required
     Instantiates nn.Module according to architecture and arch_params, and handles pretrained weights and the required
         module manipulation (i.e head replacement).
         module manipulation (i.e head replacement).
 
 
-    :param name: Defines the model's architecture from models/ALL_ARCHITECTURES
-    :param arch_params: Architecture's parameters passed to models c'tor.
-    :param pretrained_weights: string describing the dataset of the pretrained weights (for example "imagenent")
-
-    :return: instantiated model i.e torch.nn.Module, architecture_class (will be none when architecture is not str)
+    :param model_name:          Define the model's architecture from models/ALL_ARCHITECTURES
+    :param arch_params:         Architecture hyper parameters. e.g.: block, num_blocks, etc.
+    :param num_classes:         Number of classes (defines the net's structure).
+                                    If None is given, will try to derrive from pretrained_weight's corresponding dataset.
+    :param pretrained_weights:  Describe the dataset of the pretrained weights (for example "imagenent")
 
 
+    :return:                    Instantiated model i.e torch.nn.Module, architecture_class (will be none when architecture is not str)
     """
     """
+    if arch_params is None:
+        arch_params = {}
+    arch_params = core_utils.HpmStruct(**arch_params)
 
 
-    if pretrained_weights is not None:
-        if hasattr(arch_params, "num_classes"):
-            num_classes_new_head = arch_params.num_classes
-        else:
-            num_classes_new_head = PRETRAINED_NUM_CLASSES[pretrained_weights]
+    architecture_cls, arch_params, pretrained_weights, is_remote = get_architecture(model_name, arch_params, pretrained_weights)
 
 
-        arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
+    if not issubclass(architecture_cls, SgModule):
+        net = architecture_cls(**arch_params.to_dict(include_schema=False))
+    else:
+        if core_utils.get_param(arch_params, "num_classes"):
+            logger.warning("Passing num_classes through arch_params is deprecated and will be removed in the next version. "
+                           "Pass num_classes explicitly to models.get")
+            num_classes = arch_params.num_classes
 
 
-    remote_model = False
-    if isinstance(name, str) and name in ARCHITECTURES.keys():
-        architecture_cls = ARCHITECTURES[name]
-        net = architecture_cls(arch_params=arch_params)
-    elif isinstance(name, str):
-        logger.info(f'Required model {name} not found in local SuperGradients. Trying to load a model from remote deci lab')
-        deci_client = DeciClient()
-        _arch_params = deci_client.get_model_arch_params(name)
+        if num_classes is not None:
+            arch_params.override(num_classes=num_classes)
 
 
-        if _arch_params is not None:
-            _arch_params = hydra.utils.instantiate(_arch_params)
-            base_name = _arch_params["model_name"]
-            _arch_params = HpmStruct(**_arch_params)
-            architecture_cls = ARCHITECTURES[base_name]
-            _arch_params.override(**arch_params.to_dict())
+        if pretrained_weights is None and num_classes is None:
+            raise ValueError("num_classes or pretrained_weights must be passed to determine net's structure.")
 
 
-            net = architecture_cls(arch_params=_arch_params)
-            remote_model = True
-        else:
-            raise ValueError("Unsupported model name " + str(name) + ", see docs or all_architectures.py for supported nets.")
-    else:
-        raise ValueError("Unsupported model model_name " + str(name) + ", see docs or all_architectures.py for supported nets.")
-    if pretrained_weights:
-        if remote_model:
-            weights_path = deci_client.get_model_weights(name)
-            load_pretrained_weights_local(net, name, weights_path)
-        else:
-            load_pretrained_weights(net, name, pretrained_weights)
-        if num_classes_new_head != arch_params.num_classes:
-            net.replace_head(new_num_classes=num_classes_new_head)
-            arch_params.num_classes = num_classes_new_head
+        if pretrained_weights:
+            num_classes_new_head = core_utils.get_param(arch_params, "num_classes", PRETRAINED_NUM_CLASSES[pretrained_weights])
+            arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
 
 
+        # Most of the SG models work with a single params names "arch_params" of type HpmStruct, but a few take **kwargs instead
+        if "arch_params" not in get_callable_param_names(architecture_cls):
+            net = architecture_cls(**arch_params.to_dict(include_schema=False))
+        else:
+            net = architecture_cls(arch_params=arch_params)
+
+        if pretrained_weights:
+            if is_remote:
+                load_pretrained_weights_local(net, model_name, pretrained_weights)
+            else:
+                load_pretrained_weights(net, model_name, pretrained_weights)
+            if num_classes_new_head != arch_params.num_classes:
+                net.replace_head(new_num_classes=num_classes_new_head)
+                arch_params.num_classes = num_classes_new_head
     return net
     return net
 
 
 
 
@@ -80,38 +112,21 @@ def get(model_name: str, arch_params: Optional[dict] = None, num_classes: int =
         strict_load: StrictLoad = StrictLoad.NO_KEY_MATCHING, checkpoint_path: str = None,
         strict_load: StrictLoad = StrictLoad.NO_KEY_MATCHING, checkpoint_path: str = None,
         pretrained_weights: str = None, load_backbone: bool = False) -> SgModule:
         pretrained_weights: str = None, load_backbone: bool = False) -> SgModule:
     """
     """
-    :param model_name:               Defines the model's architecture from models/ALL_ARCHITECTURES
-    :param num_classes:        Number of classes (defines the net's structure). If None is given, will try to derrive from
-                                pretrained_weight's corresponding dataset.
-    :param arch_params:                Architecture hyper parameters. e.g.: block, num_blocks, etc.
-
-    :param strict_load:                See super_gradients.common.data_types.enum.strict_load.StrictLoad class documentation for details
-     (default=NO_KEY_MATCHING to suport SG trained checkpoints)
-    :param load_backbone:              loads the provided checkpoint to model.backbone instead of model.
-    :param checkpoint_path:   The path to the external checkpoint to be loaded. Can be absolute or relative
-                                       (ie: path/to/checkpoint.pth). If provided, will automatically attempt to
-                                       load the checkpoint.
-    :param pretrained_weights: a string describing the dataset of the pretrained weights (for example "imagenent").
+    :param model_name:          Defines the model's architecture from models/ALL_ARCHITECTURES
+    :param arch_params:         Architecture hyper parameters. e.g.: block, num_blocks, etc.
+    :param num_classes:         Number of classes (defines the net's structure).
+                                    If None is given, will try to derrive from pretrained_weight's corresponding dataset.
+    :param strict_load:         See super_gradients.common.data_types.enum.strict_load.StrictLoad class documentation for details
+                                    (default=NO_KEY_MATCHING to suport SG trained checkpoints)
+    :param checkpoint_path:     The path to the external checkpoint to be loaded. Can be absolute or relative (ie: path/to/checkpoint.pth).
+                                    If provided, will automatically attempt to load the checkpoint.
+    :param pretrained_weights:  Describe the dataset of the pretrained weights (for example "imagenent").
+    :param load_backbone:       Load the provided checkpoint to model.backbone instead of model.
 
 
     NOTE: Passing pretrained_weights and checkpoint_path is ill-defined and will raise an error.
     NOTE: Passing pretrained_weights and checkpoint_path is ill-defined and will raise an error.
-
     """
     """
-    if arch_params is None:
-        arch_params = {}
-
-    if arch_params.get("num_classes") is not None:
-        logger.warning("Passing num_classes through arch_params is dperecated and will be removed in the next version. "
-                       "Pass num_classes explicitly to models.get")
-    num_classes = num_classes or arch_params.get("num_classes")
-
-    if pretrained_weights is None and num_classes is None:
-        raise ValueError("num_classes or pretrained_weights must be passed to determine net's structure.")
 
 
-    if num_classes is not None:
-        arch_params["num_classes"] = num_classes
-
-    arch_params = core_utils.HpmStruct(**arch_params)
-    net = instantiate_model(model_name, arch_params, pretrained_weights)
+    net = instantiate_model(model_name, arch_params, num_classes, pretrained_weights)
 
 
     if load_backbone and not checkpoint_path:
     if load_backbone and not checkpoint_path:
         raise ValueError("Please set checkpoint_path when load_backbone=True")
         raise ValueError("Please set checkpoint_path when load_backbone=True")
Discard
@@ -5,7 +5,7 @@ import time
 from dataclasses import dataclass
 from dataclasses import dataclass
 from multiprocessing import Process
 from multiprocessing import Process
 from pathlib import Path
 from pathlib import Path
-from typing import Tuple, Union, Dict, List, Sequence
+from typing import Tuple, Union, Dict, Sequence
 import random
 import random
 
 
 import inspect
 import inspect
@@ -337,16 +337,24 @@ def log_uncaught_exceptions(logger):
     sys.excepthook = handle_exception
     sys.excepthook = handle_exception
 
 
 
 
-def parse_args(cfg, arg_names: Union[List[str], callable]) -> dict:
+def parse_args(cfg, arg_names: Union[Sequence[str], callable]) -> dict:
     """
     """
     parse args from a config.
     parse args from a config.
     unlike get_param(), in this case only parameters that appear in the config will override default params from the function's signature
     unlike get_param(), in this case only parameters that appear in the config will override default params from the function's signature
     """
     """
     if not isinstance(arg_names, Sequence):
     if not isinstance(arg_names, Sequence):
-        arg_names = list(inspect.signature(arg_names).parameters.keys())
+        arg_names = get_callable_param_names(arg_names)
 
 
     kwargs_dict = {}
     kwargs_dict = {}
     for arg_name in arg_names:
     for arg_name in arg_names:
         if hasattr(cfg, arg_name) and getattr(cfg, arg_name) is not None:
         if hasattr(cfg, arg_name) and getattr(cfg, arg_name) is not None:
             kwargs_dict[arg_name] = getattr(cfg, arg_name)
             kwargs_dict[arg_name] = getattr(cfg, arg_name)
     return kwargs_dict
     return kwargs_dict
+
+
+def get_callable_param_names(obj: callable) -> Tuple[str]:
+    """Get the param names of a given callable (function, class, ...)
+    :param obj: Object to inspect
+    :return: Param names of that object
+    """
+    return tuple(inspect.signature(obj).parameters)
Discard
@@ -47,8 +47,15 @@ class HpmStruct:
     def override(self, **entries):
     def override(self, **entries):
         recursive_override(self.__dict__, entries)
         recursive_override(self.__dict__, entries)
 
 
-    def to_dict(self):
-        return self.__dict__
+    def to_dict(self, include_schema=True) -> dict:
+        """Convert this HpmStruct instance into a dict.
+        :param include_schema: If True, also return the field "schema"
+        :return: Dict representation of this HpmStruct instance.
+        """
+        out_dict = self.__dict__.copy()
+        if not include_schema:
+            out_dict.pop("schema")
+        return out_dict
 
 
     def validate(self):
     def validate(self):
         """
         """
Discard
@@ -39,7 +39,7 @@ class TestWithoutTrainTest(unittest.TestCase):
 
 
     @staticmethod
     @staticmethod
     def get_segmentation_trainer(name=''):
     def get_segmentation_trainer(name=''):
-        shelfnet_lw_arch_params = {"num_classes": 5, "load_checkpoint": False}
+        shelfnet_lw_arch_params = {"num_classes": 5}
         trainer = Trainer(name)
         trainer = Trainer(name)
         model = models.get('shelfnet34_lw', arch_params=shelfnet_lw_arch_params)
         model = models.get('shelfnet34_lw', arch_params=shelfnet_lw_arch_params)
         return trainer, model
         return trainer, model
Discard