|
@@ -1,5 +1,5 @@
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
-from typing import Tuple, Type, Optional
|
|
|
|
|
|
+from typing import Tuple, Type, Optional, Union
|
|
|
|
|
|
import hydra
|
|
import hydra
|
|
import torch
|
|
import torch
|
|
@@ -79,7 +79,7 @@ def get_architecture(model_name: str, arch_params: HpmStruct, download_required_
|
|
|
|
|
|
def instantiate_model(
|
|
def instantiate_model(
|
|
model_name: str, arch_params: dict, num_classes: int, pretrained_weights: str = None, download_required_code: bool = True
|
|
model_name: str, arch_params: dict, num_classes: int, pretrained_weights: str = None, download_required_code: bool = True
|
|
-) -> torch.nn.Module:
|
|
|
|
|
|
+) -> Union[SgModule, torch.nn.Module]:
|
|
"""
|
|
"""
|
|
Instantiates nn.Module according to architecture and arch_params, and handles pretrained weights and the required
|
|
Instantiates nn.Module according to architecture and arch_params, and handles pretrained weights and the required
|
|
module manipulation (i.e head replacement).
|
|
module manipulation (i.e head replacement).
|
|
@@ -147,7 +147,7 @@ def get(
|
|
load_backbone: bool = False,
|
|
load_backbone: bool = False,
|
|
download_required_code: bool = True,
|
|
download_required_code: bool = True,
|
|
checkpoint_num_classes: int = None,
|
|
checkpoint_num_classes: int = None,
|
|
-) -> SgModule:
|
|
|
|
|
|
+) -> Union[SgModule, torch.nn.Module]:
|
|
"""
|
|
"""
|
|
:param model_name: Defines the model's architecture from models/ALL_ARCHITECTURES
|
|
:param model_name: Defines the model's architecture from models/ALL_ARCHITECTURES
|
|
:param arch_params: Architecture hyper parameters. e.g.: block, num_blocks, etc.
|
|
:param arch_params: Architecture hyper parameters. e.g.: block, num_blocks, etc.
|