|
@@ -1,6 +1,8 @@
|
|
-from typing import Optional
|
|
|
|
|
|
+from typing import Optional, Tuple
|
|
|
|
+from typing import Type
|
|
|
|
|
|
import hydra
|
|
import hydra
|
|
|
|
+import torch
|
|
|
|
|
|
from super_gradients.common import StrictLoad
|
|
from super_gradients.common import StrictLoad
|
|
from super_gradients.common.plugins.deci_client import DeciClient
|
|
from super_gradients.common.plugins.deci_client import DeciClient
|
|
@@ -16,63 +18,93 @@ from super_gradients.training.utils.checkpoint_utils import (
|
|
load_pretrained_weights_local,
|
|
load_pretrained_weights_local,
|
|
)
|
|
)
|
|
from super_gradients.common.abstractions.abstract_logger import get_logger
|
|
from super_gradients.common.abstractions.abstract_logger import get_logger
|
|
|
|
+from super_gradients.training.utils.sg_trainer_utils import get_callable_param_names
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
-def instantiate_model(name: str, arch_params: dict, pretrained_weights: str = None) -> SgModule:
|
|
|
|
|
|
+def get_architecture(model_name: str, arch_params: HpmStruct, pretrained_weights: str) -> Tuple[Type[torch.nn.Module], HpmStruct, str, bool]:
|
|
|
|
+ """
|
|
|
|
+ Get the corresponding architecture class.
|
|
|
|
+
|
|
|
|
+ :param model_name: Define the model's architecture from models/ALL_ARCHITECTURES
|
|
|
|
+ :param arch_params: Architecture hyper parameters. e.g.: block, num_blocks, etc.
|
|
|
|
+ :param pretrained_weights: Describe the dataset of the pretrained weights (for example "imagenent")
|
|
|
|
+
|
|
|
|
+ :return:
|
|
|
|
+ - architecture_cls: Class of the model
|
|
|
|
+ - arch_params: Might be updated if loading from remote deci lab
|
|
|
|
+ - pretrained_weights: Might be updated if loading from remote deci lab
|
|
|
|
+ - is_remote: True if loading from remote deci lab
|
|
|
|
+ """
|
|
|
|
+ is_remote = False
|
|
|
|
+ if not isinstance(model_name, str):
|
|
|
|
+ raise ValueError("Parameter model_name is expected to be a string.")
|
|
|
|
+ elif model_name not in ARCHITECTURES.keys():
|
|
|
|
+ logger.info(f'Required model {model_name} not found in local SuperGradients. Trying to load a model from remote deci lab')
|
|
|
|
+ deci_client = DeciClient()
|
|
|
|
+ _arch_params = deci_client.get_model_arch_params(model_name)
|
|
|
|
+ if _arch_params is None:
|
|
|
|
+ raise ValueError("Unsupported model name " + str(model_name) + ", see docs or all_architectures.py for supported nets.")
|
|
|
|
+ _arch_params = hydra.utils.instantiate(_arch_params)
|
|
|
|
+ _arch_params = HpmStruct(**_arch_params)
|
|
|
|
+ _arch_params.override(**arch_params.to_dict())
|
|
|
|
+ model_name, arch_params, is_remote = _arch_params["model_name"], _arch_params, True
|
|
|
|
+ pretrained_weights = deci_client.get_model_weights(model_name)
|
|
|
|
+ return ARCHITECTURES[model_name], arch_params, pretrained_weights, is_remote
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def instantiate_model(model_name: str, arch_params: dict, num_classes: int, pretrained_weights: str = None) -> torch.nn.Module:
|
|
"""
|
|
"""
|
|
Instantiates nn.Module according to architecture and arch_params, and handles pretrained weights and the required
|
|
Instantiates nn.Module according to architecture and arch_params, and handles pretrained weights and the required
|
|
module manipulation (i.e head replacement).
|
|
module manipulation (i.e head replacement).
|
|
|
|
|
|
- :param name: Defines the model's architecture from models/ALL_ARCHITECTURES
|
|
|
|
- :param arch_params: Architecture's parameters passed to models c'tor.
|
|
|
|
- :param pretrained_weights: string describing the dataset of the pretrained weights (for example "imagenent")
|
|
|
|
-
|
|
|
|
- :return: instantiated model i.e torch.nn.Module, architecture_class (will be none when architecture is not str)
|
|
|
|
|
|
+ :param model_name: Define the model's architecture from models/ALL_ARCHITECTURES
|
|
|
|
+ :param arch_params: Architecture hyper parameters. e.g.: block, num_blocks, etc.
|
|
|
|
+ :param num_classes: Number of classes (defines the net's structure).
|
|
|
|
+ If None is given, will try to derrive from pretrained_weight's corresponding dataset.
|
|
|
|
+ :param pretrained_weights: Describe the dataset of the pretrained weights (for example "imagenent")
|
|
|
|
|
|
|
|
+ :return: Instantiated model i.e torch.nn.Module, architecture_class (will be none when architecture is not str)
|
|
"""
|
|
"""
|
|
|
|
+ if arch_params is None:
|
|
|
|
+ arch_params = {}
|
|
|
|
+ arch_params = core_utils.HpmStruct(**arch_params)
|
|
|
|
|
|
- if pretrained_weights is not None:
|
|
|
|
- if hasattr(arch_params, "num_classes"):
|
|
|
|
- num_classes_new_head = arch_params.num_classes
|
|
|
|
- else:
|
|
|
|
- num_classes_new_head = PRETRAINED_NUM_CLASSES[pretrained_weights]
|
|
|
|
|
|
+ architecture_cls, arch_params, pretrained_weights, is_remote = get_architecture(model_name, arch_params, pretrained_weights)
|
|
|
|
|
|
- arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
|
|
|
|
|
|
+ if not issubclass(architecture_cls, SgModule):
|
|
|
|
+ net = architecture_cls(**arch_params.to_dict(include_schema=False))
|
|
|
|
+ else:
|
|
|
|
+ if core_utils.get_param(arch_params, "num_classes"):
|
|
|
|
+ logger.warning("Passing num_classes through arch_params is deprecated and will be removed in the next version. "
|
|
|
|
+ "Pass num_classes explicitly to models.get")
|
|
|
|
+ num_classes = arch_params.num_classes
|
|
|
|
|
|
- remote_model = False
|
|
|
|
- if isinstance(name, str) and name in ARCHITECTURES.keys():
|
|
|
|
- architecture_cls = ARCHITECTURES[name]
|
|
|
|
- net = architecture_cls(arch_params=arch_params)
|
|
|
|
- elif isinstance(name, str):
|
|
|
|
- logger.info(f'Required model {name} not found in local SuperGradients. Trying to load a model from remote deci lab')
|
|
|
|
- deci_client = DeciClient()
|
|
|
|
- _arch_params = deci_client.get_model_arch_params(name)
|
|
|
|
|
|
+ if num_classes is not None:
|
|
|
|
+ arch_params.override(num_classes=num_classes)
|
|
|
|
|
|
- if _arch_params is not None:
|
|
|
|
- _arch_params = hydra.utils.instantiate(_arch_params)
|
|
|
|
- base_name = _arch_params["model_name"]
|
|
|
|
- _arch_params = HpmStruct(**_arch_params)
|
|
|
|
- architecture_cls = ARCHITECTURES[base_name]
|
|
|
|
- _arch_params.override(**arch_params.to_dict())
|
|
|
|
|
|
+ if pretrained_weights is None and num_classes is None:
|
|
|
|
+ raise ValueError("num_classes or pretrained_weights must be passed to determine net's structure.")
|
|
|
|
|
|
- net = architecture_cls(arch_params=_arch_params)
|
|
|
|
- remote_model = True
|
|
|
|
- else:
|
|
|
|
- raise ValueError("Unsupported model name " + str(name) + ", see docs or all_architectures.py for supported nets.")
|
|
|
|
- else:
|
|
|
|
- raise ValueError("Unsupported model model_name " + str(name) + ", see docs or all_architectures.py for supported nets.")
|
|
|
|
- if pretrained_weights:
|
|
|
|
- if remote_model:
|
|
|
|
- weights_path = deci_client.get_model_weights(name)
|
|
|
|
- load_pretrained_weights_local(net, name, weights_path)
|
|
|
|
- else:
|
|
|
|
- load_pretrained_weights(net, name, pretrained_weights)
|
|
|
|
- if num_classes_new_head != arch_params.num_classes:
|
|
|
|
- net.replace_head(new_num_classes=num_classes_new_head)
|
|
|
|
- arch_params.num_classes = num_classes_new_head
|
|
|
|
|
|
+ if pretrained_weights:
|
|
|
|
+ num_classes_new_head = core_utils.get_param(arch_params, "num_classes", PRETRAINED_NUM_CLASSES[pretrained_weights])
|
|
|
|
+ arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
|
|
|
|
|
|
|
|
+ # Most of the SG models work with a single params names "arch_params" of type HpmStruct, but a few take **kwargs instead
|
|
|
|
+ if "arch_params" not in get_callable_param_names(architecture_cls):
|
|
|
|
+ net = architecture_cls(**arch_params.to_dict(include_schema=False))
|
|
|
|
+ else:
|
|
|
|
+ net = architecture_cls(arch_params=arch_params)
|
|
|
|
+
|
|
|
|
+ if pretrained_weights:
|
|
|
|
+ if is_remote:
|
|
|
|
+ load_pretrained_weights_local(net, model_name, pretrained_weights)
|
|
|
|
+ else:
|
|
|
|
+ load_pretrained_weights(net, model_name, pretrained_weights)
|
|
|
|
+ if num_classes_new_head != arch_params.num_classes:
|
|
|
|
+ net.replace_head(new_num_classes=num_classes_new_head)
|
|
|
|
+ arch_params.num_classes = num_classes_new_head
|
|
return net
|
|
return net
|
|
|
|
|
|
|
|
|
|
@@ -80,38 +112,21 @@ def get(model_name: str, arch_params: Optional[dict] = None, num_classes: int =
|
|
strict_load: StrictLoad = StrictLoad.NO_KEY_MATCHING, checkpoint_path: str = None,
|
|
strict_load: StrictLoad = StrictLoad.NO_KEY_MATCHING, checkpoint_path: str = None,
|
|
pretrained_weights: str = None, load_backbone: bool = False) -> SgModule:
|
|
pretrained_weights: str = None, load_backbone: bool = False) -> SgModule:
|
|
"""
|
|
"""
|
|
- :param model_name: Defines the model's architecture from models/ALL_ARCHITECTURES
|
|
|
|
- :param num_classes: Number of classes (defines the net's structure). If None is given, will try to derrive from
|
|
|
|
- pretrained_weight's corresponding dataset.
|
|
|
|
- :param arch_params: Architecture hyper parameters. e.g.: block, num_blocks, etc.
|
|
|
|
-
|
|
|
|
- :param strict_load: See super_gradients.common.data_types.enum.strict_load.StrictLoad class documentation for details
|
|
|
|
- (default=NO_KEY_MATCHING to suport SG trained checkpoints)
|
|
|
|
- :param load_backbone: loads the provided checkpoint to model.backbone instead of model.
|
|
|
|
- :param checkpoint_path: The path to the external checkpoint to be loaded. Can be absolute or relative
|
|
|
|
- (ie: path/to/checkpoint.pth). If provided, will automatically attempt to
|
|
|
|
- load the checkpoint.
|
|
|
|
- :param pretrained_weights: a string describing the dataset of the pretrained weights (for example "imagenent").
|
|
|
|
|
|
+ :param model_name: Defines the model's architecture from models/ALL_ARCHITECTURES
|
|
|
|
+ :param arch_params: Architecture hyper parameters. e.g.: block, num_blocks, etc.
|
|
|
|
+ :param num_classes: Number of classes (defines the net's structure).
|
|
|
|
+ If None is given, will try to derrive from pretrained_weight's corresponding dataset.
|
|
|
|
+ :param strict_load: See super_gradients.common.data_types.enum.strict_load.StrictLoad class documentation for details
|
|
|
|
+ (default=NO_KEY_MATCHING to suport SG trained checkpoints)
|
|
|
|
+ :param checkpoint_path: The path to the external checkpoint to be loaded. Can be absolute or relative (ie: path/to/checkpoint.pth).
|
|
|
|
+ If provided, will automatically attempt to load the checkpoint.
|
|
|
|
+ :param pretrained_weights: Describe the dataset of the pretrained weights (for example "imagenent").
|
|
|
|
+ :param load_backbone: Load the provided checkpoint to model.backbone instead of model.
|
|
|
|
|
|
NOTE: Passing pretrained_weights and checkpoint_path is ill-defined and will raise an error.
|
|
NOTE: Passing pretrained_weights and checkpoint_path is ill-defined and will raise an error.
|
|
-
|
|
|
|
"""
|
|
"""
|
|
- if arch_params is None:
|
|
|
|
- arch_params = {}
|
|
|
|
-
|
|
|
|
- if arch_params.get("num_classes") is not None:
|
|
|
|
- logger.warning("Passing num_classes through arch_params is dperecated and will be removed in the next version. "
|
|
|
|
- "Pass num_classes explicitly to models.get")
|
|
|
|
- num_classes = num_classes or arch_params.get("num_classes")
|
|
|
|
-
|
|
|
|
- if pretrained_weights is None and num_classes is None:
|
|
|
|
- raise ValueError("num_classes or pretrained_weights must be passed to determine net's structure.")
|
|
|
|
|
|
|
|
- if num_classes is not None:
|
|
|
|
- arch_params["num_classes"] = num_classes
|
|
|
|
-
|
|
|
|
- arch_params = core_utils.HpmStruct(**arch_params)
|
|
|
|
- net = instantiate_model(model_name, arch_params, pretrained_weights)
|
|
|
|
|
|
+ net = instantiate_model(model_name, arch_params, num_classes, pretrained_weights)
|
|
|
|
|
|
if load_backbone and not checkpoint_path:
|
|
if load_backbone and not checkpoint_path:
|
|
raise ValueError("Please set checkpoint_path when load_backbone=True")
|
|
raise ValueError("Please set checkpoint_path when load_backbone=True")
|