Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#316 Feature/sg 118 refactor build model

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-118_refactor_build_model
68 changed files with 860 additions and 965 deletions
  1. 1
    1
      src/super_gradients/common/sg_loggers/abstract_sg_logger.py
  2. 13
    14
      src/super_gradients/examples/deci_platform_logger_example/deci_platform_logger_example.py
  3. 1
    1
      src/super_gradients/examples/legacy/cifar_resnet/cifar_example.py
  4. 1
    1
      src/super_gradients/examples/legacy/imagenet_efficientnet/efficientnet_example.py
  5. 1
    1
      src/super_gradients/examples/legacy/imagenet_mobilenetv3/mobilenetv3_imagenet_example.py
  6. 1
    1
      src/super_gradients/examples/legacy/imagenet_regnetY800/regnetY800_example.py
  7. 1
    1
      src/super_gradients/examples/legacy/imagenet_repvgg/imagenet_repvgg_example.py
  8. 1
    1
      src/super_gradients/examples/legacy/imagenet_resnet/imagenet_resnet_example.py
  9. 2
    1
      src/super_gradients/recipes/checkpoint_params/default_checkpoint_params.yaml
  10. 4
    3
      src/super_gradients/recipes/cifar10_resnet.yaml
  11. 2
    3
      src/super_gradients/recipes/cityscapes_ddrnet.yaml
  12. 3
    2
      src/super_gradients/recipes/cityscapes_regseg48.yaml
  13. 2
    4
      src/super_gradients/recipes/cityscapes_stdc_base.yaml
  14. 2
    1
      src/super_gradients/recipes/coco2017_ssd_lite_mobilenet_v2.yaml
  15. 3
    2
      src/super_gradients/recipes/coco2017_yolox.yaml
  16. 5
    4
      src/super_gradients/recipes/coco_segmentation_shelfnet_lw.yaml
  17. 3
    3
      src/super_gradients/recipes/imagenet_efficientnet.yaml
  18. 3
    3
      src/super_gradients/recipes/imagenet_mobilenetv2.yaml
  19. 3
    3
      src/super_gradients/recipes/imagenet_mobilenetv3.yaml
  20. 3
    2
      src/super_gradients/recipes/imagenet_regnetY.yaml
  21. 3
    3
      src/super_gradients/recipes/imagenet_repvgg.yaml
  22. 3
    3
      src/super_gradients/recipes/imagenet_resnet50.yaml
  23. 20
    4
      src/super_gradients/recipes/imagenet_resnet50_kd.yaml
  24. 4
    2
      src/super_gradients/recipes/imagenet_vit_base.yaml
  25. 3
    3
      src/super_gradients/recipes/test_resnet.yaml
  26. 1
    0
      src/super_gradients/recipes/training_hyperparams/default_train_params.yaml
  27. 54
    14
      src/super_gradients/training/kd_trainer/kd_trainer.py
  28. 1
    0
      src/super_gradients/training/models/__init__.py
  29. 92
    0
      src/super_gradients/training/models/model_factory.py
  30. 4
    1
      src/super_gradients/training/params.py
  31. 58
    19
      src/super_gradients/training/sg_trainer/sg_trainer.py
  32. 3
    3
      src/super_gradients/training/utils/callbacks.py
  33. 1
    1
      src/super_gradients/training/utils/checkpoint_utils.py
  34. 0
    2
      tests/deci_core_unit_test_suite_runner.py
  35. 4
    2
      tests/end_to_end_tests/cifar10_trainer_test.py
  36. 3
    3
      tests/end_to_end_tests/external_dataset_e2e.py
  37. 16
    36
      tests/end_to_end_tests/trainer_test.py
  38. 6
    4
      tests/integration_tests/conversion_callback_test.py
  39. 3
    4
      tests/integration_tests/deci_lab_export_test.py
  40. 3
    3
      tests/integration_tests/ema_train_integration_test.py
  41. 12
    9
      tests/integration_tests/lr_test.py
  42. 184
    180
      tests/integration_tests/pretrained_models_test.py
  43. 12
    12
      tests/integration_tests/qat_integration_test.py
  44. 0
    43
      tests/test-data-interface.py
  45. 1
    2
      tests/unit_tests/__init__.py
  46. 3
    3
      tests/unit_tests/dataset_statistics_test.py
  47. 3
    3
      tests/unit_tests/detection_utils_test.py
  48. 7
    13
      tests/unit_tests/early_stop_test.py
  49. 1
    1
      tests/unit_tests/factories_test.py
  50. 3
    3
      tests/unit_tests/forward_pass_prep_fn_test.py
  51. 18
    15
      tests/unit_tests/initialize_with_dataloaders_test.py
  52. 37
    88
      tests/unit_tests/kd_ema_test.py
  53. 95
    202
      tests/unit_tests/kd_trainer_test.py
  54. 0
    85
      tests/unit_tests/load_checkpoint_from_direct_path_test.py
  55. 19
    7
      tests/unit_tests/load_ema_ckpt_test.py
  56. 1
    2
      tests/unit_tests/lr_cooldown_test.py
  57. 7
    11
      tests/unit_tests/lr_warmup_test.py
  58. 1
    2
      tests/unit_tests/phase_context_test.py
  59. 3
    3
      tests/unit_tests/phase_delegates_test.py
  60. 7
    8
      tests/unit_tests/pretrained_models_unit_test.py
  61. 3
    3
      tests/unit_tests/save_ckpt_test.py
  62. 43
    48
      tests/unit_tests/strictload_enum_test.py
  63. 23
    19
      tests/unit_tests/test_without_train_test.py
  64. 1
    2
      tests/unit_tests/train_logging_test.py
  65. 32
    34
      tests/unit_tests/train_with_intialized_param_args_test.py
  66. 3
    4
      tests/unit_tests/train_with_precise_bn_test.py
  67. 1
    2
      tests/unit_tests/update_param_groups_unit_test.py
  68. 3
    2
      tests/unit_tests/vit_unit_test.py
@@ -21,7 +21,7 @@ class AbstractSGLogger(ABC):
     @abstractmethod
     @abstractmethod
     def add(self, tag: str, obj: Any, global_step: int = None):
     def add(self, tag: str, obj: Any, global_step: int = None):
         """
         """
-        A generic function for adding any type of data to the SGLogger. By default, this function is not called by the SGModel, BaseSGLogger
+        A generic function for adding any type of data to the SGLogger. By default, this function is not called by the Trainer, BaseSGLogger
         does nothing with this type of data. But if you need to pass a data type which is not supported by any of the following abstract methods, use this
         does nothing with this type of data. But if you need to pass a data type which is not supported by any of the following abstract methods, use this
         method.
         method.
         """
         """
Discard
@@ -11,17 +11,16 @@ dataset = Cifar10DatasetInterface(dataset_params={"batch_size": 256, "val_batch_
 trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
 trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
 trainer.build_model("resnet18")
 trainer.build_model("resnet18")
 
 
-trainer.train(
-    training_params={"max_epochs": 20,
-                     "lr_updates": [5, 10, 15],
-                     "lr_decay_factor": 0.1,
-                     "lr_mode": "step",
-                     "initial_lr": 0.1,
-                     "loss": "cross_entropy",
-                     "optimizer": "SGD",
-                     "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
-                     "train_metrics_list": [Accuracy(), Top5()],
-                     "valid_metrics_list": [Accuracy(), Top5()],
-                     "metric_to_watch": "Accuracy",
-                     "greater_metric_to_watch_is_better": True,
-                     "sg_logger": "deci_platform_sg_logger"})
+trainer.train(training_params={"max_epochs": 20,
+                               "lr_updates": [5, 10, 15],
+                               "lr_decay_factor": 0.1,
+                               "lr_mode": "step",
+                               "initial_lr": 0.1,
+                               "loss": "cross_entropy",
+                               "optimizer": "SGD",
+                               "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
+                               "train_metrics_list": [Accuracy(), Top5()],
+                               "valid_metrics_list": [Accuracy(), Top5()],
+                               "metric_to_watch": "Accuracy",
+                               "greater_metric_to_watch_is_better": True,
+                               "sg_logger": "deci_platform_sg_logger"})
Discard
@@ -19,7 +19,7 @@ def train(cfg: DictConfig) -> None:
     cfg.trainer .build_model(cfg.architecture, load_checkpoint=cfg.load_checkpoint)
     cfg.trainer .build_model(cfg.architecture, load_checkpoint=cfg.load_checkpoint)
 
 
     # TRAIN
     # TRAIN
-    cfg.trainer .train(training_params=cfg.training_params)
+    cfg.trainer.train(training_params=cfg.training_params)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard
@@ -23,7 +23,7 @@ def train(cfg: DictConfig) -> None:
     cfg.trainer .build_model(cfg.architecture, load_checkpoint=cfg.load_checkpoint)
     cfg.trainer .build_model(cfg.architecture, load_checkpoint=cfg.load_checkpoint)
 
 
     # TRAIN
     # TRAIN
-    cfg.trainer .train(training_params=cfg.training_params)
+    cfg.trainer.train(training_params=cfg.training_params)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard
@@ -22,7 +22,7 @@ def train(cfg: DictConfig) -> None:
     cfg.trainer .build_model(cfg.architecture, load_checkpoint=cfg.load_checkpoint)
     cfg.trainer .build_model(cfg.architecture, load_checkpoint=cfg.load_checkpoint)
 
 
     # TRAIN
     # TRAIN
-    cfg.trainer .train(training_params=cfg.training_params)
+    cfg.trainer.train(training_params=cfg.training_params)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard
@@ -21,7 +21,7 @@ def train(cfg: DictConfig) -> None:
     cfg.trainer .build_model(cfg.architecture, arch_params=cfg.arch_params, load_checkpoint=cfg.load_checkpoint)
     cfg.trainer .build_model(cfg.architecture, arch_params=cfg.arch_params, load_checkpoint=cfg.load_checkpoint)
 
 
     # TRAIN
     # TRAIN
-    cfg.trainer .train(training_params=cfg.training_params)
+    cfg.trainer.train(training_params=cfg.training_params)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard
@@ -16,7 +16,7 @@ def train(cfg: DictConfig) -> None:
     cfg.trainer .build_model(cfg.architecture, arch_params=cfg.arch_params, load_checkpoint=cfg.load_checkpoint)
     cfg.trainer .build_model(cfg.architecture, arch_params=cfg.arch_params, load_checkpoint=cfg.load_checkpoint)
 
 
     # TRAIN
     # TRAIN
-    cfg.trainer .train(training_params=cfg.training_params)
+    cfg.trainer.train(training_params=cfg.training_params)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard
@@ -34,7 +34,7 @@ def train(cfg: DictConfig) -> None:
     cfg.trainer .build_model(cfg.architecture, arch_params=cfg.arch_params, load_checkpoint=cfg.load_checkpoint)
     cfg.trainer .build_model(cfg.architecture, arch_params=cfg.arch_params, load_checkpoint=cfg.load_checkpoint)
 
 
     # TRAIN
     # TRAIN
-    cfg.trainer .train(training_params=cfg.training_params)
+    cfg.trainer.train(training_params=cfg.training_params)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard
@@ -4,4 +4,5 @@ external_checkpoint_path: # checkpoint path that is not located in super_gradien
 source_ckpt_folder_name: # dirname for checkpoint loading
 source_ckpt_folder_name: # dirname for checkpoint loading
 strict_load: # key matching strictness for loading checkpoint's weights
 strict_load: # key matching strictness for loading checkpoint's weights
   _target_: super_gradients.training.sg_trainer.StrictLoad
   _target_: super_gradients.training.sg_trainer.StrictLoad
-  value: True
+  value: True
+pretrained_weights: # a string describing the dataset of the pretrained weights (for example "imagenent").
Discard
@@ -18,9 +18,10 @@ dataset_interface:
 
 
 data_loader_num_workers: 8
 data_loader_num_workers: 8
 
 
-load_checkpoint: False
-checkpoint_params:
-  load_checkpoint: ${load_checkpoint}
+resume: False
+training_hyperparams:
+  resume: $(resume}
+
 
 
 model_checkpoints_location: local
 model_checkpoints_location: local
 ckpt_root_dir:
 ckpt_root_dir:
Discard
@@ -100,9 +100,8 @@ arch_params:
 load_checkpoint: False
 load_checkpoint: False
 checkpoint_params:
 checkpoint_params:
   load_checkpoint: ${load_checkpoint}
   load_checkpoint: ${load_checkpoint}
-  external_checkpoint_path:
+  checkpoint_path:
   load_backbone: True
   load_backbone: True
-  load_weights_only: True
   strict_load: no_key_matching
   strict_load: no_key_matching
 
 
 architecture: ddrnet_23
 architecture: ddrnet_23
@@ -112,7 +111,7 @@ model_checkpoints_location: local
 ckpt_root_dir:
 ckpt_root_dir:
 
 
 multi_gpu:
 multi_gpu:
-  _target_: super_gradients.training.sg_model.MultiGPUMode
+  _target_: super_gradients.training.sg_trainer.MultiGPUMode
   value: 'DDP'
   value: 'DDP'
 
 
 
 
Discard
@@ -85,8 +85,10 @@ dataset_interface:
 data_loader_num_workers: 8
 data_loader_num_workers: 8
 
 
 
 
+resume: False
 
 
 training_hyperparams:
 training_hyperparams:
+  resume: ${resume}
   max_epochs: 800
   max_epochs: 800
   lr_mode: poly
   lr_mode: poly
   initial_lr: 0.02   # for effective batch_size=16
   initial_lr: 0.02   # for effective batch_size=16
@@ -121,8 +123,7 @@ training_hyperparams:
 
 
   _convert_: all
   _convert_: all
 load_checkpoint: False
 load_checkpoint: False
-checkpoint_params:
-  load_checkpoint: ${load_checkpoint}
+
 
 
 model_checkpoints_location: local
 model_checkpoints_location: local
 ckpt_root_dir:
 ckpt_root_dir:
Discard
@@ -46,10 +46,8 @@ arch_params:
   use_aux_heads: True
   use_aux_heads: True
   sync_bn: True
   sync_bn: True
 
 
-load_checkpoint: False
 checkpoint_params:
 checkpoint_params:
-  load_checkpoint: ${load_checkpoint}
-  external_checkpoint_path:
+  checkpoint_path:
   load_backbone: True
   load_backbone: True
   load_weights_only: True
   load_weights_only: True
   strict_load: no_key_matching
   strict_load: no_key_matching
@@ -61,6 +59,6 @@ model_checkpoints_location: local
 ckpt_root_dir:
 ckpt_root_dir:
 
 
 multi_gpu:
 multi_gpu:
-  _target_: super_gradients.training.sg_model.MultiGPUMode
+  _target_: super_gradients.training.sg_trainer.MultiGPUMode
   value: 'DDP'
   value: 'DDP'
 
 
Discard
@@ -52,8 +52,9 @@ arch_params:
 dataset_interface:
 dataset_interface:
   coco2017_detection:
   coco2017_detection:
     dataset_params: ${dataset_params}
     dataset_params: ${dataset_params}
-
+resume: False
 training_hyperparams:
 training_hyperparams:
+  resume: ${resume}
   criterion_params:
   criterion_params:
     alpha: 1.0
     alpha: 1.0
     dboxes: ${dboxes}
     dboxes: ${dboxes}
Discard
@@ -27,8 +27,9 @@ data_loader_num_workers: 8
 model_checkpoints_location: local
 model_checkpoints_location: local
 
 
 load_checkpoint: False
 load_checkpoint: False
-checkpoint_params:
-  load_checkpoint: ${load_checkpoint}
+resume: False
+training_hyperparams:
+  resume: ${resume}
 
 
 architecture: yolox_s
 architecture: yolox_s
 
 
Discard
@@ -21,13 +21,14 @@ dataset_interface:
 
 
 data_loader_num_workers: 8
 data_loader_num_workers: 8
 
 
-load_checkpoint: True
 checkpoint_params:
 checkpoint_params:
-  load_checkpoint: ${load_checkpoint}
   strict_load: True
   strict_load: True
-  load_weights_only: True
   load_backbone: True
   load_backbone: True
-  source_ckpt_folder_name: resnet_34
+  checkpoint_path:
+
+resume: False
+training_hyperparams:
+  resume: ${resume}
 
 
 experiment_name: coco_segmentation_21_subclass_shelfnet34
 experiment_name: coco_segmentation_21_subclass_shelfnet34
 
 
Discard
@@ -31,9 +31,9 @@ dataset_interface:
 
 
 data_loader_num_workers: 8
 data_loader_num_workers: 8
 
 
-load_checkpoint: False
-checkpoint_params:
-  load_checkpoint: ${load_checkpoint}
+resume: False
+training_hyperparams:
+  resume: ${resume}
 
 
 experiment_name: efficientnet_b0_imagenet
 experiment_name: efficientnet_b0_imagenet
 
 
Discard
@@ -32,9 +32,9 @@ dataset_interface:
 data_loader_num_workers: 8
 data_loader_num_workers: 8
 
 
 model_checkpoints_location: local
 model_checkpoints_location: local
-load_checkpoint: False
-checkpoint_params:
-  load_checkpoint: ${load_checkpoint}
+resume: False
+training_hyperparams:
+  resume: ${resume}
 
 
 experiment_name: mobileNetv2_training
 experiment_name: mobileNetv2_training
 
 
Discard
@@ -15,9 +15,9 @@ dataset_interface:
 data_loader_num_workers: 16
 data_loader_num_workers: 16
 
 
 model_checkpoints_location: local
 model_checkpoints_location: local
-load_checkpoint: False
-checkpoint_params:
-  load_checkpoint: ${load_checkpoint}
+resume: False
+training_hyperparams:
+  resume: ${resume}
 
 
 experiment_name: mobileNetv3_large_training
 experiment_name: mobileNetv3_large_training
 
 
Discard
@@ -50,8 +50,9 @@ data_loader_num_workers: 8
 
 
 model_checkpoints_location: local
 model_checkpoints_location: local
 load_checkpoint: False
 load_checkpoint: False
-checkpoint_params:
-  load_checkpoint: ${load_checkpoint}
+resume: False
+training_hyperparams:
+  resume: ${resume}
 
 
 experiment_name: regnetY800_imagenet
 experiment_name: regnetY800_imagenet
 
 
Discard
@@ -29,9 +29,9 @@ data_loader_num_workers: 8
 
 
 
 
 model_checkpoints_location: local
 model_checkpoints_location: local
-load_checkpoint: False
-checkpoint_params:
-  load_checkpoint: ${load_checkpoint}
+resume: False
+training_hyperparams:
+  resume: ${resume}
 
 
 experiment_name: repvgg_a0_imagenet_reproduce_fix
 experiment_name: repvgg_a0_imagenet_reproduce_fix
 
 
Discard
@@ -40,9 +40,9 @@ dataset_interface:
 data_loader_num_workers: 8
 data_loader_num_workers: 8
 
 
 model_checkpoints_location: local
 model_checkpoints_location: local
-load_checkpoint: False
-checkpoint_params:
-  load_checkpoint: ${load_checkpoint}
+resume: False
+training_hyperparams:
+  resume: ${resume}
 
 
 experiment_name: resnet50_imagenet
 experiment_name: resnet50_imagenet
 
 
Discard
@@ -17,7 +17,10 @@ defaults:
   - arch_params: default_arch_params
   - arch_params: default_arch_params
   - checkpoint_params: default_checkpoint_params
   - checkpoint_params: default_checkpoint_params
 
 
+
+resume: False
 training_hyperparams:
 training_hyperparams:
+  resume: ${resume}
   loss: kd_loss
   loss: kd_loss
   criterion_params:
   criterion_params:
     distillation_loss_coeff: 0.8
     distillation_loss_coeff: 0.8
@@ -40,6 +43,22 @@ teacher_arch_params:
   image_size: [224, 224]
   image_size: [224, 224]
   patch_size: [16, 16]
   patch_size: [16, 16]
 
 
+teacher_checkpoint_params:
+  load_backbone: False # whether to load only backbone part of checkpoint
+  checkpoint_path: # checkpoint path that is not located in super_gradients/checkpoints
+  strict_load: # key matching strictness for loading checkpoint's weights
+    _target_: super_gradients.training.sg_trainer.StrictLoad
+    value: True
+  pretrained_weights: imagenet
+
+student_checkpoint_params:
+  load_backbone: False # whether to load only backbone part of checkpoint
+  checkpoint_path: # checkpoint path that is not located in super_gradients/checkpoints
+  strict_load: # key matching strictness for loading checkpoint's weights
+    _target_: super_gradients.training.sg_trainer.StrictLoad
+    value: True
+  pretrained_weights: # a string describing the dataset of the pretrained weights (for example "imagenent").
+
 dataset_params:
 dataset_params:
   batch_size: 192
   batch_size: 192
   val_batch_size: 256
   val_batch_size: 256
@@ -61,10 +80,7 @@ dataset_interface:
 data_loader_num_workers: 8
 data_loader_num_workers: 8
 
 
 model_checkpoints_location: local
 model_checkpoints_location: local
-load_checkpoint: False
-checkpoint_params:
-  load_checkpoint: ${load_checkpoint}
-  teacher_pretrained_weights: imagenet
+
 
 
 run_teacher_on_eval: True
 run_teacher_on_eval: True
 
 
Discard
@@ -42,8 +42,10 @@ dataset_interface:
 data_loader_num_workers: 8
 data_loader_num_workers: 8
 
 
 model_checkpoints_location: local
 model_checkpoints_location: local
-load_checkpoint: True
-load_weights_only: True
+
+resume: False
+training_hyperparams:
+  resume: ${resume}
 
 
 experiment_name: vit_base_imagenet1k
 experiment_name: vit_base_imagenet1k
 
 
Discard
@@ -12,9 +12,9 @@ dataset_interface:
 
 
 data_loader_num_workers: 1
 data_loader_num_workers: 1
 
 
-load_checkpoint: False
-checkpoint_params:
-  load_checkpoint: ${load_checkpoint}
+resume: False
+training_hyperparams:
+  resume: $(resume}
 
 
 experiment_name: test
 experiment_name: test
 
 
Discard
@@ -1,3 +1,4 @@
+resume: False # whether to continue training from ckpt with the same experiment name.
 lr_mode: # Learning rate scheduling policy, one of ['step','poly','cosine','function']
 lr_mode: # Learning rate scheduling policy, one of ['step','poly','cosine','function']
 lr_schedule_function: # Learning rate scheduling function to be used when `lr_mode` is 'function'.
 lr_schedule_function: # Learning rate scheduling function to be used when `lr_mode` is 'function'.
 lr_warmup_epochs: 0 # number of epochs for learning rate warm up - see https://arxiv.org/pdf/1706.02677.pdf (Section 2.2).
 lr_warmup_epochs: 0 # number of epochs for learning rate warm up - see https://arxiv.org/pdf/1706.02677.pdf (Section 2.2).
Discard
@@ -4,14 +4,15 @@ from omegaconf import DictConfig
 from torch.utils.data import DataLoader
 from torch.utils.data import DataLoader
 
 
 from super_gradients.common import MultiGPUMode
 from super_gradients.common import MultiGPUMode
+from super_gradients.training.models import SgModule
 from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
 from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
 from super_gradients.training.models.kd_modules.kd_module import KDModule
 from super_gradients.training.models.kd_modules.kd_module import KDModule
 from super_gradients.training.sg_trainer import Trainer
 from super_gradients.training.sg_trainer import Trainer
 from typing import Union, List, Any
 from typing import Union, List, Any
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
-from super_gradients.training import utils as core_utils
+from super_gradients.training import utils as core_utils, models
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
-from super_gradients.training.utils import get_param
+from super_gradients.training.utils import get_param, HpmStruct
 from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, \
 from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, \
     load_checkpoint_to_model
     load_checkpoint_to_model
 from super_gradients.training.exceptions.kd_trainer_exceptions import ArchitectureKwargsException, \
 from super_gradients.training.exceptions.kd_trainer_exceptions import ArchitectureKwargsException, \
@@ -27,11 +28,14 @@ logger = get_logger(__name__)
 
 
 class KDTrainer(Trainer):
 class KDTrainer(Trainer):
     def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
     def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
-                 model_checkpoints_location: str = 'local', overwrite_local_checkpoint: bool = True, ckpt_name: str = 'ckpt_latest.pth',
-                 post_prediction_callback: DetectionPostPredictionCallback = None, ckpt_root_dir: str = None, train_loader: DataLoader = None,
+                 model_checkpoints_location: str = 'local', overwrite_local_checkpoint: bool = True,
+                 ckpt_name: str = 'ckpt_latest.pth',
+                 post_prediction_callback: DetectionPostPredictionCallback = None, ckpt_root_dir: str = None,
+                 train_loader: DataLoader = None,
                  valid_loader: DataLoader = None, test_loader: DataLoader = None, classes: List[Any] = None):
                  valid_loader: DataLoader = None, test_loader: DataLoader = None, classes: List[Any] = None):
 
 
-        super().__init__(experiment_name, device, multi_gpu, model_checkpoints_location, overwrite_local_checkpoint, ckpt_name, post_prediction_callback,
+        super().__init__(experiment_name, device, multi_gpu, model_checkpoints_location, overwrite_local_checkpoint,
+                         ckpt_name, post_prediction_callback,
                          ckpt_root_dir, train_loader, valid_loader, test_loader, classes)
                          ckpt_root_dir, train_loader, valid_loader, test_loader, classes)
         self.student_architecture = None
         self.student_architecture = None
         self.teacher_architecture = None
         self.teacher_architecture = None
@@ -56,15 +60,22 @@ class KDTrainer(Trainer):
         # CONNECT THE DATASET INTERFACE WITH DECI MODEL
         # CONNECT THE DATASET INTERFACE WITH DECI MODEL
         trainer.connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
         trainer.connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
 
 
-        # BUILD NETWORK
-        trainer.build_model(student_architecture=cfg.student_architecture,
-                            teacher_architecture=cfg.teacher_architecture,
-                            arch_params=cfg.arch_params, student_arch_params=cfg.student_arch_params,
-                            teacher_arch_params=cfg.teacher_arch_params,
-                            checkpoint_params=cfg.checkpoint_params, run_teacher_on_eval=cfg.run_teacher_on_eval)
+        student = models.get(cfg.student_architecture, arch_params=cfg.student_arch_params,
+                             strict_load=cfg.student_checkpoint_params.strict_load,
+                             pretrained_weights=cfg.student_checkpoint_params.pretrained_weights,
+                             checkpoint_path=cfg.student_checkpoint_params.checkpoint_path,
+                             load_backbone=cfg.student_checkpoint_params.load_backbone)
+
+        teacher = models.get(cfg.teacher_architecture, arch_params=cfg.teacher_arch_params,
+                             strict_load=cfg.teacher_checkpoint_params.strict_load,
+                             pretrained_weights=cfg.teacher_checkpoint_params.pretrained_weights,
+                             checkpoint_path=cfg.teacher_checkpoint_params.checkpoint_path,
+                             load_backbone=cfg.teacher_checkpoint_params.load_backbone)
 
 
         # TRAIN
         # TRAIN
-        trainer.train(training_params=cfg.training_hyperparams)
+        trainer.train(training_params=cfg.training_hyperparams, student=student, teacher=teacher,
+                      kd_architecture=cfg.architecture, kd_arch_params=cfg.arch_params,
+                      run_teacher_on_eval=cfg.run_teacher_on_eval)
 
 
     def build_model(self,
     def build_model(self,
                     # noqa: C901 - too complex
                     # noqa: C901 - too complex
@@ -167,7 +178,8 @@ class KDTrainer(Trainer):
         load_kd_model_checkpoint = get_param(checkpoint_params, "load_checkpoint")
         load_kd_model_checkpoint = get_param(checkpoint_params, "load_checkpoint")
 
 
         # CHECK THAT TEACHER NETWORK HOLDS KNOWLEDGE FOR THE STUDENT TO LEARN FROM OR THAT WE ARE LOADING AN ENTIRE KD
         # CHECK THAT TEACHER NETWORK HOLDS KNOWLEDGE FOR THE STUDENT TO LEARN FROM OR THAT WE ARE LOADING AN ENTIRE KD
-        if not (teacher_pretrained_weights or teacher_checkpoint_path or load_kd_model_checkpoint or isinstance(teacher_architecture, torch.nn.Module)):
+        if not (teacher_pretrained_weights or teacher_checkpoint_path or load_kd_model_checkpoint or isinstance(
+                teacher_architecture, torch.nn.Module)):
             raise TeacherKnowledgeException()
             raise TeacherKnowledgeException()
 
 
     def _validate_num_classes(self, student_arch_params, teacher_arch_params):
     def _validate_num_classes(self, student_arch_params, teacher_arch_params):
@@ -229,6 +241,9 @@ class KDTrainer(Trainer):
 
 
         run_teacher_on_eval = get_param(kwargs, "run_teacher_on_eval", default_val=False)
         run_teacher_on_eval = get_param(kwargs, "run_teacher_on_eval", default_val=False)
 
 
+        return self._instantiate_kd_net(arch_params, architecture, run_teacher_on_eval, student, teacher)
+
+    def _instantiate_kd_net(self, arch_params, architecture, run_teacher_on_eval, student, teacher):
         if isinstance(architecture, str):
         if isinstance(architecture, str):
             architecture_cls = KD_ARCHITECTURES[architecture]
             architecture_cls = KD_ARCHITECTURES[architecture]
             net = architecture_cls(arch_params=arch_params, student=student, teacher=teacher,
             net = architecture_cls(arch_params=arch_params, student=student, teacher=teacher,
@@ -238,7 +253,6 @@ class KDTrainer(Trainer):
                                run_teacher_on_eval=run_teacher_on_eval)
                                run_teacher_on_eval=run_teacher_on_eval)
         else:
         else:
             net = architecture
             net = architecture
-
         return net
         return net
 
 
     def _load_checkpoint_to_model(self):
     def _load_checkpoint_to_model(self):
@@ -313,3 +327,29 @@ class KDTrainer(Trainer):
 
 
         state["net"] = best_net.state_dict()
         state["net"] = best_net.state_dict()
         self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
         self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
+
+    def train(self, model: KDModule = None, training_params: dict = dict(), student: SgModule = None,
+              teacher: torch.nn.Module = None, kd_architecture: Union[KDModule.__class__, str] = 'kd_module',
+              kd_arch_params: dict = dict(), run_teacher_on_eval=False, *args, **kwargs):
+        """
+        Trains the student network (wrapped in KDModule network).
+
+        :param model: KDModule, network to train. When none is given will initialize KDModule according to kd_architecture,
+            student and teacher (default=None)
+        :param training_params: dict, Same as in Trainer.train()
+        :param student: SgModule - the student trainer
+        :param teacher: torch.nn.Module- the teacher trainer
+        :param kd_architecture: KDModule architecture to use, currently only 'kd_module' is supported (default='kd_module').
+        :param kd_arch_params: architecture params to pas to kd_architecture constructor.
+        :param run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode)
+        """
+        kd_net = self.net or model
+        if kd_net is None:
+            if student is None or teacher is None:
+                raise ValueError("Must pass student and teacher models or net (KDModule).")
+            kd_net = self._instantiate_kd_net(arch_params=HpmStruct(**kd_arch_params),
+                                              architecture=kd_architecture,
+                                              run_teacher_on_eval=run_teacher_on_eval,
+                                              student=student,
+                                              teacher=teacher)
+        super(KDTrainer, self).train(model=kd_net, training_params=training_params)
Discard
@@ -19,3 +19,4 @@ from super_gradients.training.models.segmentation_models.shelfnet import *
 from super_gradients.training.models.classification_models.efficientnet import *
 from super_gradients.training.models.classification_models.efficientnet import *
 
 
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.models.all_architectures import ARCHITECTURES
+from super_gradients.training.models.model_factory import get
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
  1. from super_gradients.common import StrictLoad
  2. from super_gradients.training import utils as core_utils
  3. from super_gradients.training.models import SgModule
  4. from super_gradients.training.models.all_architectures import ARCHITECTURES
  5. from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
  6. from super_gradients.training.utils.checkpoint_utils import load_checkpoint_to_model, load_pretrained_weights, \
  7. read_ckpt_state_dict
  8. from super_gradients.common.abstractions.abstract_logger import get_logger
  9. logger = get_logger(__name__)
  10. def instantiate_model(name: str, arch_params: dict, pretrained_weights: str = None) -> SgModule:
  11. """
  12. Instantiates nn.Module according to architecture and arch_params, and handles pretrained weights and the required
  13. module manipulation (i.e head replacement).
  14. :param name: Defines the model's architecture from models/ALL_ARCHITECTURES
  15. :param arch_params: Architecture's parameters passed to models c'tor.
  16. :param pretrained_weights: string describing the dataset of the pretrained weights (for example "imagenent")
  17. :return: instantiated model i.e torch.nn.Module, architecture_class (will be none when architecture is not str)
  18. """
  19. if pretrained_weights is not None:
  20. if hasattr(arch_params, "num_classes"):
  21. num_classes_new_head = arch_params.num_classes
  22. else:
  23. num_classes_new_head = PRETRAINED_NUM_CLASSES[pretrained_weights]
  24. arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
  25. if isinstance(name, str) and name in ARCHITECTURES.keys():
  26. architecture_cls = ARCHITECTURES[name]
  27. net = architecture_cls(arch_params=arch_params)
  28. else:
  29. raise ValueError(
  30. "Unsupported model name " + str(name) + ", see docs or all_architectures.py for supported "
  31. "nets.")
  32. if pretrained_weights:
  33. load_pretrained_weights(net, name, pretrained_weights)
  34. if num_classes_new_head != arch_params.num_classes:
  35. net.replace_head(new_num_classes=num_classes_new_head)
  36. arch_params.num_classes = num_classes_new_head
  37. return net
  38. def get(name: str, arch_params: dict = {}, num_classes: int = None,
  39. strict_load: StrictLoad = StrictLoad.NO_KEY_MATCHING, checkpoint_path: str = None,
  40. pretrained_weights: str = None, load_backbone: bool = False) -> SgModule:
  41. """
  42. :param name: Defines the model's architecture from models/ALL_ARCHITECTURES
  43. :param num_classes: Number of classes (defines the net's structure). If None is given, will try to derrive from
  44. pretrained_weight's corresponding dataset.
  45. :param arch_params: Architecture hyper parameters. e.g.: block, num_blocks, etc.
  46. :param strict_load: See super_gradients.common.data_types.enum.strict_load.StrictLoad class documentation for details
  47. (default=NO_KEY_MATCHING to suport SG trained checkpoints)
  48. :param load_backbone: loads the provided checkpoint to model.backbone instead of model.
  49. :param checkpoint_path: The path to the external checkpoint to be loaded. Can be absolute or relative
  50. (ie: path/to/checkpoint.pth). If provided, will automatically attempt to
  51. load the checkpoint.
  52. :param pretrained_weights: a string describing the dataset of the pretrained weights (for example "imagenent").
  53. NOTE: Passing pretrained_weights and checkpoint_path is ill-defined and will raise an error.
  54. """
  55. if arch_params.get("num_classes") is not None:
  56. logger.warning("Passing num_classes through arch_params is dperecated and will be removed in the next version. "
  57. "Pass num_classes explicitly to models.get")
  58. num_classes = num_classes or arch_params.get("num_classes")
  59. if pretrained_weights is None and num_classes is None:
  60. raise ValueError("num_classes or pretrained_weights must be passed to determine net's structure.")
  61. if num_classes is not None:
  62. arch_params["num_classes"] = num_classes
  63. arch_params = core_utils.HpmStruct(**arch_params)
  64. net = instantiate_model(name, arch_params, pretrained_weights)
  65. if checkpoint_path:
  66. load_ema_as_net = 'ema_net' in read_ckpt_state_dict(ckpt_path=checkpoint_path).keys()
  67. _ = load_checkpoint_to_model(ckpt_local_path=checkpoint_path,
  68. load_backbone=load_backbone,
  69. net=net,
  70. strict=strict_load.value if hasattr(strict_load, "value") else strict_load,
  71. load_weights_only=True,
  72. load_ema_as_net=load_ema_as_net)
  73. return net
Discard
@@ -60,7 +60,10 @@ DEFAULT_TRAINING_PARAMS = {"lr_warmup_epochs": 0,
                                "calib_data_loader": None,
                                "calib_data_loader": None,
                                "num_calib_batches": 2,
                                "num_calib_batches": 2,
                                "percentile": 99.99
                                "percentile": 99.99
-                           }
+                           },
+                           "resume": False,
+                           "resume_path": None,
+                           "resume_strict_load": False
                            }
                            }
 
 
 DEFAULT_OPTIMIZER_PARAMS_SGD = {"weight_decay": 1e-4, "momentum": 0.9}
 DEFAULT_OPTIMIZER_PARAMS_SGD = {"weight_decay": 1e-4, "momentum": 0.9}
Discard
@@ -8,6 +8,7 @@ import hydra
 import numpy as np
 import numpy as np
 import pkg_resources
 import pkg_resources
 import torch
 import torch
+from deprecate import deprecated
 from omegaconf import DictConfig
 from omegaconf import DictConfig
 from torch import nn
 from torch import nn
 from torch.utils.data import DataLoader, DistributedSampler
 from torch.utils.data import DataLoader, DistributedSampler
@@ -29,7 +30,7 @@ from super_gradients.common.factories.metrics_factory import MetricsFactory
 from super_gradients.common.sg_loggers import SG_LOGGERS
 from super_gradients.common.sg_loggers import SG_LOGGERS
 from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
 from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
-from super_gradients.training import utils as core_utils
+from super_gradients.training import utils as core_utils, models
 from super_gradients.training.models import SgModule
 from super_gradients.training.models import SgModule
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.utils import sg_trainer_utils
 from super_gradients.training.utils import sg_trainer_utils
@@ -206,10 +207,17 @@ class Trainer:
         trainer.connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
         trainer.connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
 
 
         # BUILD NETWORK
         # BUILD NETWORK
-        trainer.build_model(cfg.architecture, arch_params=cfg.arch_params, checkpoint_params=cfg.checkpoint_params)
+        model = models.get(name=cfg.architecture,
+                           num_classes=cfg.arch_params.num_classes,
+                           arch_params=cfg.arch_params,
+                           strict_load=cfg.checkpoint_params.strict_load,
+                           pretrained_weights=cfg.checkpoint_params.pretrained_weights,
+                           checkpoint_path=cfg.checkpoint_params.checkpoint_path,
+                           load_backbone=cfg.checkpoint_params.load_backbone
+                           )
 
 
         # TRAIN
         # TRAIN
-        trainer.train(training_params=cfg.training_hyperparams)
+        trainer.train(model=model, training_params=cfg.training_hyperparams)
 
 
     def _set_dataset_properties(self, classes, test_loader, train_loader, valid_loader):
     def _set_dataset_properties(self, classes, test_loader, train_loader, valid_loader):
         if any([train_loader, valid_loader, classes]) and not all([train_loader, valid_loader, classes]):
         if any([train_loader, valid_loader, classes]) and not all([train_loader, valid_loader, classes]):
@@ -225,7 +233,8 @@ class Trainer:
             if not all([isinstance(train_loader.sampler, DistributedSampler),
             if not all([isinstance(train_loader.sampler, DistributedSampler),
                         isinstance(valid_loader.sampler, DistributedSampler),
                         isinstance(valid_loader.sampler, DistributedSampler),
                         test_loader is None or isinstance(test_loader.sampler, DistributedSampler)]):
                         test_loader is None or isinstance(test_loader.sampler, DistributedSampler)]):
-                logger.warning("DDP training was selected but the dataloader samplers are not of type DistributedSamplers")
+                logger.warning(
+                    "DDP training was selected but the dataloader samplers are not of type DistributedSamplers")
 
 
         self.dataset_params, self.train_loader, self.valid_loader, self.test_loader, self.classes = \
         self.dataset_params, self.train_loader, self.valid_loader, self.test_loader, self.classes = \
             HpmStruct(**dataset_params), train_loader, valid_loader, test_loader, classes
             HpmStruct(**dataset_params), train_loader, valid_loader, test_loader, classes
@@ -248,6 +257,7 @@ class Trainer:
         self.dataset_params = self.dataset_interface.get_dataset_params()
         self.dataset_params = self.dataset_interface.get_dataset_params()
 
 
     # FIXME - we need to resolve flake8's 'function is too complex' for this function
     # FIXME - we need to resolve flake8's 'function is too complex' for this function
+    @deprecated(target=None, deprecated_in='2.3.0', remove_in='3.0.0')
     def build_model(self,  # noqa: C901 - too complex
     def build_model(self,  # noqa: C901 - too complex
                     architecture: Union[str, nn.Module],
                     architecture: Union[str, nn.Module],
                     arch_params={}, checkpoint_params={}, *args, **kwargs):
                     arch_params={}, checkpoint_params={}, *args, **kwargs):
@@ -521,8 +531,32 @@ class Trainer:
     def _save_best_checkpoint(self, epoch, state):
     def _save_best_checkpoint(self, epoch, state):
         self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
         self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
 
 
+    def _prep_net_for_train(self):
+        if self.arch_params is None:
+            default_arch_params = HpmStruct(sync_bn=False)
+            arch_params = getattr(self.net, "arch_params", default_arch_params)
+            self.arch_params = default_arch_params
+            if arch_params is not None:
+                self.arch_params.override(**arch_params.to_dict())
+
+        # TODO: REMOVE THE BELOW LINE (FOR BACKWARD COMPATIBILITY)
+        if self.checkpoint_params is None:
+            self.checkpoint_params = HpmStruct(load_checkpoint=self.training_params.resume)
+
+        self._net_to_device()
+
+        # SET THE FLAG FOR DIFFERENT PARAMETER GROUP OPTIMIZER UPDATE
+        self.update_param_groups = hasattr(self.net.module, 'update_param_groups')
+
+        self.checkpoint = {}
+        self.strict_load = core_utils.get_param(self.training_params, "resume_strict_load", StrictLoad.ON)
+        self.load_ema_as_net = False
+        self.load_checkpoint = core_utils.get_param(self.training_params, "resume", False)
+        self.external_checkpoint_path = core_utils.get_param(self.training_params, "resume_path")
+        self._load_checkpoint_to_model()
+
     # FIXME - we need to resolve flake8's 'function is too complex' for this function
     # FIXME - we need to resolve flake8's 'function is too complex' for this function
-    def train(self, training_params: dict = dict()):  # noqa: C901
+    def train(self, model: nn.Module = None, training_params: dict = dict(), *args, **kwargs):  # noqa: C901
         """
         """
 
 
         train - Trains the Model
         train - Trains the Model
@@ -531,6 +565,8 @@ class Trainer:
           the data loaders, as dictionary. The phase context will hold the additional items, under an attribute with
           the data loaders, as dictionary. The phase context will hold the additional items, under an attribute with
           the same name as the key in this dictionary. Then such items can be accessed through phase callbacks.
           the same name as the key in this dictionary. Then such items can be accessed through phase callbacks.
 
 
+            :param model: torch.nn.Module, model to train. When none is given will attempt to use self.net
+             (SEE BUILD_MODEL DEPRECATION) (default=None).
 
 
             :param training_params:
             :param training_params:
                 - `max_epochs` : int
                 - `max_epochs` : int
@@ -798,14 +834,16 @@ class Trainer:
         """
         """
         global logger
         global logger
 
 
-        if self.net is None:
-            raise Exception('Model', 'No model found')
         if self.dataset_interface is None and self.train_loader is None:
         if self.dataset_interface is None and self.train_loader is None:
             raise Exception('Data', 'No dataset found')
             raise Exception('Data', 'No dataset found')
 
 
         self.training_params = TrainingParams()
         self.training_params = TrainingParams()
         self.training_params.override(**training_params)
         self.training_params.override(**training_params)
 
 
+        if self.net is None:
+            self.net = model
+            self._prep_net_for_train()
+
         # SET RANDOM SEED
         # SET RANDOM SEED
         random_seed(is_ddp=self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL,
         random_seed(is_ddp=self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL,
                     device=self.device, seed=self.training_params.seed)
                     device=self.device, seed=self.training_params.seed)
@@ -945,7 +983,8 @@ class Trainer:
             load_opt_params = False
             load_opt_params = False
 
 
         if isinstance(self.training_params.optimizer, str) or \
         if isinstance(self.training_params.optimizer, str) or \
-                (inspect.isclass(self.training_params.optimizer) and issubclass(self.training_params.optimizer, torch.optim.Optimizer)):
+                (inspect.isclass(self.training_params.optimizer) and issubclass(self.training_params.optimizer,
+                                                                                torch.optim.Optimizer)):
             self.optimizer = build_optimizer(net=self.net, lr=self.training_params.initial_lr,
             self.optimizer = build_optimizer(net=self.net, lr=self.training_params.initial_lr,
                                              training_params=self.training_params)
                                              training_params=self.training_params)
         elif isinstance(self.training_params.optimizer, torch.optim.Optimizer):
         elif isinstance(self.training_params.optimizer, torch.optim.Optimizer):
@@ -964,8 +1003,10 @@ class Trainer:
 
 
         self._initialize_mixed_precision(self.training_params.mixed_precision)
         self._initialize_mixed_precision(self.training_params.mixed_precision)
 
 
-        self._infinite_train_loader = (hasattr(self.train_loader, "sampler") and isinstance(self.train_loader.sampler, InfiniteSampler)) or \
-                                      (hasattr(self.train_loader, "batch_sampler") and isinstance(self.train_loader.batch_sampler.sampler, InfiniteSampler))
+        self._infinite_train_loader = (hasattr(self.train_loader, "sampler") and isinstance(self.train_loader.sampler,
+                                                                                            InfiniteSampler)) or \
+                                      (hasattr(self.train_loader, "batch_sampler") and isinstance(
+                                          self.train_loader.batch_sampler.sampler, InfiniteSampler))
 
 
         self.ckpt_best_name = self.training_params.ckpt_best_name
         self.ckpt_best_name = self.training_params.ckpt_best_name
 
 
@@ -986,9 +1027,8 @@ class Trainer:
                                metric_idx_in_results_tuple=self.metric_idx_in_results_tuple,
                                metric_idx_in_results_tuple=self.metric_idx_in_results_tuple,
                                metric_to_watch=self.metric_to_watch,
                                metric_to_watch=self.metric_to_watch,
                                device=self.device,
                                device=self.device,
-                               context_methods=self._get_context_methods(Phase.PRE_TRAINING)
-                               )
-
+                               context_methods=self._get_context_methods(Phase.PRE_TRAINING),
+                               ema_model=self.ema_model)
         self.phase_callback_handler(Phase.PRE_TRAINING, context)
         self.phase_callback_handler(Phase.PRE_TRAINING, context)
 
 
         try:
         try:
@@ -1479,16 +1519,13 @@ class Trainer:
         lr_dict = {lr_titles[i]: lrs[i] for i in range(len(lrs))}
         lr_dict = {lr_titles[i]: lrs[i] for i in range(len(lrs))}
         self.sg_logger.add_scalars(tag_scalar_dict=lr_dict, global_step=epoch)
         self.sg_logger.add_scalars(tag_scalar_dict=lr_dict, global_step=epoch)
 
 
-    def test(self,  # noqa: C901
-             test_loader: torch.utils.data.DataLoader = None,
-             loss: torch.nn.modules.loss._Loss = None,
-             silent_mode: bool = False,
-             test_metrics_list=None,
+    def test(self, model: nn.Module = None, test_loader: torch.utils.data.DataLoader = None,
+             loss: torch.nn.modules.loss._Loss = None, silent_mode: bool = False, test_metrics_list=None,
              loss_logging_items_names=None, metrics_progress_verbose=False, test_phase_callbacks=None,
              loss_logging_items_names=None, metrics_progress_verbose=False, test_phase_callbacks=None,
              use_ema_net=True) -> tuple:
              use_ema_net=True) -> tuple:
         """
         """
         Evaluates the model on given dataloader and metrics.
         Evaluates the model on given dataloader and metrics.
-
+        :param model: model to perfrom test on. When none is given, will try to use self.net (defalut=None).
         :param test_loader: dataloader to perform test on.
         :param test_loader: dataloader to perform test on.
         :param test_metrics_list: (list(torchmetrics.Metric)) metrics list for evaluation.
         :param test_metrics_list: (list(torchmetrics.Metric)) metrics list for evaluation.
         :param silent_mode: (bool) controls verbosity
         :param silent_mode: (bool) controls verbosity
@@ -1501,6 +1538,8 @@ class Trainer:
          is ran on self.test_loader with self.test_metrics.
          is ran on self.test_loader with self.test_metrics.
         """
         """
 
 
+        self.net = model or self.net
+
         # IN CASE TRAINING WAS PERFROMED BEFORE TEST- MAKE SURE TO TEST THE EMA MODEL (UNLESS SPECIFIED OTHERWISE BY
         # IN CASE TRAINING WAS PERFROMED BEFORE TEST- MAKE SURE TO TEST THE EMA MODEL (UNLESS SPECIFIED OTHERWISE BY
         # use_ema_net)
         # use_ema_net)
 
 
Discard
@@ -20,7 +20,6 @@ logger = get_logger(__name__)
 try:
 try:
     from deci_lab_client.client import DeciPlatformClient
     from deci_lab_client.client import DeciPlatformClient
     from deci_lab_client.models import ModelBenchmarkState
     from deci_lab_client.models import ModelBenchmarkState
-    from deci_lab_client.models.model_metadata import ModelMetadata
 
 
     _imported_deci_lab_failure = None
     _imported_deci_lab_failure = None
 except (ImportError, NameError, ModuleNotFoundError) as import_err:
 except (ImportError, NameError, ModuleNotFoundError) as import_err:
@@ -63,7 +62,7 @@ class PhaseContext:
                  train_loader=None, valid_loader=None,
                  train_loader=None, valid_loader=None,
                  training_params=None, ddp_silent_mode=None, checkpoint_params=None, architecture=None,
                  training_params=None, ddp_silent_mode=None, checkpoint_params=None, architecture=None,
                  arch_params=None, metric_idx_in_results_tuple=None,
                  arch_params=None, metric_idx_in_results_tuple=None,
-                 metric_to_watch=None, valid_metrics=None, context_methods=None):
+                 metric_to_watch=None, valid_metrics=None, context_methods=None, ema_model=None):
         self.epoch = epoch
         self.epoch = epoch
         self.batch_idx = batch_idx
         self.batch_idx = batch_idx
         self.optimizer = optimizer
         self.optimizer = optimizer
@@ -93,6 +92,7 @@ class PhaseContext:
         self.metric_to_watch = metric_to_watch
         self.metric_to_watch = metric_to_watch
         self.valid_metrics = valid_metrics
         self.valid_metrics = valid_metrics
         self.context_methods = context_methods
         self.context_methods = context_methods
+        self.ema_model = ema_model
 
 
     def update_context(self, **kwargs):
     def update_context(self, **kwargs):
         for attr, attr_val in kwargs.items():
         for attr, attr_val in kwargs.items():
@@ -137,7 +137,7 @@ class ModelConversionCheckCallback(PhaseCallback):
         :param atol (default=1e-05)
         :param atol (default=1e-05)
     """
     """
 
 
-    def __init__(self, model_meta_data: ModelMetadata, **kwargs):
+    def __init__(self, model_meta_data, **kwargs):
         super(ModelConversionCheckCallback, self).__init__(phase=Phase.PRE_TRAINING)
         super(ModelConversionCheckCallback, self).__init__(phase=Phase.PRE_TRAINING)
         self.model_meta_data = model_meta_data
         self.model_meta_data = model_meta_data
 
 
Discard
@@ -62,7 +62,7 @@ def adaptive_load_state_dict(net: torch.nn.Module, state_dict: dict, strict: str
     @return:
     @return:
     """
     """
     try:
     try:
-        net.load_state_dict(state_dict['net'], strict=strict)
+        net.load_state_dict(state_dict['net'] if 'net' in state_dict.keys() else state_dict, strict=strict)
     except (RuntimeError, ValueError, KeyError) as ex:
     except (RuntimeError, ValueError, KeyError) as ex:
         if strict == 'no_key_matching':
         if strict == 'no_key_matching':
             adapted_state_dict = adapt_state_dict_to_fit_model_layer_names(net.state_dict(), state_dict)
             adapted_state_dict = adapt_state_dict_to_fit_model_layer_names(net.state_dict(), state_dict)
Discard
@@ -6,7 +6,6 @@ from tests.unit_tests import ZeroWdForBnBiasTest, SaveCkptListUnitTest, TestAver
     TestRepVgg, TestWithoutTrainTest, OhemLossTest, EarlyStopTest, SegmentationTransformsTest, \
     TestRepVgg, TestWithoutTrainTest, OhemLossTest, EarlyStopTest, SegmentationTransformsTest, \
     TestConvBnRelu, FactoriesTest, InitializeWithDataloadersTest
     TestConvBnRelu, FactoriesTest, InitializeWithDataloadersTest
 from tests.end_to_end_tests import TestTrainer
 from tests.end_to_end_tests import TestTrainer
-from tests.unit_tests.load_checkpoint_from_direct_path_test import LoadCheckpointFromDirectPathTest
 from tests.unit_tests.phase_delegates_test import ContextMethodsTest
 from tests.unit_tests.phase_delegates_test import ContextMethodsTest
 from tests.unit_tests.random_erase_test import RandomEraseTest
 from tests.unit_tests.random_erase_test import RandomEraseTest
 from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
 from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
@@ -51,7 +50,6 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestRepVgg))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestRepVgg))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestWithoutTrainTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestWithoutTrainTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(StrictLoadEnumTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(StrictLoadEnumTest))
-        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LoadCheckpointFromDirectPathTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TrainWithInitializedObjectsTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TrainWithInitializedObjectsTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(RandomEraseTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(RandomEraseTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(OhemLossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(OhemLossTest))
Discard
@@ -1,5 +1,7 @@
 import unittest
 import unittest
 
 
+from super_gradients.training import models
+
 import super_gradients
 import super_gradients
 
 
 from super_gradients import Trainer
 from super_gradients import Trainer
@@ -12,8 +14,8 @@ class TestCifar10Trainer(unittest.TestCase):
         trainer = Trainer("test", model_checkpoints_location='local')
         trainer = Trainer("test", model_checkpoints_location='local')
         cifar_10_dataset_interface = LibraryDatasetInterface(name="cifar10")
         cifar_10_dataset_interface = LibraryDatasetInterface(name="cifar10")
         trainer.connect_dataset_interface(cifar_10_dataset_interface)
         trainer.connect_dataset_interface(cifar_10_dataset_interface)
-        trainer.build_model("resnet18_cifar", arch_params={'num_classes': 10})
-        trainer.train(training_params={"max_epochs": 1})
+        model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
+        trainer.train(model=model, training_params={"max_epochs": 1})
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -4,7 +4,7 @@ import unittest
 import numpy as np
 import numpy as np
 from PIL import Image
 from PIL import Image
 import tensorflow.keras as keras
 import tensorflow.keras as keras
-from super_gradients.training import MultiGPUMode
+from super_gradients.training import MultiGPUMode, models
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ExternalDatasetInterface, \
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ExternalDatasetInterface, \
     ImageNetDatasetInterface
     ImageNetDatasetInterface
@@ -127,8 +127,8 @@ class TestExternalDatasetInterface(unittest.TestCase):
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(dataset_interface=self.test_external_dataset_interface,
         trainer.connect_dataset_interface(dataset_interface=self.test_external_dataset_interface,
                                           data_loader_num_workers=8)
                                           data_loader_num_workers=8)
-        trainer.build_model("resnet50", arch_params)
-        trainer.train(training_params=train_params)
+        model = models.get("resnet50", arch_params)
+        trainer.train(model=model, training_params=train_params)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -1,6 +1,8 @@
 import shutil
 import shutil
 import unittest
 import unittest
 
 
+from super_gradients.training import models
+
 import super_gradients
 import super_gradients
 import torch
 import torch
 import os
 import os
@@ -37,53 +39,31 @@ class TestTrainer(unittest.TestCase):
     def get_classification_trainer(name=''):
     def get_classification_trainer(name=''):
         trainer = Trainer(name, model_checkpoints_location='local')
         trainer = Trainer(name, model_checkpoints_location='local')
         dataset_params = {"batch_size": 4}
         dataset_params = {"batch_size": 4}
-        dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
+        dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params, image_size=224)
         trainer.connect_dataset_interface(dataset)
         trainer.connect_dataset_interface(dataset)
-        trainer.build_model("resnet18_cifar")
-        return trainer
+        model = models.get("resnet18", arch_params={"num_classes": 5})
+        return trainer, model
 
 
     def test_train(self):
     def test_train(self):
-        trainer = self.get_classification_trainer(self.folder_names[0])
-        trainer.train(training_params=self.training_params)
+        trainer, model = self.get_classification_trainer(self.folder_names[0])
+        trainer.train(model=model, training_params=self.training_params)
 
 
     def test_save_load(self):
     def test_save_load(self):
-        trainer = self.get_classification_trainer(self.folder_names[1])
-        trainer.train(training_params=self.training_params)
-        trainer.build_model("resnet18_cifar", checkpoint_params={'load_checkpoint': True})
-
-    def test_load_only_weights_from_ckpt(self):
-        # Create a checkpoint with 100% accuracy
-        trainer = self.get_classification_trainer(self.folder_names[2])
-        params = self.training_params.copy()
+        trainer, model = self.get_classification_trainer(self.folder_names[1])
+        trainer.train(model=model, training_params=self.training_params)
 
 
-        params['max_epochs'] = 3
-        trainer.train(training_params=params)
-        # Build a model that continues the training
-        trainer = self.get_classification_trainer(self.folder_names[3])
-        trainer.build_model('resnet18_cifar', checkpoint_params={"load_checkpoint": True, "load_weights_only": False,
-                                                                 "source_ckpt_folder_name": self.folder_names[2]}
-                            )
-        self.assertTrue(trainer.best_metric > -1)
-        self.assertTrue(trainer.start_epoch != 0)
-        # start_epoch is not initialized, adding to max_epochs
-        self.training_params['max_epochs'] += 3
-        trainer.train(training_params=self.training_params)
-        # Build a model that loads the weights and starts from scratch
-        trainer = self.get_classification_trainer(self.folder_names[4])
-        trainer.build_model('resnet18_cifar', checkpoint_params={"load_checkpoint": True, "load_weights_only": True,
-                                                                 "source_ckpt_folder_name": self.folder_names[2]}
-                            )
-        self.assertTrue(trainer.best_metric == -1)
-        self.assertTrue(trainer.start_epoch == 0)
-        self.training_params['max_epochs'] += 3
-        trainer.train(training_params=self.training_params)
+        resume_training_params = self.training_params.copy()
+        resume_training_params["resume"] = True
+        resume_training_params["max_epochs"] = 2
+        trainer, model = self.get_classification_trainer(self.folder_names[1])
+        trainer.train(model=model, training_params=resume_training_params)
 
 
     def test_checkpoint_content(self):
     def test_checkpoint_content(self):
         """VERIFY THAT ALL CHECKPOINTS ARE SAVED AND CONTAIN ALL THE EXPECTED KEYS"""
         """VERIFY THAT ALL CHECKPOINTS ARE SAVED AND CONTAIN ALL THE EXPECTED KEYS"""
-        trainer = self.get_classification_trainer(self.folder_names[5])
+        trainer, model = self.get_classification_trainer(self.folder_names[5])
         params = self.training_params.copy()
         params = self.training_params.copy()
         params["save_ckpt_epoch_list"] = [1]
         params["save_ckpt_epoch_list"] = [1]
-        trainer.train(training_params=params)
+        trainer.train(model=model, training_params=params)
         ckpt_filename = ['ckpt_best.pth', 'ckpt_latest.pth', 'ckpt_epoch_1.pth']
         ckpt_filename = ['ckpt_best.pth', 'ckpt_latest.pth', 'ckpt_epoch_1.pth']
         ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
         ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
         for ckpt_path in ckpt_paths:
         for ckpt_path in ckpt_paths:
Discard
@@ -2,6 +2,8 @@ import unittest
 from enum import Enum
 from enum import Enum
 import re
 import re
 
 
+from super_gradients.training import models
+
 from super_gradients import (
 from super_gradients import (
     Trainer,
     Trainer,
     ClassificationTestDatasetInterface,
     ClassificationTestDatasetInterface,
@@ -72,9 +74,9 @@ class ConversionCallbackTest(unittest.TestCase):
             dataset = ClassificationTestDatasetInterface(dataset_params={"batch_size": 10})
             dataset = ClassificationTestDatasetInterface(dataset_params={"batch_size": 10})
 
 
             trainer.connect_dataset_interface(dataset, data_loader_num_workers=0)
             trainer.connect_dataset_interface(dataset, data_loader_num_workers=0)
-            trainer.build_model(architecture=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
+            model = models.get(architecture=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
             try:
             try:
-                trainer.train(train_params)
+                trainer.train(model=model, training_params=train_params)
             except Exception as e:
             except Exception as e:
                 self.fail(f"Model training didn't succeed due to {e}")
                 self.fail(f"Model training didn't succeed due to {e}")
             else:
             else:
@@ -105,7 +107,7 @@ class ConversionCallbackTest(unittest.TestCase):
             dataset = SegmentationTestDatasetInterface(dataset_params={"batch_size": 10})
             dataset = SegmentationTestDatasetInterface(dataset_params={"batch_size": 10})
             trainer = Trainer(f"{architecture}_example", model_checkpoints_location="local", ckpt_root_dir=checkpoint_dir)
             trainer = Trainer(f"{architecture}_example", model_checkpoints_location="local", ckpt_root_dir=checkpoint_dir)
             trainer.connect_dataset_interface(dataset, data_loader_num_workers=0)
             trainer.connect_dataset_interface(dataset, data_loader_num_workers=0)
-            trainer.build_model(architecture=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
+            model = models.get(architecture=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
 
 
             phase_callbacks = [
             phase_callbacks = [
                 ModelConversionCheckCallback(model_meta_data=model_meta_data, opset_version=11, rtol=1, atol=1),
                 ModelConversionCheckCallback(model_meta_data=model_meta_data, opset_version=11, rtol=1, atol=1),
@@ -129,7 +131,7 @@ class ConversionCallbackTest(unittest.TestCase):
             train_params.update(custom_config)
             train_params.update(custom_config)
 
 
             try:
             try:
-                trainer.train(train_params)
+                trainer.train(model=model, training_params=train_params)
             except Exception as e:
             except Exception as e:
                 self.fail(f"Model training didn't succeed for {architecture} due to {e}")
                 self.fail(f"Model training didn't succeed for {architecture} due to {e}")
             else:
             else:
Discard
@@ -13,9 +13,6 @@ class DeciLabUploadTest(unittest.TestCase):
         self.trainer = Trainer("deci_lab_export_test_model", model_checkpoints_location='local')
         self.trainer = Trainer("deci_lab_export_test_model", model_checkpoints_location='local')
         dataset = ClassificationTestDatasetInterface(dataset_params={"batch_size": 10})
         dataset = ClassificationTestDatasetInterface(dataset_params={"batch_size": 10})
         self.trainer.connect_dataset_interface(dataset)
         self.trainer.connect_dataset_interface(dataset)
-        net = ResNet18(num_classes=5, arch_params={})
-        self.optimizer = SGD(params=net.parameters(), lr=0.1)
-        self.trainer.build_model(net)
 
 
     def test_train_with_deci_lab_integration(self):
     def test_train_with_deci_lab_integration(self):
         model_meta_data = ModelMetadata(name='model_for_deci_lab_upload_test',
         model_meta_data = ModelMetadata(name='model_for_deci_lab_upload_test',
@@ -42,6 +39,7 @@ class DeciLabUploadTest(unittest.TestCase):
                                                   model_meta_data=model_meta_data,
                                                   model_meta_data=model_meta_data,
                                                   optimization_request_form=optimization_request_form)
                                                   optimization_request_form=optimization_request_form)
 
 
+        net = ResNet18(num_classes=5, arch_params={})
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": self.optimizer,
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": self.optimizer,
                         "criterion_params": {},
                         "criterion_params": {},
@@ -49,8 +47,9 @@ class DeciLabUploadTest(unittest.TestCase):
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True,
                         "greater_metric_to_watch_is_better": True,
                         "phase_callbacks": [model_conversion_callback, deci_lab_callback]}
                         "phase_callbacks": [model_conversion_callback, deci_lab_callback]}
+        self.optimizer = SGD(params=net.parameters(), lr=0.1)
 
 
-        self.trainer.train(train_params)
+        self.trainer.train(model=net, training_params=train_params)
 
 
         # CLEANUP
         # CLEANUP
 
 
Discard
@@ -1,5 +1,5 @@
 from super_gradients import ClassificationTestDatasetInterface
 from super_gradients import ClassificationTestDatasetInterface
-from super_gradients.training import MultiGPUMode
+from super_gradients.training import MultiGPUMode, models
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 import unittest
 import unittest
@@ -27,7 +27,7 @@ class EMAIntegrationTest(unittest.TestCase):
                                device='cpu', multi_gpu=MultiGPUMode.OFF)
                                device='cpu', multi_gpu=MultiGPUMode.OFF)
         dataset_interface = ClassificationTestDatasetInterface({"batch_size": 32})
         dataset_interface = ClassificationTestDatasetInterface({"batch_size": 32})
         self.trainer.connect_dataset_interface(dataset_interface, 8)
         self.trainer.connect_dataset_interface(dataset_interface, 8)
-        self.trainer.build_model("resnet18_cifar")
+        self.model = models.get("resnet18_cifar", arch_params={"num_classes": 5})
 
 
     @classmethod
     @classmethod
     def tearDownClass(cls) -> None:
     def tearDownClass(cls) -> None:
@@ -65,7 +65,7 @@ class EMAIntegrationTest(unittest.TestCase):
         self.trainer.test = CallWrapper(self.trainer.test, check_before=before_test)
         self.trainer.test = CallWrapper(self.trainer.test, check_before=before_test)
         self.trainer._train_epoch = CallWrapper(self.trainer._train_epoch, check_before=before_train_epoch)
         self.trainer._train_epoch = CallWrapper(self.trainer._train_epoch, check_before=before_train_epoch)
 
 
-        self.trainer.train(training_params=training_params)
+        self.trainer.train(model=self.model, training_params=training_params)
 
 
         self.assertIsNotNone(self.trainer.ema_model)
         self.assertIsNotNone(self.trainer.ema_model)
 
 
Discard
@@ -1,6 +1,9 @@
 import shutil
 import shutil
 import unittest
 import unittest
 import os
 import os
+
+from super_gradients.training import models
+
 from super_gradients import Trainer, ClassificationTestDatasetInterface
 from super_gradients import Trainer, ClassificationTestDatasetInterface
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 
 
@@ -30,33 +33,33 @@ class LRTest(unittest.TestCase):
         dataset_params = {"batch_size": 4}
         dataset_params = {"batch_size": 4}
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         trainer.connect_dataset_interface(dataset)
         trainer.connect_dataset_interface(dataset)
-        trainer.build_model("resnet18_cifar")
-        return trainer
+        model = models.get("resnet18_cifar", arch_params={"num_classes": 5})
+        return trainer, model
 
 
     def test_function_lr(self):
     def test_function_lr(self):
-        trainer = self.get_trainer(self.folder_name)
+        trainer, model = self.get_trainer(self.folder_name)
 
 
         def test_lr_function(initial_lr, epoch, iter, max_epoch, iters_per_epoch, **kwargs):
         def test_lr_function(initial_lr, epoch, iter, max_epoch, iters_per_epoch, **kwargs):
             return initial_lr * (1 - ((epoch * iters_per_epoch + iter) / (max_epoch * iters_per_epoch)))
             return initial_lr * (1 - ((epoch * iters_per_epoch + iter) / (max_epoch * iters_per_epoch)))
 
 
         # test if we are able that lr_function supports functions with this structure
         # test if we are able that lr_function supports functions with this structure
         training_params = {**self.training_params, "lr_mode": "function", "lr_schedule_function": test_lr_function}
         training_params = {**self.training_params, "lr_mode": "function", "lr_schedule_function": test_lr_function}
-        trainer.train(training_params=training_params)
+        trainer.train(model=model, training_params=training_params)
 
 
         # test that we assert lr_function is callable
         # test that we assert lr_function is callable
         training_params = {**self.training_params, "lr_mode": "function"}
         training_params = {**self.training_params, "lr_mode": "function"}
         with self.assertRaises(AssertionError):
         with self.assertRaises(AssertionError):
-            trainer.train(training_params=training_params)
+            trainer.train(model=model, training_params=training_params)
 
 
     def test_cosine_lr(self):
     def test_cosine_lr(self):
-        trainer = self.get_trainer(self.folder_name)
+        trainer, model = self.get_trainer(self.folder_name)
         training_params = {**self.training_params, "lr_mode": "cosine", "cosine_final_lr_ratio": 0.01}
         training_params = {**self.training_params, "lr_mode": "cosine", "cosine_final_lr_ratio": 0.01}
-        trainer.train(training_params=training_params)
+        trainer.train(model=model, training_params=training_params)
 
 
     def test_step_lr(self):
     def test_step_lr(self):
-        trainer = self.get_trainer(self.folder_name)
+        trainer, model = self.get_trainer(self.folder_name)
         training_params = {**self.training_params, "lr_mode": "step", "lr_decay_factor": 0.1, "lr_updates": [4]}
         training_params = {**self.training_params, "lr_mode": "step", "lr_decay_factor": 0.1, "lr_updates": [4]}
-        trainer.train(training_params=training_params)
+        trainer.train(model=model, training_params=training_params)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -20,6 +20,7 @@ from super_gradients.training.models.detection_models.yolo_base import YoloPostP
 from super_gradients.training.utils.detection_utils import DetectionCollateFN, CrowdDetectionCollateFN
 from super_gradients.training.utils.detection_utils import DetectionCollateFN, CrowdDetectionCollateFN
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CoCoDetectionDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CoCoDetectionDatasetInterface
 from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
 from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
+from super_gradients.training import models
 
 
 
 
 class PretrainedModelsTest(unittest.TestCase):
 class PretrainedModelsTest(unittest.TestCase):
@@ -83,11 +84,13 @@ class PretrainedModelsTest(unittest.TestCase):
                                             'coco_ssd_mobilenet_v1': {'num_classes': 80}}
                                             'coco_ssd_mobilenet_v1': {'num_classes': 80}}
         self.coco_pretrained_ckpt_params = {"pretrained_weights": "coco"}
         self.coco_pretrained_ckpt_params = {"pretrained_weights": "coco"}
 
 
-        from super_gradients.training.transforms.transforms import DetectionMosaic, DetectionMixup, DetectionRandomAffine, \
+        from super_gradients.training.transforms.transforms import DetectionMosaic, DetectionMixup, \
+            DetectionRandomAffine, \
             DetectionTargetsFormatTransform, DetectionPaddedRescale, DetectionHSV, DetectionHorizontalFlip
             DetectionTargetsFormatTransform, DetectionPaddedRescale, DetectionHSV, DetectionHorizontalFlip
 
 
         yolox_train_transforms = [DetectionMosaic(input_dim=(640, 640), prob=1.0),
         yolox_train_transforms = [DetectionMosaic(input_dim=(640, 640), prob=1.0),
-                                  DetectionRandomAffine(degrees=10., translate=0.1, scales=[0.1, 2], shear=2.0, target_size=(640, 640),
+                                  DetectionRandomAffine(degrees=10., translate=0.1, scales=[0.1, 2], shear=2.0,
+                                                        target_size=(640, 640),
                                                         filter_box_candidates=False, wh_thr=0, area_thr=0, ar_thr=0),
                                                         filter_box_candidates=False, wh_thr=0, area_thr=0, ar_thr=0),
                                   DetectionMixup(input_dim=(640, 640), mixup_scale=[0.5, 1.5], prob=1.0, flip_prob=0.5),
                                   DetectionMixup(input_dim=(640, 640), mixup_scale=[0.5, 1.5], prob=1.0, flip_prob=0.5),
                                   DetectionHSV(prob=1.0, hgain=5, sgain=30, vgain=30),
                                   DetectionHSV(prob=1.0, hgain=5, sgain=30, vgain=30),
@@ -95,10 +98,12 @@ class PretrainedModelsTest(unittest.TestCase):
                                   DetectionPaddedRescale(input_dim=(640, 640), max_targets=120),
                                   DetectionPaddedRescale(input_dim=(640, 640), max_targets=120),
                                   DetectionTargetsFormatTransform(output_format=DetectionTargetsFormat.LABEL_CXCYWH)]
                                   DetectionTargetsFormatTransform(output_format=DetectionTargetsFormat.LABEL_CXCYWH)]
         yolox_val_transforms = [DetectionPaddedRescale(input_dim=(640, 640)),
         yolox_val_transforms = [DetectionPaddedRescale(input_dim=(640, 640)),
-                                DetectionTargetsFormatTransform(max_targets=50, output_format=DetectionTargetsFormat.LABEL_CXCYWH)]
+                                DetectionTargetsFormatTransform(max_targets=50,
+                                                                output_format=DetectionTargetsFormat.LABEL_CXCYWH)]
 
 
         ssd_train_transforms = [DetectionMosaic(input_dim=(640, 640), prob=1.0),
         ssd_train_transforms = [DetectionMosaic(input_dim=(640, 640), prob=1.0),
-                                DetectionRandomAffine(degrees=0., translate=0.1, scales=[0.5, 1.5], shear=.0, target_size=(640, 640),
+                                DetectionRandomAffine(degrees=0., translate=0.1, scales=[0.5, 1.5], shear=.0,
+                                                      target_size=(640, 640),
                                                       filter_box_candidates=True, wh_thr=2, area_thr=0.1, ar_thr=20),
                                                       filter_box_candidates=True, wh_thr=2, area_thr=0.1, ar_thr=20),
                                 DetectionMixup(input_dim=(640, 640), mixup_scale=[0.5, 1.5], prob=0., flip_prob=0.),
                                 DetectionMixup(input_dim=(640, 640), mixup_scale=[0.5, 1.5], prob=0., flip_prob=0.),
                                 DetectionHSV(prob=.0, hgain=5, sgain=30, vgain=30),
                                 DetectionHSV(prob=.0, hgain=5, sgain=30, vgain=30),
@@ -106,7 +111,8 @@ class PretrainedModelsTest(unittest.TestCase):
                                 DetectionPaddedRescale(input_dim=(640, 640), max_targets=120),
                                 DetectionPaddedRescale(input_dim=(640, 640), max_targets=120),
                                 DetectionTargetsFormatTransform(output_format=DetectionTargetsFormat.LABEL_CXCYWH)]
                                 DetectionTargetsFormatTransform(output_format=DetectionTargetsFormat.LABEL_CXCYWH)]
         ssd_val_transforms = [DetectionPaddedRescale(input_dim=(640, 640)),
         ssd_val_transforms = [DetectionPaddedRescale(input_dim=(640, 640)),
-                              DetectionTargetsFormatTransform(max_targets=50, output_format=DetectionTargetsFormat.LABEL_CXCYWH)]
+                              DetectionTargetsFormatTransform(max_targets=50,
+                                                              output_format=DetectionTargetsFormat.LABEL_CXCYWH)]
 
 
         self.coco_dataset = {
         self.coco_dataset = {
             'yolox': CoCoDetectionDatasetInterface(
             'yolox': CoCoDetectionDatasetInterface(
@@ -129,9 +135,6 @@ class PretrainedModelsTest(unittest.TestCase):
                                 "cache_val_images": False,
                                 "cache_val_images": False,
                                 "with_crowd": True}),
                                 "with_crowd": True}),
 
 
-
-
-
             'ssd_mobilenet': CoCoDetectionDatasetInterface(dataset_params={"data_dir": "/data/coco",
             'ssd_mobilenet': CoCoDetectionDatasetInterface(dataset_params={"data_dir": "/data/coco",
                                                                            "train_subdir": "images/train2017",
                                                                            "train_subdir": "images/train2017",
                                                                            "val_subdir": "images/val2017",
                                                                            "val_subdir": "images/val2017",
@@ -327,9 +330,9 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('imagenet_pretrained_resnet50', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_resnet50', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
-        trainer.build_model("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
+        model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet50"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet50"], delta=0.001)
 
 
@@ -337,17 +340,17 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('imagenet_pretrained_resnet50_transfer_learning', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_resnet50_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
-        trainer.build_model("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        trainer.train(training_params=self.transfer_classification_train_params)
+        model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_resnet34_imagenet(self):
     def test_pretrained_resnet34_imagenet(self):
         trainer = Trainer('imagenet_pretrained_resnet34', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_resnet34', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
-        trainer.build_model("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
+        model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet34"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet34"], delta=0.001)
 
 
@@ -355,17 +358,17 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('imagenet_pretrained_resnet34_transfer_learning', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_resnet34_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
-        trainer.build_model("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        trainer.train(training_params=self.transfer_classification_train_params)
+        model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_resnet18_imagenet(self):
     def test_pretrained_resnet18_imagenet(self):
         trainer = Trainer('imagenet_pretrained_resnet18', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_resnet18', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
-        trainer.build_model("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
+        model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet18"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet18"], delta=0.001)
 
 
@@ -373,17 +376,17 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('imagenet_pretrained_resnet18_transfer_learning', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_resnet18_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
-        trainer.build_model("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        trainer.train(training_params=self.transfer_classification_train_params)
+        model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_regnetY800_imagenet(self):
     def test_pretrained_regnetY800_imagenet(self):
         trainer = Trainer('imagenet_pretrained_regnetY800', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_regnetY800', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
-        trainer.build_model("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
+        model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY800"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY800"], delta=0.001)
 
 
@@ -391,17 +394,17 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('imagenet_pretrained_regnetY800_transfer_learning', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_regnetY800_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
-        trainer.build_model("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        trainer.train(training_params=self.transfer_classification_train_params)
+        model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_regnetY600_imagenet(self):
     def test_pretrained_regnetY600_imagenet(self):
         trainer = Trainer('imagenet_pretrained_regnetY600', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_regnetY600', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
-        trainer.build_model("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
+        model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY600"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY600"], delta=0.001)
 
 
@@ -409,17 +412,17 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('imagenet_pretrained_regnetY600_transfer_learning', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_regnetY600_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
-        trainer.build_model("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        trainer.train(training_params=self.transfer_classification_train_params)
+        model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_regnetY400_imagenet(self):
     def test_pretrained_regnetY400_imagenet(self):
         trainer = Trainer('imagenet_pretrained_regnetY400', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_regnetY400', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
-        trainer.build_model("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
+        model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY400"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY400"], delta=0.001)
 
 
@@ -427,17 +430,17 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('imagenet_pretrained_regnetY400_transfer_learning', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_regnetY400_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
-        trainer.build_model("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        trainer.train(training_params=self.transfer_classification_train_params)
+        model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_regnetY200_imagenet(self):
     def test_pretrained_regnetY200_imagenet(self):
         trainer = Trainer('imagenet_pretrained_regnetY200', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_regnetY200', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
-        trainer.build_model("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
+        model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY200"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY200"], delta=0.001)
 
 
@@ -445,17 +448,17 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('imagenet_pretrained_regnetY200_transfer_learning', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_regnetY200_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
-        trainer.build_model("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        trainer.train(training_params=self.transfer_classification_train_params)
+        model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_repvgg_a0_imagenet(self):
     def test_pretrained_repvgg_a0_imagenet(self):
         trainer = Trainer('imagenet_pretrained_repvgg_a0', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_repvgg_a0', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
-        trainer.build_model("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
+        model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"],
+                           **self.imagenet_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["repvgg_a0"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["repvgg_a0"], delta=0.001)
 
 
@@ -463,17 +466,17 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('imagenet_pretrained_repvgg_a0_transfer_learning', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_repvgg_a0_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
-        trainer.build_model("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        trainer.train(training_params=self.transfer_classification_train_params)
+        model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"],
+                           **self.imagenet_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_regseg48_cityscapes(self):
     def test_pretrained_regseg48_cityscapes(self):
         trainer = Trainer('cityscapes_pretrained_regseg48', model_checkpoints_location='local',
         trainer = Trainer('cityscapes_pretrained_regseg48', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.cityscapes_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.cityscapes_dataset, data_loader_num_workers=8)
-        trainer.build_model("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"],
-                            checkpoint_params=self.cityscapes_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.cityscapes_dataset.val_loader,
+        model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"],
+                           **self.cityscapes_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.cityscapes_dataset.val_loader,
                            test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                            test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["regseg48"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["regseg48"], delta=0.001)
@@ -482,17 +485,17 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('regseg48_cityscapes_transfer_learning', model_checkpoints_location='local',
         trainer = Trainer('regseg48_cityscapes_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
-        trainer.build_model("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"],
-                            checkpoint_params=self.cityscapes_pretrained_ckpt_params)
-        trainer.train(training_params=self.regseg_transfer_segmentation_train_params)
+        model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"],
+                           **self.cityscapes_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.regseg_transfer_segmentation_train_params)
 
 
     def test_pretrained_ddrnet23_cityscapes(self):
     def test_pretrained_ddrnet23_cityscapes(self):
         trainer = Trainer('cityscapes_pretrained_ddrnet23', model_checkpoints_location='local',
         trainer = Trainer('cityscapes_pretrained_ddrnet23', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.cityscapes_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.cityscapes_dataset, data_loader_num_workers=8)
-        trainer.build_model("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
-                            checkpoint_params=self.cityscapes_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.cityscapes_dataset.val_loader,
+        model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
+                           **self.cityscapes_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.cityscapes_dataset.val_loader,
                            test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                            test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23"], delta=0.001)
@@ -501,9 +504,9 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('cityscapes_pretrained_ddrnet23_slim', model_checkpoints_location='local',
         trainer = Trainer('cityscapes_pretrained_ddrnet23_slim', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.cityscapes_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.cityscapes_dataset, data_loader_num_workers=8)
-        trainer.build_model("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
-                            checkpoint_params=self.cityscapes_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.cityscapes_dataset.val_loader,
+        model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
+                           **self.cityscapes_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.cityscapes_dataset.val_loader,
                            test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                            test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23_slim"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23_slim"], delta=0.001)
@@ -512,36 +515,36 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('cityscapes_pretrained_ddrnet23_transfer_learning', model_checkpoints_location='local',
         trainer = Trainer('cityscapes_pretrained_ddrnet23_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
-        trainer.build_model("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
-                            checkpoint_params=self.cityscapes_pretrained_ckpt_params)
-        trainer.train(training_params=self.ddrnet_transfer_segmentation_train_params)
+        model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
+                           **self.cityscapes_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.ddrnet_transfer_segmentation_train_params)
 
 
     def test_transfer_learning_ddrnet23_slim_cityscapes(self):
     def test_transfer_learning_ddrnet23_slim_cityscapes(self):
         trainer = Trainer('cityscapes_pretrained_ddrnet23_slim_transfer_learning', model_checkpoints_location='local',
         trainer = Trainer('cityscapes_pretrained_ddrnet23_slim_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
-        trainer.build_model("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
-                            checkpoint_params=self.cityscapes_pretrained_ckpt_params)
-        trainer.train(training_params=self.ddrnet_transfer_segmentation_train_params)
+        model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
+                           **self.cityscapes_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.ddrnet_transfer_segmentation_train_params)
 
 
     def test_pretrained_coco_segmentation_subclass_pretrained_shelfnet34_lw(self):
     def test_pretrained_coco_segmentation_subclass_pretrained_shelfnet34_lw(self):
         trainer = Trainer('coco_segmentation_subclass_pretrained_shelfnet34_lw', model_checkpoints_location='local',
         trainer = Trainer('coco_segmentation_subclass_pretrained_shelfnet34_lw', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.coco_segmentation_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.coco_segmentation_dataset, data_loader_num_workers=8)
-        trainer.build_model("shelfnet34_lw",
-                            arch_params=self.coco_segmentation_subclass_pretrained_arch_params["shelfnet34_lw"],
-                            checkpoint_params=self.coco_segmentation_subclass_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.coco_segmentation_dataset.val_loader, test_metrics_list=[IoU(21)],
-                           metrics_progress_verbose=True)[0].cpu().item()
+        model = models.get("shelfnet34_lw",
+                           arch_params=self.coco_segmentation_subclass_pretrained_arch_params["shelfnet34_lw"],
+                           **self.coco_segmentation_subclass_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.coco_segmentation_dataset.val_loader,
+                           test_metrics_list=[IoU(21)], metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.coco_segmentation_subclass_pretrained_mious["shelfnet34_lw"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_segmentation_subclass_pretrained_mious["shelfnet34_lw"], delta=0.001)
 
 
     def test_pretrained_efficientnet_b0_imagenet(self):
     def test_pretrained_efficientnet_b0_imagenet(self):
         trainer = Trainer('imagenet_pretrained_efficientnet_b0', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_efficientnet_b0', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
-        trainer.build_model("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
+        model = models.get("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"],
+                           **self.imagenet_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["efficientnet_b0"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["efficientnet_b0"], delta=0.001)
 
 
@@ -549,22 +552,20 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('imagenet_pretrained_efficientnet_b0_transfer_learning', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_efficientnet_b0_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
-        trainer.build_model("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        trainer.train(training_params=self.transfer_classification_train_params)
+        model = models.get("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"],
+                           **self.imagenet_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_ssd_lite_mobilenet_v2_coco(self):
     def test_pretrained_ssd_lite_mobilenet_v2_coco(self):
         trainer = Trainer('coco_ssd_lite_mobilenet_v2', model_checkpoints_location='local',
         trainer = Trainer('coco_ssd_lite_mobilenet_v2', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.coco_dataset['ssd_mobilenet'], data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.coco_dataset['ssd_mobilenet'], data_loader_num_workers=8)
-        trainer.build_model("ssd_lite_mobilenet_v2",
-                            arch_params=self.coco_pretrained_arch_params["ssd_lite_mobilenet_v2"],
-                            checkpoint_params=self.coco_pretrained_ckpt_params)
+        model = models.get("ssd_lite_mobilenet_v2",
+                           arch_params=self.coco_pretrained_arch_params["ssd_lite_mobilenet_v2"],
+                           **self.coco_pretrained_ckpt_params)
         ssd_post_prediction_callback = SSDPostPredictCallback()
         ssd_post_prediction_callback = SSDPostPredictCallback()
-        res = trainer.test(test_loader=self.coco_dataset['ssd_mobilenet'].val_loader,
-                           test_metrics_list=[
-                               DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback, num_cls=80)],
-                           metrics_progress_verbose=True)[2]
+        res = trainer.test(model=model, test_loader=self.coco_dataset['ssd_mobilenet'].val_loader, test_metrics_list=[
+            DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback, num_cls=80)], metrics_progress_verbose=True)[2]
         self.assertAlmostEqual(res, self.coco_pretrained_maps["ssd_lite_mobilenet_v2"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["ssd_lite_mobilenet_v2"], delta=0.001)
 
 
     def test_transfer_learning_ssd_lite_mobilenet_v2_coco(self):
     def test_transfer_learning_ssd_lite_mobilenet_v2_coco(self):
@@ -574,20 +575,20 @@ class PretrainedModelsTest(unittest.TestCase):
                                           data_loader_num_workers=8)
                                           data_loader_num_workers=8)
         transfer_arch_params = self.coco_pretrained_arch_params['ssd_lite_mobilenet_v2'].copy()
         transfer_arch_params = self.coco_pretrained_arch_params['ssd_lite_mobilenet_v2'].copy()
         transfer_arch_params['num_classes'] = len(self.transfer_detection_dataset.classes)
         transfer_arch_params['num_classes'] = len(self.transfer_detection_dataset.classes)
-        trainer.build_model("ssd_lite_mobilenet_v2",
-                            arch_params=transfer_arch_params,
-                            checkpoint_params=self.coco_pretrained_ckpt_params)
-        trainer.train(training_params=self.transfer_detection_train_params['ssd_lite_mobilenet_v2'])
+        model = models.get("ssd_lite_mobilenet_v2",
+                           arch_params=transfer_arch_params,
+                           **self.coco_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.transfer_detection_train_params['ssd_lite_mobilenet_v2'])
 
 
     def test_pretrained_ssd_mobilenet_v1_coco(self):
     def test_pretrained_ssd_mobilenet_v1_coco(self):
         trainer = Trainer('coco_ssd_mobilenet_v1', model_checkpoints_location='local',
         trainer = Trainer('coco_ssd_mobilenet_v1', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.coco_dataset['ssd_mobilenet'], data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.coco_dataset['ssd_mobilenet'], data_loader_num_workers=8)
-        trainer.build_model("ssd_mobilenet_v1",
-                            arch_params=self.coco_pretrained_arch_params["coco_ssd_mobilenet_v1"],
-                            checkpoint_params=self.coco_pretrained_ckpt_params)
+        model = models.get("ssd_mobilenet_v1",
+                           arch_params=self.coco_pretrained_arch_params["coco_ssd_mobilenet_v1"],
+                           **self.coco_pretrained_ckpt_params)
         ssd_post_prediction_callback = SSDPostPredictCallback()
         ssd_post_prediction_callback = SSDPostPredictCallback()
-        res = trainer.test(test_loader=self.coco_dataset['ssd_mobilenet'].val_loader,
+        res = trainer.test(model=model, test_loader=self.coco_dataset['ssd_mobilenet'].val_loader,
                            test_metrics_list=[DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback,
                            test_metrics_list=[DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback,
                                                                num_cls=len(
                                                                num_cls=len(
                                                                    self.coco_dataset['ssd_mobilenet'].coco_classes))],
                                                                    self.coco_dataset['ssd_mobilenet'].coco_classes))],
@@ -598,9 +599,9 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('yolox_s', model_checkpoints_location='local',
         trainer = Trainer('yolox_s', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
-        trainer.build_model("yolox_s",
-                            checkpoint_params=self.coco_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.coco_dataset['yolox'].val_loader,
+        model = models.get("yolox_s",
+                           **self.coco_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.coco_dataset['yolox'].val_loader,
                            test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                            test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                                                                num_cls=80,
                                                                num_cls=80,
                                                                normalize_targets=True)])[2]
                                                                normalize_targets=True)])[2]
@@ -610,9 +611,9 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('yolox_m', model_checkpoints_location='local',
         trainer = Trainer('yolox_m', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
-        trainer.build_model("yolox_m",
-                            checkpoint_params=self.coco_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.coco_dataset['yolox'].val_loader,
+        model = models.get("yolox_m",
+                           **self.coco_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.coco_dataset['yolox'].val_loader,
                            test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                            test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                                                                num_cls=80,
                                                                num_cls=80,
                                                                normalize_targets=True)])[2]
                                                                normalize_targets=True)])[2]
@@ -622,9 +623,9 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('yolox_l', model_checkpoints_location='local',
         trainer = Trainer('yolox_l', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
-        trainer.build_model("yolox_l",
-                            checkpoint_params=self.coco_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.coco_dataset['yolox'].val_loader,
+        model = models.get("yolox_l",
+                           **self.coco_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.coco_dataset['yolox'].val_loader,
                            test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                            test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                                                                num_cls=80,
                                                                num_cls=80,
                                                                normalize_targets=True)])[2]
                                                                normalize_targets=True)])[2]
@@ -634,9 +635,9 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('yolox_n', model_checkpoints_location='local',
         trainer = Trainer('yolox_n', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
-        trainer.build_model("yolox_n",
-                            checkpoint_params=self.coco_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.coco_dataset['yolox'].val_loader,
+        model = models.get("yolox_n",
+                           **self.coco_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.coco_dataset['yolox'].val_loader,
                            test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                            test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                                                                num_cls=80,
                                                                num_cls=80,
                                                                normalize_targets=True)])[2]
                                                                normalize_targets=True)])[2]
@@ -646,9 +647,9 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('yolox_t', model_checkpoints_location='local',
         trainer = Trainer('yolox_t', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
-        trainer.build_model("yolox_t",
-                            checkpoint_params=self.coco_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.coco_dataset['yolox'].val_loader,
+        model = models.get("yolox_t",
+                           **self.coco_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.coco_dataset['yolox'].val_loader,
                            test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                            test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                                                                num_cls=80,
                                                                num_cls=80,
                                                                normalize_targets=True)])[2]
                                                                normalize_targets=True)])[2]
@@ -659,25 +660,25 @@ class PretrainedModelsTest(unittest.TestCase):
                           model_checkpoints_location='local',
                           model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_detection_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_detection_dataset, data_loader_num_workers=8)
-        trainer.build_model("yolox_n", checkpoint_params=self.coco_pretrained_ckpt_params)
-        trainer.train(training_params=self.transfer_detection_train_params["yolox"])
+        model = models.get("yolox_n", **self.coco_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.transfer_detection_train_params["yolox"])
 
 
     def test_transfer_learning_mobilenet_v3_large_imagenet(self):
     def test_transfer_learning_mobilenet_v3_large_imagenet(self):
         trainer = Trainer('imagenet_pretrained_mobilenet_v3_large_transfer_learning',
         trainer = Trainer('imagenet_pretrained_mobilenet_v3_large_transfer_learning',
                           model_checkpoints_location='local',
                           model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
-        trainer.build_model("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        trainer.train(training_params=self.transfer_classification_train_params)
+        model = models.get("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_mobilenet_v3_large_imagenet(self):
     def test_pretrained_mobilenet_v3_large_imagenet(self):
         trainer = Trainer('imagenet_mobilenet_v3_large', model_checkpoints_location='local',
         trainer = Trainer('imagenet_mobilenet_v3_large', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
-        trainer.build_model("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
+        model = models.get("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_large"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_large"], delta=0.001)
 
 
@@ -686,17 +687,17 @@ class PretrainedModelsTest(unittest.TestCase):
                           model_checkpoints_location='local',
                           model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
-        trainer.build_model("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        trainer.train(training_params=self.transfer_classification_train_params)
+        model = models.get("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_mobilenet_v3_small_imagenet(self):
     def test_pretrained_mobilenet_v3_small_imagenet(self):
         trainer = Trainer('imagenet_mobilenet_v3_small', model_checkpoints_location='local',
         trainer = Trainer('imagenet_mobilenet_v3_small', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
-        trainer.build_model("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
+        model = models.get("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_small"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_small"], delta=0.001)
 
 
@@ -705,17 +706,17 @@ class PretrainedModelsTest(unittest.TestCase):
                           model_checkpoints_location='local',
                           model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
-        trainer.build_model("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        trainer.train(training_params=self.transfer_classification_train_params)
+        model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_mobilenet_v2_imagenet(self):
     def test_pretrained_mobilenet_v2_imagenet(self):
         trainer = Trainer('imagenet_mobilenet_v2', model_checkpoints_location='local',
         trainer = Trainer('imagenet_mobilenet_v2', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
-        trainer.build_model("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
+        model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
+                           **self.imagenet_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v2"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v2"], delta=0.001)
 
 
@@ -723,9 +724,9 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('cityscapes_pretrained_stdc1_seg50', model_checkpoints_location='local',
         trainer = Trainer('cityscapes_pretrained_stdc1_seg50', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled50, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled50, data_loader_num_workers=8)
-        trainer.build_model("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
-                            checkpoint_params=self.cityscapes_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.cityscapes_dataset_rescaled50.val_loader,
+        model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
+                           **self.cityscapes_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.cityscapes_dataset_rescaled50.val_loader,
                            test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                            test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg50"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg50"], delta=0.001)
@@ -734,17 +735,17 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('cityscapes_pretrained_stdc1_seg50_transfer_learning', model_checkpoints_location='local',
         trainer = Trainer('cityscapes_pretrained_stdc1_seg50_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
-        trainer.build_model("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
-                            checkpoint_params=self.cityscapes_pretrained_ckpt_params)
-        trainer.train(training_params=self.stdc_transfer_segmentation_train_params)
+        model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
+                           **self.cityscapes_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.stdc_transfer_segmentation_train_params)
 
 
     def test_pretrained_stdc1_seg75_cityscapes(self):
     def test_pretrained_stdc1_seg75_cityscapes(self):
         trainer = Trainer('cityscapes_pretrained_stdc1_seg75', model_checkpoints_location='local',
         trainer = Trainer('cityscapes_pretrained_stdc1_seg75', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled75, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled75, data_loader_num_workers=8)
-        trainer.build_model("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
-                            checkpoint_params=self.cityscapes_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.cityscapes_dataset_rescaled75.val_loader,
+        model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
+                           **self.cityscapes_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.cityscapes_dataset_rescaled75.val_loader,
                            test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                            test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg75"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg75"], delta=0.001)
@@ -753,17 +754,17 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('cityscapes_pretrained_stdc1_seg75_transfer_learning', model_checkpoints_location='local',
         trainer = Trainer('cityscapes_pretrained_stdc1_seg75_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
-        trainer.build_model("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
-                            checkpoint_params=self.cityscapes_pretrained_ckpt_params)
-        trainer.train(training_params=self.stdc_transfer_segmentation_train_params)
+        model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
+                           **self.cityscapes_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.stdc_transfer_segmentation_train_params)
 
 
     def test_pretrained_stdc2_seg50_cityscapes(self):
     def test_pretrained_stdc2_seg50_cityscapes(self):
         trainer = Trainer('cityscapes_pretrained_stdc2_seg50', model_checkpoints_location='local',
         trainer = Trainer('cityscapes_pretrained_stdc2_seg50', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled50, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled50, data_loader_num_workers=8)
-        trainer.build_model("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
-                            checkpoint_params=self.cityscapes_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.cityscapes_dataset_rescaled50.val_loader,
+        model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
+                           **self.cityscapes_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.cityscapes_dataset_rescaled50.val_loader,
                            test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                            test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg50"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg50"], delta=0.001)
@@ -772,17 +773,17 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('cityscapes_pretrained_stdc2_seg50_transfer_learning', model_checkpoints_location='local',
         trainer = Trainer('cityscapes_pretrained_stdc2_seg50_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
-        trainer.build_model("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
-                            checkpoint_params=self.cityscapes_pretrained_ckpt_params)
-        trainer.train(training_params=self.stdc_transfer_segmentation_train_params)
+        model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
+                           **self.cityscapes_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.stdc_transfer_segmentation_train_params)
 
 
     def test_pretrained_stdc2_seg75_cityscapes(self):
     def test_pretrained_stdc2_seg75_cityscapes(self):
         trainer = Trainer('cityscapes_pretrained_stdc2_seg75', model_checkpoints_location='local',
         trainer = Trainer('cityscapes_pretrained_stdc2_seg75', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled75, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled75, data_loader_num_workers=8)
-        trainer.build_model("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
-                            checkpoint_params=self.cityscapes_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.cityscapes_dataset_rescaled75.val_loader,
+        model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
+                           **self.cityscapes_pretrained_ckpt_params)
+        res = trainer.test(model=model, test_loader=self.cityscapes_dataset_rescaled75.val_loader,
                            test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                            test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                            metrics_progress_verbose=True)[0].cpu().item()
                            metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg75"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg75"], delta=0.001)
@@ -791,56 +792,59 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer = Trainer('cityscapes_pretrained_stdc2_seg75_transfer_learning', model_checkpoints_location='local',
         trainer = Trainer('cityscapes_pretrained_stdc2_seg75_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
-        trainer.build_model("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
-                            checkpoint_params=self.cityscapes_pretrained_ckpt_params)
-        trainer.train(training_params=self.stdc_transfer_segmentation_train_params)
+        model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
+                           **self.cityscapes_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.stdc_transfer_segmentation_train_params)
 
 
     def test_transfer_learning_vit_base_imagenet21k(self):
     def test_transfer_learning_vit_base_imagenet21k(self):
         trainer = Trainer('imagenet21k_pretrained_vit_base',
         trainer = Trainer('imagenet21k_pretrained_vit_base',
                           model_checkpoints_location='local',
                           model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
-        trainer.build_model("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
-                            checkpoint_params=self.imagenet21k_pretrained_ckpt_params)
-        trainer.train(training_params=self.transfer_classification_train_params)
+        model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
+                           **self.imagenet21k_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.transfer_classification_train_params)
 
 
     def test_transfer_learning_vit_large_imagenet21k(self):
     def test_transfer_learning_vit_large_imagenet21k(self):
         trainer = Trainer('imagenet21k_pretrained_vit_large',
         trainer = Trainer('imagenet21k_pretrained_vit_large',
                           model_checkpoints_location='local',
                           model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
-        trainer.build_model("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
-                            checkpoint_params=self.imagenet21k_pretrained_ckpt_params)
-        trainer.train(training_params=self.transfer_classification_train_params)
+        model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
+                           **self.imagenet21k_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_vit_base_imagenet(self):
     def test_pretrained_vit_base_imagenet(self):
         trainer = Trainer('imagenet_pretrained_vit_base', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_vit_base', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset_05_mean_std, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset_05_mean_std, data_loader_num_workers=8)
-        trainer.build_model("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.imagenet_dataset_05_mean_std.val_loader, test_metrics_list=[Accuracy()],
-                           metrics_progress_verbose=True)[0].cpu().item()
+        model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
+                           **self.imagenet_pretrained_ckpt_params)
+        res = \
+            trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std.val_loader,
+                         test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_base"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_base"], delta=0.001)
 
 
     def test_pretrained_vit_large_imagenet(self):
     def test_pretrained_vit_large_imagenet(self):
         trainer = Trainer('imagenet_pretrained_vit_large', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_vit_large', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset_05_mean_std, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset_05_mean_std, data_loader_num_workers=8)
-        trainer.build_model("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.imagenet_dataset_05_mean_std.val_loader, test_metrics_list=[Accuracy()],
-                           metrics_progress_verbose=True)[0].cpu().item()
+        model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
+                           **self.imagenet_pretrained_ckpt_params)
+        res = \
+            trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std.val_loader,
+                         test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_large"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_large"], delta=0.001)
 
 
     def test_pretrained_beit_base_imagenet(self):
     def test_pretrained_beit_base_imagenet(self):
         trainer = Trainer('imagenet_pretrained_beit_base', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_beit_base', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset_05_mean_std, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset_05_mean_std, data_loader_num_workers=8)
-        trainer.build_model("beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        res = trainer.test(test_loader=self.imagenet_dataset_05_mean_std.val_loader, test_metrics_list=[Accuracy()],
-                           metrics_progress_verbose=True)[0].cpu().item()
+        model = models.get("beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
+                           **self.imagenet_pretrained_ckpt_params)
+        res = \
+            trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std.val_loader,
+                         test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["beit_base_patch16_224"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["beit_base_patch16_224"], delta=0.001)
 
 
     def test_transfer_learning_beit_base_imagenet(self):
     def test_transfer_learning_beit_base_imagenet(self):
@@ -848,9 +852,9 @@ class PretrainedModelsTest(unittest.TestCase):
                           model_checkpoints_location='local',
                           model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
-        trainer.build_model("beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
-                            checkpoint_params=self.imagenet_pretrained_ckpt_params)
-        trainer.train(training_params=self.transfer_classification_train_params)
+        model = models.get("beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
+                           **self.imagenet_pretrained_ckpt_params)
+        trainer.train(model=model, training_params=self.transfer_classification_train_params)
 
 
     def tearDown(self) -> None:
     def tearDown(self) -> None:
         if os.path.exists('~/.cache/torch/hub/'):
         if os.path.exists('~/.cache/torch/hub/'):
Discard
@@ -1,7 +1,7 @@
 import unittest
 import unittest
 
 
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
-from super_gradients.training import Trainer, MultiGPUMode
+from super_gradients.training import Trainer, MultiGPUMode, models
 from super_gradients.training.metrics.classification_metrics import Accuracy
 from super_gradients.training.metrics.classification_metrics import Accuracy
 import os
 import os
 from super_gradients.training.utils.quantization_utils import PostQATConversionCallback
 from super_gradients.training.utils.quantization_utils import PostQATConversionCallback
@@ -15,8 +15,8 @@ class QATIntegrationTest(unittest.TestCase):
                           model_checkpoints_location='local',
                           model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(dataset)
         trainer.connect_dataset_interface(dataset)
-        trainer.build_model("resnet18", checkpoint_params={"pretrained_weights": "imagenet"})
-        return trainer
+        model = models.get("resnet18", pretrained_weights="imagenet")
+        return trainer, model
 
 
     def _get_train_params(self, qat_params):
     def _get_train_params(self, qat_params):
         train_params = {"max_epochs": 2,
         train_params = {"max_epochs": 2,
@@ -38,7 +38,7 @@ class QATIntegrationTest(unittest.TestCase):
         return train_params
         return train_params
 
 
     def test_qat_from_start(self):
     def test_qat_from_start(self):
-        trainer = self._get_trainer("test_qat_from_start")
+        model, net = self._get_trainer("test_qat_from_start")
         train_params = self._get_train_params(qat_params={
         train_params = self._get_train_params(qat_params={
             "start_epoch": 0,
             "start_epoch": 0,
             "quant_modules_calib_method": "percentile",
             "quant_modules_calib_method": "percentile",
@@ -47,10 +47,10 @@ class QATIntegrationTest(unittest.TestCase):
             "percentile": 99.99
             "percentile": 99.99
         })
         })
 
 
-        trainer.train(training_params=train_params)
+        model.train(model=net, training_params=train_params)
 
 
     def test_qat_transition(self):
     def test_qat_transition(self):
-        trainer = self._get_trainer("test_qat_transition")
+        model, net = self._get_trainer("test_qat_transition")
         train_params = self._get_train_params(qat_params={
         train_params = self._get_train_params(qat_params={
             "start_epoch": 1,
             "start_epoch": 1,
             "quant_modules_calib_method": "percentile",
             "quant_modules_calib_method": "percentile",
@@ -59,10 +59,10 @@ class QATIntegrationTest(unittest.TestCase):
             "percentile": 99.99
             "percentile": 99.99
         })
         })
 
 
-        trainer.train(training_params=train_params)
+        model.train(model=net, training_params=train_params)
 
 
     def test_qat_from_calibrated_ckpt(self):
     def test_qat_from_calibrated_ckpt(self):
-        trainer = self._get_trainer("generate_calibrated_model")
+        model, net = self._get_trainer("generate_calibrated_model")
         train_params = self._get_train_params(qat_params={
         train_params = self._get_train_params(qat_params={
             "start_epoch": 0,
             "start_epoch": 0,
             "quant_modules_calib_method": "percentile",
             "quant_modules_calib_method": "percentile",
@@ -71,11 +71,11 @@ class QATIntegrationTest(unittest.TestCase):
             "percentile": 99.99
             "percentile": 99.99
         })
         })
 
 
-        trainer.train(training_params=train_params)
+        model.train(model=net, training_params=train_params)
 
 
-        calibrated_model_path = os.path.join(trainer.checkpoints_dir_path, "ckpt_calibrated_percentile_99.99.pth")
+        calibrated_model_path = os.path.join(model.checkpoints_dir_path, "ckpt_calibrated_percentile_99.99.pth")
 
 
-        trainer = self._get_trainer("test_qat_from_calibrated_ckpt")
+        model, net = self._get_trainer("test_qat_from_calibrated_ckpt")
         train_params = self._get_train_params(qat_params={
         train_params = self._get_train_params(qat_params={
             "start_epoch": 0,
             "start_epoch": 0,
             "quant_modules_calib_method": "percentile",
             "quant_modules_calib_method": "percentile",
@@ -85,7 +85,7 @@ class QATIntegrationTest(unittest.TestCase):
             "percentile": 99.99
             "percentile": 99.99
         })
         })
 
 
-        trainer.train(training_params=train_params)
+        model.train(model=net, training_params=train_params)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
  1. import torch
  2. from super_gradients.training.datasets.dataset_interfaces import DatasetInterface
  3. from super_gradients.training.sg_trainer import Trainer
  4. from torchvision.models import resnet18
  5. import numpy as np
  6. class TestDatasetInterface(DatasetInterface):
  7. def __init__(self, dataset_params={}, image_size=32, batch_size=5):
  8. super(TestDatasetInterface, self).__init__(dataset_params)
  9. self.trainset = torch.utils.data.TensorDataset(torch.Tensor(np.zeros((batch_size, 3, image_size, image_size))),
  10. torch.LongTensor(np.zeros((batch_size))))
  11. self.testset = self.trainset
  12. def get_data_loaders(self, batch_size_factor=1, num_workers=8, train_batch_size=None, test_batch_size=None,
  13. distributed_sampler=False):
  14. self.trainset.classes = [0, 1, 2, 3, 4]
  15. return super().get_data_loaders(batch_size_factor=batch_size_factor,
  16. num_workers=num_workers,
  17. train_batch_size=train_batch_size,
  18. test_batch_size=test_batch_size,
  19. distributed_sampler=distributed_sampler)
  20. # ------------------ Loading The Model From Model.py----------------
  21. arch_params = {'num_classes': 1000}
  22. model = resnet18()
  23. trainer = Trainer('Client_model_training',
  24. model_checkpoints_location='local', device='cpu')
  25. # if a torch.nn.Module is provided when building the model, the model will be integrated into deci model class
  26. trainer.build_model(model, arch_params=arch_params)
  27. # ------------------ Loading The Dataset From Dataset.py----------------
  28. dataset = TestDatasetInterface()
  29. trainer.connect_dataset_interface(dataset)
  30. # ------------------ Loading The Loss From Loss.py -----------------
  31. loss = 'cross_entropy'
  32. # ------------------ Training -----------------
  33. train_params = {"max_epochs": 100,
  34. "lr_mode": "step",
  35. "lr_updates": [30, 60, 90, 100],
  36. "lr_decay_factor": 0.1,
  37. "initial_lr": 0.025, "loss": loss}
  38. trainer.train(train_params)
Discard
@@ -1,7 +1,6 @@
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 from tests.unit_tests.dataset_interface_test import TestDatasetInterface
 from tests.unit_tests.dataset_interface_test import TestDatasetInterface
 from tests.unit_tests.factories_test import FactoriesTest
 from tests.unit_tests.factories_test import FactoriesTest
-from tests.unit_tests.load_checkpoint_from_direct_path_test import LoadCheckpointFromDirectPathTest
 from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
 from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
 from tests.unit_tests.zero_weight_decay_on_bias_bn_test import ZeroWdForBnBiasTest
 from tests.unit_tests.zero_weight_decay_on_bias_bn_test import ZeroWdForBnBiasTest
 from tests.unit_tests.save_ckpt_test import SaveCkptListUnitTest
 from tests.unit_tests.save_ckpt_test import SaveCkptListUnitTest
@@ -21,6 +20,6 @@ from tests.unit_tests.initialize_with_dataloaders_test import InitializeWithData
 
 
 __all__ = ['TestDatasetInterface', 'ZeroWdForBnBiasTest', 'SaveCkptListUnitTest',
 __all__ = ['TestDatasetInterface', 'ZeroWdForBnBiasTest', 'SaveCkptListUnitTest',
            'AllArchitecturesTest', 'TestAverageMeter', 'TestRepVgg', 'TestWithoutTrainTest',
            'AllArchitecturesTest', 'TestAverageMeter', 'TestRepVgg', 'TestWithoutTrainTest',
-           'LoadCheckpointFromDirectPathTest', 'StrictLoadEnumTest', 'TrainWithInitializedObjectsTest', 'TestAutoAugment',
+           'StrictLoadEnumTest', 'TrainWithInitializedObjectsTest', 'TestAutoAugment',
            'OhemLossTest', 'EarlyStopTest', 'SegmentationTransformsTest', 'PretrainedModelsUnitTest', 'TestConvBnRelu',
            'OhemLossTest', 'EarlyStopTest', 'SegmentationTransformsTest', 'PretrainedModelsUnitTest', 'TestConvBnRelu',
            'FactoriesTest', 'InitializeWithDataloadersTest']
            'FactoriesTest', 'InitializeWithDataloadersTest']
Discard
@@ -3,7 +3,7 @@ import unittest
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CoCoDetectionDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CoCoDetectionDatasetInterface
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 
 
-from super_gradients.training import Trainer
+from super_gradients.training import Trainer, models
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 from super_gradients.training.utils.detection_utils import CrowdDetectionCollateFN, DetectionCollateFN, \
 from super_gradients.training.utils.detection_utils import CrowdDetectionCollateFN, DetectionCollateFN, \
     DetectionTargetsFormat
     DetectionTargetsFormat
@@ -57,7 +57,7 @@ class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
                           model_checkpoints_location='local',
                           model_checkpoints_location='local',
                           post_prediction_callback=YoloPostPredictionCallback())
                           post_prediction_callback=YoloPostPredictionCallback())
         trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
-        trainer.build_model("yolox_s")
+        model = models.get("yolox_s")
 
 
         training_params = {"max_epochs": 1,  # we dont really need the actual training to run
         training_params = {"max_epochs": 1,  # we dont really need the actual training to run
                            "lr_mode": "cosine",
                            "lr_mode": "cosine",
@@ -74,7 +74,7 @@ class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
                            "loss_logging_items_names": ["iou", "obj", "cls", "l1", "num_fg", "Loss"],
                            "loss_logging_items_names": ["iou", "obj", "cls", "l1", "num_fg", "Loss"],
                            "metric_to_watch": "mAP@0.50:0.95",
                            "metric_to_watch": "mAP@0.50:0.95",
                            }
                            }
-        trainer.train(training_params=training_params)
+        trainer.train(model=model, training_params=training_params)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -1,7 +1,7 @@
 import os
 import os
 import unittest
 import unittest
 
 
-from super_gradients.training import Trainer, utils as core_utils
+from super_gradients.training import Trainer, utils as core_utils, models
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CoCoDetectionDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CoCoDetectionDatasetInterface
 from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
 from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
@@ -50,14 +50,14 @@ class TestDetectionUtils(unittest.TestCase):
                           model_checkpoints_location='local',
                           model_checkpoints_location='local',
                           post_prediction_callback=YoloPostPredictionCallback())
                           post_prediction_callback=YoloPostPredictionCallback())
         trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
-        trainer.build_model("yolox_n", checkpoint_params={"pretrained_weights": "coco"})
+        model = models.get("yolox_n", pretrained_weights="coco")
 
 
         # Simulate one iteration of validation subset
         # Simulate one iteration of validation subset
         valid_loader = trainer.valid_loader
         valid_loader = trainer.valid_loader
         batch_i, (imgs, targets) = 0, next(iter(valid_loader))
         batch_i, (imgs, targets) = 0, next(iter(valid_loader))
         imgs = core_utils.tensor_container_to_device(imgs, trainer.device)
         imgs = core_utils.tensor_container_to_device(imgs, trainer.device)
         targets = core_utils.tensor_container_to_device(targets, trainer.device)
         targets = core_utils.tensor_container_to_device(targets, trainer.device)
-        output = trainer.net(imgs)
+        output = model(imgs)
         output = trainer.post_prediction_callback(output)
         output = trainer.post_prediction_callback(output)
         # Visualize the batch
         # Visualize the batch
         DetectionVisualization.visualize_batch(imgs, output, targets, batch_i,
         DetectionVisualization.visualize_batch(imgs, output, targets, batch_i,
Discard
@@ -62,7 +62,6 @@ class EarlyStopTest(unittest.TestCase):
         """
         """
         trainer = Trainer("early_stop_test", model_checkpoints_location='local')
         trainer = Trainer("early_stop_test", model_checkpoints_location='local')
         trainer.connect_dataset_interface(self.dataset)
         trainer.connect_dataset_interface(self.dataset)
-        trainer.build_model(self.net)
 
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", patience=3, verbose=True)
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", patience=3, verbose=True)
         phase_callbacks = [early_stop_loss]
         phase_callbacks = [early_stop_loss]
@@ -72,7 +71,7 @@ class EarlyStopTest(unittest.TestCase):
         train_params = self.train_params.copy()
         train_params = self.train_params.copy()
         train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
         train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
 
 
-        trainer.train(train_params)
+        trainer.train(model=self.net, training_params=train_params)
 
 
         excepted_end_epoch = 5
         excepted_end_epoch = 5
 
 
@@ -86,7 +85,6 @@ class EarlyStopTest(unittest.TestCase):
         """
         """
         trainer = Trainer("early_stop_test", model_checkpoints_location='local')
         trainer = Trainer("early_stop_test", model_checkpoints_location='local')
         trainer.connect_dataset_interface(self.dataset)
         trainer.connect_dataset_interface(self.dataset)
-        trainer.build_model(self.net)
 
 
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=3,
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=3,
                                    verbose=True)
                                    verbose=True)
@@ -98,7 +96,7 @@ class EarlyStopTest(unittest.TestCase):
         train_params.update(
         train_params.update(
             {"valid_metrics_list": [fake_metric], "metric_to_watch": "MetricTest", "phase_callbacks": phase_callbacks})
             {"valid_metrics_list": [fake_metric], "metric_to_watch": "MetricTest", "phase_callbacks": phase_callbacks})
 
 
-        trainer.train(train_params)
+        trainer.train(model=self.net, training_params=train_params)
 
 
         excepted_end_epoch = 6
         excepted_end_epoch = 6
 
 
@@ -110,7 +108,6 @@ class EarlyStopTest(unittest.TestCase):
         """
         """
         trainer = Trainer("early_stop_test", model_checkpoints_location='local')
         trainer = Trainer("early_stop_test", model_checkpoints_location='local')
         trainer.connect_dataset_interface(self.dataset)
         trainer.connect_dataset_interface(self.dataset)
-        trainer.build_model(self.net)
 
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", threshold=0.1, verbose=True)
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", threshold=0.1, verbose=True)
         phase_callbacks = [early_stop_loss]
         phase_callbacks = [early_stop_loss]
@@ -120,7 +117,7 @@ class EarlyStopTest(unittest.TestCase):
         train_params = self.train_params.copy()
         train_params = self.train_params.copy()
         train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
         train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
 
 
-        trainer.train(train_params)
+        trainer.train(model=self.net, training_params=train_params)
 
 
         excepted_end_epoch = 5
         excepted_end_epoch = 5
         # count divided by 2, because loss counter used for both train and eval.
         # count divided by 2, because loss counter used for both train and eval.
@@ -132,7 +129,6 @@ class EarlyStopTest(unittest.TestCase):
         """
         """
         trainer = Trainer("early_stop_test", model_checkpoints_location='local')
         trainer = Trainer("early_stop_test", model_checkpoints_location='local')
         trainer.connect_dataset_interface(self.dataset)
         trainer.connect_dataset_interface(self.dataset)
-        trainer.build_model(self.net)
 
 
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", threshold=0.94,
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", threshold=0.94,
                                    verbose=True)
                                    verbose=True)
@@ -144,7 +140,7 @@ class EarlyStopTest(unittest.TestCase):
         train_params.update(
         train_params.update(
             {"valid_metrics_list": [fake_metric], "metric_to_watch": "MetricTest", "phase_callbacks": phase_callbacks})
             {"valid_metrics_list": [fake_metric], "metric_to_watch": "MetricTest", "phase_callbacks": phase_callbacks})
 
 
-        trainer.train(train_params)
+        trainer.train(model=self.net, training_params=train_params)
 
 
         excepted_end_epoch = 7
         excepted_end_epoch = 7
 
 
@@ -157,7 +153,6 @@ class EarlyStopTest(unittest.TestCase):
         # test Nan value
         # test Nan value
         trainer = Trainer("early_stop_test", model_checkpoints_location='local')
         trainer = Trainer("early_stop_test", model_checkpoints_location='local')
         trainer.connect_dataset_interface(self.dataset)
         trainer.connect_dataset_interface(self.dataset)
-        trainer.build_model(self.net)
 
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", check_finite=True,
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", check_finite=True,
                                     verbose=True)
                                     verbose=True)
@@ -168,7 +163,7 @@ class EarlyStopTest(unittest.TestCase):
         train_params = self.train_params.copy()
         train_params = self.train_params.copy()
         train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
         train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
 
 
-        trainer.train(train_params)
+        trainer.train(model=self.net, training_params=train_params)
 
 
         excepted_end_epoch = 2
         excepted_end_epoch = 2
 
 
@@ -187,7 +182,7 @@ class EarlyStopTest(unittest.TestCase):
         train_params = self.train_params.copy()
         train_params = self.train_params.copy()
         train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
         train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
 
 
-        trainer.train(train_params)
+        trainer.train(model=self.net, training_params=train_params)
 
 
         excepted_end_epoch = 3
         excepted_end_epoch = 3
         # count divided by 2, because loss counter used for both train and eval.
         # count divided by 2, because loss counter used for both train and eval.
@@ -200,7 +195,6 @@ class EarlyStopTest(unittest.TestCase):
         """
         """
         trainer = Trainer("early_stop_test", model_checkpoints_location='local')
         trainer = Trainer("early_stop_test", model_checkpoints_location='local')
         trainer.connect_dataset_interface(self.dataset)
         trainer.connect_dataset_interface(self.dataset)
-        trainer.build_model(self.net)
 
 
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=2,
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=2,
                                    min_delta=0.1, verbose=True)
                                    min_delta=0.1, verbose=True)
@@ -212,7 +206,7 @@ class EarlyStopTest(unittest.TestCase):
         train_params.update(
         train_params.update(
             {"valid_metrics_list": [fake_metric], "metric_to_watch": "MetricTest", "phase_callbacks": phase_callbacks})
             {"valid_metrics_list": [fake_metric], "metric_to_watch": "MetricTest", "phase_callbacks": phase_callbacks})
 
 
-        trainer.train(train_params)
+        trainer.train(model=self.net, training_params=train_params)
 
 
         excepted_end_epoch = 5
         excepted_end_epoch = 5
 
 
Discard
@@ -32,7 +32,7 @@ class FactoriesTest(unittest.TestCase):
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
 
 
-        trainer.train(train_params)
+        trainer.train(model=net, training_params=train_params)
 
 
         self.assertIsInstance(trainer.train_metrics.Accuracy, Accuracy)
         self.assertIsInstance(trainer.train_metrics.Accuracy, Accuracy)
         self.assertIsInstance(trainer.valid_metrics.Top5, Top5)
         self.assertIsInstance(trainer.valid_metrics.Top5, Top5)
Discard
@@ -1,5 +1,5 @@
 import unittest
 import unittest
-from super_gradients.training import Trainer
+from super_gradients.training import Trainer, models
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
 from super_gradients.training.utils.callbacks import PhaseCallback, Phase, PhaseContext
 from super_gradients.training.utils.callbacks import PhaseCallback, Phase, PhaseContext
@@ -36,7 +36,7 @@ class ForwardpassPrepFNTest(unittest.TestCase):
         # Define Model
         # Define Model
         trainer = Trainer("ForwardpassPrepFNTest")
         trainer = Trainer("ForwardpassPrepFNTest")
         trainer.connect_dataset_interface(self.dataset)
         trainer.connect_dataset_interface(self.dataset)
-        trainer.build_model("resnet18", arch_params=self.arch_params)
+        model = models.get("resnet18", arch_params=self.arch_params)
 
 
         sizes = []
         sizes = []
         phase_callbacks = [TestInputSizesCallback(sizes)]
         phase_callbacks = [TestInputSizesCallback(sizes)]
@@ -49,7 +49,7 @@ class ForwardpassPrepFNTest(unittest.TestCase):
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks,
                         "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks,
                         "pre_prediction_callback": test_forward_pass_prep_fn}
                         "pre_prediction_callback": test_forward_pass_prep_fn}
-        trainer.train(train_params)
+        trainer.train(model=model, training_params=train_params)
 
 
         # ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
         # ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
         # THE LRS AFTER THE UPDATE
         # THE LRS AFTER THE UPDATE
Discard
@@ -1,4 +1,7 @@
 import unittest
 import unittest
+
+from super_gradients.training import models
+
 from super_gradients import Trainer, ClassificationTestDatasetInterface
 from super_gradients import Trainer, ClassificationTestDatasetInterface
 import torch
 import torch
 from torch.utils.data import TensorDataset, DataLoader
 from torch.utils.data import TensorDataset, DataLoader
@@ -29,7 +32,7 @@ class InitializeWithDataloadersTest(unittest.TestCase):
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         trainer.connect_dataset_interface(dataset)
         trainer.connect_dataset_interface(dataset)
 
 
-        trainer.build_model("efficientnet_b0")
+        model = models.get("efficientnet_b0", arch_params={"num_classes": 5})
         train_params = {"max_epochs": 1, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 1, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": torch.nn.CrossEntropyLoss(),
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": torch.nn.CrossEntropyLoss(),
                         "optimizer": "SGD",
                         "optimizer": "SGD",
@@ -37,7 +40,7 @@ class InitializeWithDataloadersTest(unittest.TestCase):
                         "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
                         "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
                         "metric_to_watch": "Accuracy",
                         "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
-        trainer.train(train_params)
+        trainer.train(model=model, training_params=train_params)
 
 
     def test_initialization_rules(self):
     def test_initialization_rules(self):
         self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
         self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
@@ -63,19 +66,19 @@ class InitializeWithDataloadersTest(unittest.TestCase):
                           train_loader=self.testcase_trainloader, valid_loader=self.testcase_validloader,
                           train_loader=self.testcase_trainloader, valid_loader=self.testcase_validloader,
                           classes=self.testcase_classes)
                           classes=self.testcase_classes)
 
 
-        trainer.build_model("resnet18")
-        trainer.train(training_params={"max_epochs": 2,
-                                       "lr_updates": [5, 6, 12],
-                                       "lr_decay_factor": 0.01,
-                                       "lr_mode": "step",
-                                       "initial_lr": 0.01,
-                                       "loss": "cross_entropy",
-                                       "optimizer": "SGD",
-                                       "optimizer_params": {"weight_decay": 1e-5, "momentum": 0.9},
-                                       "train_metrics_list": [Accuracy()],
-                                       "valid_metrics_list": [Accuracy()],
-                                       "metric_to_watch": "Accuracy",
-                                       "greater_metric_to_watch_is_better": True})
+        model = models.get("resnet18", arch_params={"num_classes": 5})
+        trainer.train(model=model, training_params={"max_epochs": 2,
+                                                    "lr_updates": [5, 6, 12],
+                                                    "lr_decay_factor": 0.01,
+                                                    "lr_mode": "step",
+                                                    "initial_lr": 0.01,
+                                                    "loss": "cross_entropy",
+                                                    "optimizer": "SGD",
+                                                    "optimizer_params": {"weight_decay": 1e-5, "momentum": 0.9},
+                                                    "train_metrics_list": [Accuracy()],
+                                                    "valid_metrics_list": [Accuracy()],
+                                                    "metric_to_watch": "Accuracy",
+                                                    "greater_metric_to_watch_is_better": True})
         self.assertTrue(0 < trainer.best_metric.item() < 1)
         self.assertTrue(0 < trainer.best_metric.item() < 1)
 
 
 
 
Discard
@@ -1,6 +1,8 @@
 import unittest
 import unittest
-from super_gradients.training.sg_trainer import Trainer
-from super_gradients.training.kd_trainer.kd_trainer import KDTrainer
+
+from super_gradients.training import models
+from super_gradients.training import Trainer
+from super_gradients.training.kd_trainer import KDTrainer
 import torch
 import torch
 from super_gradients.training.utils.utils import check_models_have_same_weights
 from super_gradients.training.utils.utils import check_models_have_same_weights
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
@@ -29,96 +31,43 @@ class KDEMATest(unittest.TestCase):
     def test_teacher_ema_not_duplicated(self):
     def test_teacher_ema_not_duplicated(self):
         """Check that the teacher EMA is a reference to the teacher net (not a copy)."""
         """Check that the teacher EMA is a reference to the teacher net (not a copy)."""
 
 
-        kd_trainer = KDTrainer("test_teacher_ema_not_duplicated", device='cpu')
-        kd_trainer.connect_dataset_interface(self.dataset)
-        kd_trainer.build_model(student_architecture='resnet18',
-                               teacher_architecture='resnet50',
-                               student_arch_params={'num_classes': 1000},
-                               teacher_arch_params={'num_classes': 1000},
-                               checkpoint_params={'teacher_pretrained_weights': "imagenet"},
-                               run_teacher_on_eval=True, )
-
-        kd_trainer.train(self.kd_train_params)
-
-        self.assertTrue(kd_trainer.ema_model.ema.module.teacher is kd_trainer.net.module.teacher)
-        self.assertTrue(kd_trainer.ema_model.ema.module.student is not kd_trainer.net.module.student)
-
-    def test_kd_ckpt_reload_ema(self):
-        """Check that the KD model load correctly from checkpoint when "load_ema_as_net=True"."""
-
-        # Create a KD model and train it
-        kd_trainer = KDTrainer("test_kd_ema_ckpt_reload", device='cpu')
-        kd_trainer.connect_dataset_interface(self.dataset)
-        kd_trainer.build_model(student_architecture='resnet18',
-                               teacher_architecture='resnet50',
-                               student_arch_params={'num_classes': 1000},
-                               teacher_arch_params={'num_classes': 1000},
-                               checkpoint_params={'teacher_pretrained_weights': "imagenet"},
-                               run_teacher_on_eval=True, )
-        kd_trainer.train(self.kd_train_params)
-        ema_model = kd_trainer.ema_model.ema
-        net = kd_trainer.net
-
-        # Load the trained KD model
-        kd_trainer = KDTrainer("test_kd_ema_ckpt_reload", device='cpu')
-        kd_trainer.connect_dataset_interface(self.dataset)
-        kd_trainer.build_model(student_architecture='resnet18',
-                               teacher_architecture='resnet50',
-                               student_arch_params={'num_classes': 1000},
-                               teacher_arch_params={'num_classes': 1000},
-                               checkpoint_params={"load_checkpoint": True, "load_ema_as_net": True},
-                               run_teacher_on_eval=True, )
-
-        # TRAIN FOR 0 EPOCHS JUST TO SEE THAT WHEN CONTINUING TRAINING EMA MODEL HAS BEEN SAVED CORRECTLY
-        kd_trainer.train(self.kd_train_params)
-        reloaded_ema_model = kd_trainer.ema_model.ema
-        reloaded_net = kd_trainer.net
-
-        # trained ema == loaded ema (Should always be true as long as "ema=True" in train_params)
-        self.assertTrue(check_models_have_same_weights(ema_model, reloaded_ema_model))
-
-        # loaded net != trained net (since load_ema_as_net = True)
-        self.assertTrue(not check_models_have_same_weights(reloaded_net, net))
-
-        # loaded net == trained ema (since load_ema_as_net = True)
-        self.assertTrue(check_models_have_same_weights(reloaded_net, ema_model))
+        kd_model = KDTrainer("test_teacher_ema_not_duplicated", device='cpu')
+        kd_model.connect_dataset_interface(self.dataset)
+        student = models.get('resnet18', arch_params={'num_classes': 1000})
+        teacher = models.get('resnet50', arch_params={'num_classes': 1000},
+                             pretrained_weights="imagenet")
 
 
-        # loaded student ema == loaded student net (since load_ema_as_net = True)
-        self.assertTrue(check_models_have_same_weights(reloaded_ema_model.module.student, reloaded_net.module.student))
+        kd_model.train(training_params=self.kd_train_params, student=student, teacher=teacher)
 
 
-        # loaded teacher ema == loaded teacher net (teacher always loads ema)
-        self.assertTrue(check_models_have_same_weights(reloaded_ema_model.module.teacher, reloaded_net.module.teacher))
+        self.assertTrue(kd_model.ema_model.ema.module.teacher is kd_model.net.module.teacher)
+        self.assertTrue(kd_model.ema_model.ema.module.student is not kd_model.net.module.student)
 
 
     def test_kd_ckpt_reload_net(self):
     def test_kd_ckpt_reload_net(self):
-        """Check that the KD model load correctly from checkpoint when "load_ema_as_net=False"."""
-
-        # Create a KD model and train it
-        kd_trainer = KDTrainer("test_kd_ema_ckpt_reload", device='cpu')
-        kd_trainer.connect_dataset_interface(self.dataset)
-        kd_trainer.build_model(student_architecture='resnet18',
-                               teacher_architecture='resnet50',
-                               student_arch_params={'num_classes': 1000},
-                               teacher_arch_params={'num_classes': 1000},
-                               checkpoint_params={'teacher_pretrained_weights': "imagenet"},
-                               run_teacher_on_eval=True, )
-        kd_trainer.train(self.kd_train_params)
-        ema_model = kd_trainer.ema_model.ema
-        net = kd_trainer.net
-
-        # Load the trained KD model
-        kd_trainer = KDTrainer("test_kd_ema_ckpt_reload", device='cpu')
-        kd_trainer.connect_dataset_interface(self.dataset)
-        kd_trainer.build_model(student_architecture='resnet18',
-                               teacher_architecture='resnet50',
-                               student_arch_params={'num_classes': 1000},
-                               teacher_arch_params={'num_classes': 1000},
-                               checkpoint_params={"load_checkpoint": True, "load_ema_as_net": False},
-                               run_teacher_on_eval=True, )
-
-        # TRAIN FOR 0 EPOCHS JUST TO SEE THAT WHEN CONTINUING TRAINING EMA MODEL HAS BEEN SAVED CORRECTLY
-        kd_trainer.train(self.kd_train_params)
-        reloaded_ema_model = kd_trainer.ema_model.ema
-        reloaded_net = kd_trainer.net
+        """Check that the KD trainer load correctly from checkpoint when "load_ema_as_net=False"."""
+
+        # Create a KD trainer and train it
+        train_params = self.kd_train_params.copy()
+        kd_model = KDTrainer("test_kd_ema_ckpt_reload", device='cpu')
+        kd_model.connect_dataset_interface(self.dataset)
+        student = models.get('resnet18', arch_params={'num_classes': 1000})
+        teacher = models.get('resnet50', arch_params={'num_classes': 1000},
+                             pretrained_weights="imagenet")
+
+        kd_model.train(training_params=self.kd_train_params, student=student, teacher=teacher)
+        ema_model = kd_model.ema_model.ema
+        net = kd_model.net
+
+        # Load the trained KD trainer
+        kd_model = KDTrainer("test_kd_ema_ckpt_reload", device='cpu')
+        kd_model.connect_dataset_interface(self.dataset)
+        student = models.get('resnet18', arch_params={'num_classes': 1000})
+        teacher = models.get('resnet50', arch_params={'num_classes': 1000},
+                             pretrained_weights="imagenet")
+
+        train_params["resume"] = True
+        kd_model.train(training_params=train_params, student=student, teacher=teacher)
+        reloaded_ema_model = kd_model.ema_model.ema
+        reloaded_net = kd_model.net
 
 
         # trained ema == loaded ema (Should always be true as long as "ema=True" in train_params)
         # trained ema == loaded ema (Should always be true as long as "ema=True" in train_params)
         self.assertTrue(check_models_have_same_weights(ema_model, reloaded_ema_model))
         self.assertTrue(check_models_have_same_weights(ema_model, reloaded_ema_model))
Discard
@@ -1,37 +1,43 @@
-import unittest
 import os
 import os
-from super_gradients.training.sg_trainer import Trainer
+import unittest
+from copy import deepcopy
 from super_gradients.training.kd_trainer.kd_trainer import KDTrainer
 from super_gradients.training.kd_trainer.kd_trainer import KDTrainer
 import torch
 import torch
-from super_gradients.training.utils.utils import check_models_have_same_weights
+
+from super_gradients.training import models
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
+from super_gradients.training.losses.kd_losses import KDLogitsLoss
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
-from super_gradients.training.exceptions.kd_trainer_exceptions import ArchitectureKwargsException, \
-    UnsupportedKDArchitectureException, InconsistentParamsException, UnsupportedKDModelArgException, \
-    TeacherKnowledgeException
 from super_gradients.training.models.classification_models.resnet import ResNet50, ResNet18
 from super_gradients.training.models.classification_models.resnet import ResNet50, ResNet18
-from super_gradients.training.losses.kd_losses import KDLogitsLoss
-from copy import deepcopy
+from super_gradients.training.models.kd_modules.kd_module import KDModule
+from super_gradients.training.utils.callbacks import PhaseCallback, PhaseContext, Phase
 from super_gradients.training.utils.module_utils import NormalizationAdapter
 from super_gradients.training.utils.module_utils import NormalizationAdapter
+from super_gradients.training.utils.utils import check_models_have_same_weights
+
+
+class PreTrainingNetCollector(PhaseCallback):
+    def __init__(self):
+        super(PreTrainingNetCollector, self).__init__(phase=Phase.PRE_TRAINING)
+        self.net = None
+
+    def __call__(self, context: PhaseContext):
+        self.net = deepcopy(context.net)
+
+
+class PreTrainingEMANetCollector(PhaseCallback):
+    def __init__(self):
+        super(PreTrainingEMANetCollector, self).__init__(phase=Phase.PRE_TRAINING)
+        self.net = None
+
+    def __call__(self, context: PhaseContext):
+        self.net = deepcopy(context.ema_model)
 
 
 
 
 class KDTrainerTest(unittest.TestCase):
 class KDTrainerTest(unittest.TestCase):
     @classmethod
     @classmethod
     def setUp(cls):
     def setUp(cls):
-        cls.sg_trained_teacher = Trainer("sg_trained_teacher", device='cpu')
         cls.dataset_params = {"batch_size": 5}
         cls.dataset_params = {"batch_size": 5}
         cls.dataset = ClassificationTestDatasetInterface(dataset_params=cls.dataset_params)
         cls.dataset = ClassificationTestDatasetInterface(dataset_params=cls.dataset_params)
-        cls.sg_trained_teacher.connect_dataset_interface(cls.dataset)
-
-        cls.sg_trained_teacher.build_model('resnet50', arch_params={'num_classes': 5})
-
-        cls.train_params = {"max_epochs": 3, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
-                            "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": torch.nn.CrossEntropyLoss(),
-                            "optimizer": "SGD",
-                            "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
-                            "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
-                            "metric_to_watch": "Accuracy",
-                            "greater_metric_to_watch_is_better": True, "average_best_models": False}
 
 
         cls.kd_train_params = {"max_epochs": 3, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         cls.kd_train_params = {"max_epochs": 3, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                                "lr_warmup_epochs": 0, "initial_lr": 0.1,
                                "lr_warmup_epochs": 0, "initial_lr": 0.1,
@@ -43,230 +49,117 @@ class KDTrainerTest(unittest.TestCase):
                                'loss_logging_items_names': ["Loss", "Task Loss", "Distillation Loss"],
                                'loss_logging_items_names': ["Loss", "Task Loss", "Distillation Loss"],
                                "greater_metric_to_watch_is_better": True, "average_best_models": False}
                                "greater_metric_to_watch_is_better": True, "average_best_models": False}
 
 
-    def test_build_kd_module_with_pretrained_teacher(self):
-        kd_trainer = KDTrainer("build_kd_module_with_pretrained_teacher", device='cpu')
-        kd_trainer.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
-                             student_arch_params={'num_classes': 1000}, teacher_arch_params={'num_classes': 1000},
-                             checkpoint_params={'teacher_pretrained_weights': "imagenet"}
-                             )
-        imagenet_resnet50_trainer = Trainer("pretrained_resnet50")
-        imagenet_resnet50_trainer.build_model('resnet50', arch_params={'num_classes': 1000},
-                                               checkpoint_params={'pretrained_weights': "imagenet"})
-
-        self.assertTrue(check_models_have_same_weights(kd_trainer.net.module.teacher,
-                                                       imagenet_resnet50_trainer.net.module))
-
-    def test_build_kd_module_with_pretrained_student(self):
-        kd_trainer = KDTrainer("build_kd_module_with_pretrained_student", device='cpu')
-        kd_trainer.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
-                             student_arch_params={'num_classes': 1000}, teacher_arch_params={'num_classes': 1000},
-                             checkpoint_params={'student_pretrained_weights': "imagenet",
-                                                'teacher_pretrained_weights': "imagenet"}
-                             )
-
-        imagenet_resnet18_trainer = Trainer("pretrained_resnet18", device='cpu')
-        imagenet_resnet18_trainer.build_model('resnet18', arch_params={'num_classes': 1000},
-                                               checkpoint_params={'pretrained_weights': "imagenet"})
-
-        self.assertTrue(check_models_have_same_weights(kd_trainer.net.module.student,
-                                                       imagenet_resnet18_trainer.net.module))
-
-    def test_build_kd_module_pretrained_student_with_head_replacement(self):
-        self.sg_trained_teacher.train(self.train_params)
-        teacher_path = os.path.join(self.sg_trained_teacher.checkpoints_dir_path, 'ckpt_latest.pth')
-
-        kd_trainer = KDTrainer('test_build_kd_module_student_replace_head', device='cpu')
-        kd_trainer.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
-                                student_arch_params={'num_classes': 5}, teacher_arch_params={'num_classes': 5},
-                                checkpoint_params={'student_pretrained_weights': "imagenet",
-                                                   "teacher_checkpoint_path": teacher_path}
-                                )
-
-        self.assertTrue(kd_trainer.net.module.student.linear.out_features == 5)
-
-    def test_build_kd_module_with_sg_trained_teacher(self):
-        self.sg_trained_teacher.train(self.train_params)
-        teacher_path = os.path.join(self.sg_trained_teacher.checkpoints_dir_path, 'ckpt_latest.pth')
-
-        kd_trainer = KDTrainer('test_build_kd_module_with_sg_trained_teacher', device='cpu')
-
-        kd_trainer.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
-                                student_arch_params={'num_classes': 5}, teacher_arch_params={'num_classes': 5},
-                                checkpoint_params={"teacher_checkpoint_path": teacher_path}
-                                )
-
-        self.assertTrue(
-            check_models_have_same_weights(self.sg_trained_teacher.net.module, kd_trainer.net.module.teacher))
-
     def test_teacher_sg_module_methods(self):
     def test_teacher_sg_module_methods(self):
-        kd_trainer = KDTrainer("test_teacher_sg_module_methods", device='cpu')
-        kd_trainer.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
-                                student_arch_params={'num_classes': 1000}, teacher_arch_params={'num_classes': 1000},
-                                checkpoint_params={'teacher_pretrained_weights': "imagenet"}
-                                )
+        student = models.get('resnet18', arch_params={'num_classes': 1000})
+        teacher = models.get('resnet50', arch_params={'num_classes': 1000},
+                             pretrained_weights="imagenet")
+        kd_module = KDModule(arch_params={},
+                             student=student,
+                             teacher=teacher
+                             )
 
 
-        initial_param_groups = kd_trainer.net.module.initialize_param_groups(lr=0.1, training_params={})
-        updated_param_groups = kd_trainer.net.module.update_param_groups(param_groups=initial_param_groups, lr=0.2,
-                                                                          epoch=0, iter=0, training_params={},
-                                                                          total_batch=None)
+        initial_param_groups = kd_module.initialize_param_groups(lr=0.1, training_params={})
+        updated_param_groups = kd_module.update_param_groups(param_groups=initial_param_groups, lr=0.2,
+                                                             epoch=0, iter=0, training_params={},
+                                                             total_batch=None)
 
 
         self.assertTrue(initial_param_groups[0]['lr'] == 0.2 == updated_param_groups[0]['lr'])
         self.assertTrue(initial_param_groups[0]['lr'] == 0.2 == updated_param_groups[0]['lr'])
 
 
-    def test_kd_architecture_kwarg_exception_catching(self):
-        kd_trainer = KDTrainer("test_teacher_sg_module_methods", device='cpu')
-        with self.assertRaises(ArchitectureKwargsException):
-            kd_trainer.build_model(teacher_architecture='resnet50',
-                                    student_arch_params={'num_classes': 5}, teacher_arch_params={'num_classes': 5},
-                                    checkpoint_params={'teacher_pretrained_weights': "imagenet"}
-                                    )
-
-    def test_kd_unsupported_kdmodel_arg_exception_catching(self):
-        kd_trainer = KDTrainer("test_teacher_sg_module_methods", device='cpu')
-        with self.assertRaises(UnsupportedKDModelArgException):
-            kd_trainer.build_model(student_architecture='resnet18',
-                                    teacher_architecture='resnet50',
-                                    student_arch_params={'num_classes': 1000},
-                                    teacher_arch_params={'num_classes': 1000},
-                                    checkpoint_params={"pretrained_weights": "imagenet"},
-                                    )
-
-    def test_kd_unsupported_model_exception_catching(self):
-        kd_trainer = KDTrainer("test_teacher_sg_module_methods", device='cpu')
-        with self.assertRaises(UnsupportedKDArchitectureException):
-            kd_trainer.build_model(student_architecture='resnet18',
-                                    teacher_architecture='resnet50',
-                                    student_arch_params={'num_classes': 1000},
-                                    teacher_arch_params={'num_classes': 1000},
-                                    checkpoint_params={'teacher_pretrained_weights': "imagenet"},
-                                    architecture='unsupported_model'
-                                    )
-
-    def test_kd_inconsistent_params_exception_catching(self):
-        kd_trainer = KDTrainer("test_teacher_sg_module_methods", device='cpu')
-        with self.assertRaises(InconsistentParamsException):
-            kd_trainer.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
-                                    student_arch_params={'num_classes': 10}, teacher_arch_params={'num_classes': 1000},
-                                    checkpoint_params={'teacher_pretrained_weights': "imagenet"}
-                                    )
-
-    def test_kd_teacher_knowledge_exception_catching(self):
-        kd_trainer = KDTrainer("test_teacher_sg_module_methods", device='cpu')
-        with self.assertRaises(TeacherKnowledgeException):
-            kd_trainer.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
-                                    student_arch_params={'num_classes': 1000}, teacher_arch_params={'num_classes': 1000}
-                                    )
-
-    def test_build_external_models(self):
-        kd_trainer = KDTrainer("test_training_with_external_teacher", device='cpu')
-        teacher_model = ResNet50(arch_params={}, num_classes=10)
-        student_model = ResNet18(arch_params={}, num_classes=10)
-        kd_trainer.build_model(student_architecture=student_model, teacher_architecture=teacher_model,
-                             student_arch_params={'num_classes': 10}, teacher_arch_params={'num_classes': 10}
-                             )
-
-        self.assertTrue(
-            check_models_have_same_weights(teacher_model, kd_trainer.net.module.teacher))
-        self.assertTrue(
-            check_models_have_same_weights(student_model, kd_trainer.net.module.student))
-
     def test_train_kd_module_external_models(self):
     def test_train_kd_module_external_models(self):
-        kd_trainer = KDTrainer("test_train_kd_module_external_models", device='cpu')
+        sg_model = KDTrainer("test_train_kd_module_external_models", device='cpu')
         teacher_model = ResNet50(arch_params={}, num_classes=5)
         teacher_model = ResNet50(arch_params={}, num_classes=5)
         student_model = ResNet18(arch_params={}, num_classes=5)
         student_model = ResNet18(arch_params={}, num_classes=5)
-        kd_trainer.connect_dataset_interface(self.dataset)
-        kd_trainer.build_model(run_teacher_on_eval=True,
-                             student_arch_params={'num_classes': 5},
-                             teacher_arch_params={'num_classes': 5},
-                             student_architecture=deepcopy(student_model),
-                             teacher_architecture=deepcopy(teacher_model),
-                             )
+        sg_model.connect_dataset_interface(self.dataset)
 
 
-        kd_trainer.train(self.kd_train_params)
+        sg_model.train(training_params=self.kd_train_params, student=deepcopy(student_model), teacher=teacher_model)
 
 
         # TEACHER WEIGHT'S SHOULD REMAIN THE SAME
         # TEACHER WEIGHT'S SHOULD REMAIN THE SAME
         self.assertTrue(
         self.assertTrue(
-            check_models_have_same_weights(teacher_model, kd_trainer.net.module.teacher))
+            check_models_have_same_weights(teacher_model, sg_model.net.module.teacher))
 
 
         # STUDENT WEIGHT'S SHOULD NOT REMAIN THE SAME
         # STUDENT WEIGHT'S SHOULD NOT REMAIN THE SAME
         self.assertFalse(
         self.assertFalse(
-            check_models_have_same_weights(student_model, kd_trainer.net.module.student))
+            check_models_have_same_weights(student_model, sg_model.net.module.student))
 
 
-    def test_train_kd_module_pretrained_ckpt(self):
-        kd_trainer = KDTrainer("test_train_kd_module_pretrained_ckpt", device='cpu')
-        teacher_model = ResNet50(arch_params={}, num_classes=5)
-        teacher_path = '/tmp/teacher.pth'
-        torch.save(teacher_model.state_dict(), teacher_path)
+    def test_train_model_with_input_adapter(self):
+        kd_trainer = KDTrainer("train_kd_module_with_with_input_adapter", device='cpu')
         kd_trainer.connect_dataset_interface(self.dataset)
         kd_trainer.connect_dataset_interface(self.dataset)
+        student = models.get('resnet18', arch_params={'num_classes': 5})
+        teacher = models.get('resnet50', arch_params={'num_classes': 5},
+                             pretrained_weights="imagenet")
 
 
-        kd_trainer.build_model(student_arch_params={'num_classes': 5},
-                             teacher_arch_params={'num_classes': 5},
-                             student_architecture='resnet18',
-                             teacher_architecture='resnet50',
-                             checkpoint_params={"teacher_checkpoint_path": teacher_path}
-                             )
-        kd_trainer.train(self.kd_train_params)
-
-    def test_build_model_with_input_adapter(self):
         adapter = NormalizationAdapter(mean_original=[0.485, 0.456, 0.406],
         adapter = NormalizationAdapter(mean_original=[0.485, 0.456, 0.406],
                                        std_original=[0.229, 0.224, 0.225],
                                        std_original=[0.229, 0.224, 0.225],
                                        mean_required=[0.5, 0.5, 0.5],
                                        mean_required=[0.5, 0.5, 0.5],
                                        std_required=[0.5, 0.5, 0.5])
                                        std_required=[0.5, 0.5, 0.5])
-        kd_trainer = KDTrainer("build_kd_module_with_with_input_adapter", device='cpu')
-        kd_trainer.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
-                             student_arch_params={'num_classes': 1000}, teacher_arch_params={'num_classes': 1000},
-                             checkpoint_params={'teacher_pretrained_weights': "imagenet"},
-                             arch_params={"teacher_input_adapter": adapter})
+
+        kd_arch_params = {
+            "teacher_input_adapter": adapter}
+        kd_trainer.train(training_params=self.kd_train_params, student=student, teacher=teacher,
+                         kd_arch_params=kd_arch_params)
+
         self.assertEqual(kd_trainer.net.module.teacher_input_adapter, adapter)
         self.assertEqual(kd_trainer.net.module.teacher_input_adapter, adapter)
 
 
     def test_load_ckpt_best_for_student(self):
     def test_load_ckpt_best_for_student(self):
         kd_trainer = KDTrainer("test_load_ckpt_best", device='cpu')
         kd_trainer = KDTrainer("test_load_ckpt_best", device='cpu')
-        teacher_model = ResNet50(arch_params={}, num_classes=5)
-        teacher_path = '/tmp/teacher.pth'
-        torch.save(teacher_model.state_dict(), teacher_path)
         kd_trainer.connect_dataset_interface(self.dataset)
         kd_trainer.connect_dataset_interface(self.dataset)
-
-        kd_trainer.build_model(student_arch_params={'num_classes': 5},
-                             teacher_arch_params={'num_classes': 5},
-                             student_architecture='resnet18',
-                             teacher_architecture='resnet50',
-                             checkpoint_params={"teacher_checkpoint_path": teacher_path}
-                             )
+        student = models.get('resnet18', arch_params={'num_classes': 5})
+        teacher = models.get('resnet50', arch_params={'num_classes': 5},
+                             pretrained_weights="imagenet")
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
         train_params["max_epochs"] = 1
         train_params["max_epochs"] = 1
-        kd_trainer.train(train_params)
+        kd_trainer.train(training_params=train_params, student=student, teacher=teacher)
         best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
         best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
 
 
-        student_trainer = Trainer("studnet_trainer")
-        student_trainer.build_model("resnet18", arch_params={'num_classes': 5},
-                                     checkpoint_params={"load_checkpoint": True, "external_checkpoint_path": best_student_ckpt})
+        student_reloaded = models.get('resnet18', arch_params={'num_classes': 5},
+                                      checkpoint_path=best_student_ckpt)
 
 
         self.assertTrue(
         self.assertTrue(
-            check_models_have_same_weights(student_trainer.net.module, kd_trainer.net.module.student))
+            check_models_have_same_weights(student_reloaded, kd_trainer.net.module.student))
 
 
     def test_load_ckpt_best_for_student_with_ema(self):
     def test_load_ckpt_best_for_student_with_ema(self):
-        kd_trainer = KDTrainer("test_load_ckpt_best_for_student_with_ema", device='cpu')
-        teacher_model = ResNet50(arch_params={}, num_classes=5)
-        teacher_path = '/tmp/teacher.pth'
-        torch.save(teacher_model.state_dict(), teacher_path)
+        kd_trainer = KDTrainer("test_load_ckpt_best", device='cpu')
         kd_trainer.connect_dataset_interface(self.dataset)
         kd_trainer.connect_dataset_interface(self.dataset)
-
-        kd_trainer.build_model(student_arch_params={'num_classes': 5},
-                             teacher_arch_params={'num_classes': 5},
-                             student_architecture='resnet18',
-                             teacher_architecture='resnet50',
-                             checkpoint_params={"teacher_checkpoint_path": teacher_path}
-                             )
+        student = models.get('resnet18', arch_params={'num_classes': 5})
+        teacher = models.get('resnet50', arch_params={'num_classes': 5},
+                             pretrained_weights="imagenet")
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
         train_params["max_epochs"] = 1
         train_params["max_epochs"] = 1
         train_params["ema"] = True
         train_params["ema"] = True
-        kd_trainer.train(train_params)
+        kd_trainer.train(training_params=train_params, student=student, teacher=teacher)
         best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
         best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
 
 
-        student_trainer = Trainer("studnet_trainer")
-        student_trainer.build_model("resnet18", arch_params={'num_classes': 5},
-                                     checkpoint_params={"load_checkpoint": True, "external_checkpoint_path": best_student_ckpt})
+        student_reloaded = models.get('resnet18', arch_params={'num_classes': 5},
+                                      checkpoint_path=best_student_ckpt)
+
+        self.assertTrue(
+            check_models_have_same_weights(student_reloaded, kd_trainer.ema_model.ema.module.student))
+
+    def test_resume_kd_training(self):
+        kd_trainer = KDTrainer("test_resume_training_start", device='cpu')
+        kd_trainer.connect_dataset_interface(self.dataset)
+        student = models.get('resnet18', arch_params={'num_classes': 5})
+        teacher = models.get('resnet50', arch_params={'num_classes': 5},
+                             pretrained_weights="imagenet")
+        train_params = self.kd_train_params.copy()
+        train_params["max_epochs"] = 1
+        kd_trainer.train(training_params=train_params, student=student, teacher=teacher)
+        latest_net = deepcopy(kd_trainer.net)
+
+        kd_trainer = KDTrainer("test_resume_training_start", device='cpu')
+        kd_trainer.connect_dataset_interface(self.dataset)
+        student = models.get('resnet18', arch_params={'num_classes': 5})
+        teacher = models.get('resnet50', arch_params={'num_classes': 5},
+                             pretrained_weights="imagenet")
+
+        train_params["max_epochs"] = 2
+        train_params["resume"] = True
+        collector = PreTrainingNetCollector()
+        train_params["phase_callbacks"] = [collector]
+        kd_trainer.train(training_params=train_params, student=student, teacher=teacher)
+
         self.assertTrue(
         self.assertTrue(
-            check_models_have_same_weights(student_trainer.net.module, kd_trainer.ema_model.ema.module.student))
+            check_models_have_same_weights(collector.net, latest_net))
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
  1. import shutil
  2. import tempfile
  3. import unittest
  4. import os
  5. from super_gradients.training import Trainer
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from super_gradients.training.sg_trainer.sg_trainer import StrictLoad
  10. class Net(nn.Module):
  11. def __init__(self):
  12. super(Net, self).__init__()
  13. self.conv1 = nn.Conv2d(3, 6, 3)
  14. self.pool = nn.MaxPool2d(2, 2)
  15. self.conv2 = nn.Conv2d(6, 16, 3)
  16. self.fc1 = nn.Linear(16 * 3 * 3, 120)
  17. self.fc2 = nn.Linear(120, 84)
  18. self.fc3 = nn.Linear(84, 10)
  19. def forward(self, x):
  20. x = self.pool(F.relu(self.conv1(x)))
  21. x = self.pool(F.relu(self.conv2(x)))
  22. x = x.view(-1, 16 * 3 * 3)
  23. x = F.relu(self.fc1(x))
  24. x = F.relu(self.fc2(x))
  25. x = self.fc3(x)
  26. return x
  27. class LoadCheckpointFromDirectPathTest(unittest.TestCase):
  28. @classmethod
  29. def setUpClass(cls):
  30. cls.temp_working_file_dir = tempfile.TemporaryDirectory(prefix='load_checkpoint_test').name
  31. if not os.path.isdir(cls.temp_working_file_dir):
  32. os.mkdir(cls.temp_working_file_dir)
  33. cls.checkpoint_path = cls.temp_working_file_dir + '/load_checkpoint_test.pth'
  34. # Setup the model
  35. cls.original_torch_net = Net()
  36. # Save the model's checkpoint
  37. torch.save(cls.original_torch_net.state_dict(), cls.checkpoint_path)
  38. @classmethod
  39. def tearDownClass(cls):
  40. if os.path.isdir(cls.temp_working_file_dir):
  41. shutil.rmtree(cls.temp_working_file_dir)
  42. def test_external_checkpoint_loaded_correctly(self):
  43. # Define Model
  44. new_torch_net = Net()
  45. # Make sure we initialized a model with different weights
  46. assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
  47. # Build the Trainer and load the checkpoint
  48. trainer = Trainer("load_checkpoint_test", model_checkpoints_location='local')
  49. trainer.build_model(new_torch_net, arch_params={'num_classes': 10},
  50. checkpoint_params={'external_checkpoint_path': self.checkpoint_path,
  51. 'load_checkpoint': True,
  52. 'strict_load': StrictLoad.NO_KEY_MATCHING})
  53. # Assert the weights were loaded correctly
  54. assert self.check_models_have_same_weights(trainer.net, self.original_torch_net)
  55. def check_models_have_same_weights(self, model_1, model_2):
  56. model_1, model_2 = model_1.to('cpu'), model_2.to('cpu')
  57. models_differ = 0
  58. for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
  59. if torch.equal(key_item_1[1], key_item_2[1]):
  60. pass
  61. else:
  62. models_differ += 1
  63. if (key_item_1[0] == key_item_2[0]):
  64. print(f'Layer names match but layers have different weights for layers: {key_item_1[0]}')
  65. if models_differ == 0:
  66. return True
  67. else:
  68. return False
  69. if __name__ == '__main__':
  70. unittest.main()
Discard
@@ -1,9 +1,20 @@
 import unittest
 import unittest
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
+from super_gradients.training.utils.callbacks import PhaseCallback, Phase, PhaseContext
 from super_gradients.training.utils.utils import check_models_have_same_weights
 from super_gradients.training.utils.utils import check_models_have_same_weights
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
 from super_gradients.training.models import LeNet
 from super_gradients.training.models import LeNet
+from copy import deepcopy
+
+
+class PreTrainingEMANetCollector(PhaseCallback):
+    def __init__(self):
+        super(PreTrainingEMANetCollector, self).__init__(phase=Phase.PRE_TRAINING)
+        self.net = None
+
+    def __call__(self, context: PhaseContext):
+        self.net = deepcopy(context.ema_model)
 
 
 
 
 class LoadCheckpointWithEmaTest(unittest.TestCase):
 class LoadCheckpointWithEmaTest(unittest.TestCase):
@@ -23,21 +34,22 @@ class LoadCheckpointWithEmaTest(unittest.TestCase):
         trainer = Trainer("ema_ckpt_test", model_checkpoints_location='local')
         trainer = Trainer("ema_ckpt_test", model_checkpoints_location='local')
 
 
         trainer.connect_dataset_interface(self.dataset)
         trainer.connect_dataset_interface(self.dataset)
-        trainer.build_model(net, arch_params={'num_classes': 10})
 
 
-        trainer.train(self.train_params)
+        trainer.train(model=net, training_params=self.train_params)
 
 
         ema_model = trainer.ema_model.ema
         ema_model = trainer.ema_model.ema
 
 
+        # TRAIN FOR 1 MORE EPOCH AND COMPARE THE NET AT THE BEGINNING OF EPOCH 3 AND THE END OF EPOCH NUMBER 2
         net = LeNet()
         net = LeNet()
         trainer = Trainer("ema_ckpt_test", model_checkpoints_location='local')
         trainer = Trainer("ema_ckpt_test", model_checkpoints_location='local')
-        trainer.build_model(net, arch_params={'num_classes': 10}, checkpoint_params={'load_checkpoint': True})
         trainer.connect_dataset_interface(self.dataset)
         trainer.connect_dataset_interface(self.dataset)
+        net_collector = PreTrainingEMANetCollector()
+        self.train_params["resume"] = True
+        self.train_params["max_epochs"] = 3
+        self.train_params["phase_callbacks"] = [net_collector]
+        trainer.train(model=net, training_params=self.train_params)
 
 
-        # TRAIN FOR 0 EPOCHS JUST TO SEE THAT WHEN CONTINUING TRAINING EMA MODEL HAS BEEN SAVED CORRECTLY
-        trainer.train(self.train_params)
-
-        reloaded_ema_model = trainer.ema_model.ema
+        reloaded_ema_model = net_collector.net.ema
 
 
         # ASSERT RELOADED EMA MODEL HAS THE SAME WEIGHTS AS THE EMA MODEL SAVED IN FIRST PART OF TRAINING
         # ASSERT RELOADED EMA MODEL HAS THE SAME WEIGHTS AS THE EMA MODEL SAVED IN FIRST PART OF TRAINING
         assert check_models_have_same_weights(ema_model, reloaded_ema_model)
         assert check_models_have_same_weights(ema_model, reloaded_ema_model)
Discard
@@ -17,7 +17,6 @@ class LRCooldownTest(unittest.TestCase):
         net = LeNet()
         net = LeNet()
         trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
         trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
         trainer.connect_dataset_interface(self.dataset)
         trainer.connect_dataset_interface(self.dataset)
-        trainer.build_model(net, arch_params=self.arch_params)
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -31,7 +30,7 @@ class LRCooldownTest(unittest.TestCase):
                         "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks}
                         "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks}
 
 
         expected_lrs = [0.25, 0.5, 0.75, 0.9236067977499791, 0.4763932022500211, 0.4763932022500211, 0.4763932022500211]
         expected_lrs = [0.25, 0.5, 0.75, 0.9236067977499791, 0.4763932022500211, 0.4763932022500211, 0.4763932022500211]
-        trainer.train(train_params)
+        trainer.train(model=net, training_params=train_params)
 
 
         # ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
         # ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
         # THE LRS AFTER THE UPDATE
         # THE LRS AFTER THE UPDATE
Discard
@@ -44,7 +44,6 @@ class LRWarmupTest(unittest.TestCase):
         net = LeNet()
         net = LeNet()
         trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
         trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
         trainer.connect_dataset_interface(self.dataset)
         trainer.connect_dataset_interface(self.dataset)
-        trainer.build_model(net, arch_params=self.arch_params)
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -58,15 +57,14 @@ class LRWarmupTest(unittest.TestCase):
                         "warmup_mode": "linear_step"}
                         "warmup_mode": "linear_step"}
 
 
         expected_lrs = [0.25, 0.5, 0.75, 1.0, 1.0]
         expected_lrs = [0.25, 0.5, 0.75, 1.0, 1.0]
-        trainer.train(train_params)
+        trainer.train(model=net, training_params=train_params)
         self.assertListEqual(lrs, expected_lrs)
         self.assertListEqual(lrs, expected_lrs)
 
 
     def test_lr_warmup_with_lr_scheduling(self):
     def test_lr_warmup_with_lr_scheduling(self):
-        # Define Model
+        # Define model
         net = LeNet()
         net = LeNet()
         trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
         trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
         trainer.connect_dataset_interface(self.dataset)
         trainer.connect_dataset_interface(self.dataset)
-        trainer.build_model(net, arch_params=self.arch_params)
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -80,18 +78,17 @@ class LRWarmupTest(unittest.TestCase):
                         "warmup_mode": "linear_step"}
                         "warmup_mode": "linear_step"}
 
 
         expected_lrs = [0.25, 0.5, 0.75, 0.9236067977499791, 0.4763932022500211]
         expected_lrs = [0.25, 0.5, 0.75, 0.9236067977499791, 0.4763932022500211]
-        trainer.train(train_params)
+        trainer.train(model=net, training_params=train_params)
 
 
         # ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
         # ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
         # THE LRS AFTER THE UPDATE
         # THE LRS AFTER THE UPDATE
         self.assertListEqual(lrs, expected_lrs)
         self.assertListEqual(lrs, expected_lrs)
 
 
     def test_warmup_initial_lr(self):
     def test_warmup_initial_lr(self):
-        # Define Model
+        # Define model
         net = LeNet()
         net = LeNet()
         trainer = Trainer("test_warmup_initial_lr", model_checkpoints_location='local')
         trainer = Trainer("test_warmup_initial_lr", model_checkpoints_location='local')
         trainer.connect_dataset_interface(self.dataset)
         trainer.connect_dataset_interface(self.dataset)
-        trainer.build_model(net, arch_params=self.arch_params)
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -105,15 +102,14 @@ class LRWarmupTest(unittest.TestCase):
                         "warmup_mode": "linear_step", "initial_lr": 1, "warmup_initial_lr": 4.}
                         "warmup_mode": "linear_step", "initial_lr": 1, "warmup_initial_lr": 4.}
 
 
         expected_lrs = [4., 3., 2., 1., 1.]
         expected_lrs = [4., 3., 2., 1., 1.]
-        trainer.train(train_params)
+        trainer.train(model=net, training_params=train_params)
         self.assertListEqual(lrs, expected_lrs)
         self.assertListEqual(lrs, expected_lrs)
 
 
     def test_custom_lr_warmup(self):
     def test_custom_lr_warmup(self):
-        # Define Model
+        # Define model
         net = LeNet()
         net = LeNet()
         trainer = Trainer("custom_lr_warmup_test", model_checkpoints_location='local')
         trainer = Trainer("custom_lr_warmup_test", model_checkpoints_location='local')
         trainer.connect_dataset_interface(self.dataset)
         trainer.connect_dataset_interface(self.dataset)
-        trainer.build_model(net, arch_params=self.arch_params)
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -127,7 +123,7 @@ class LRWarmupTest(unittest.TestCase):
                         "warmup_mode": ExponentialWarmupLRCallback, "initial_lr": 1., "warmup_initial_lr": 0.1}
                         "warmup_mode": ExponentialWarmupLRCallback, "initial_lr": 1., "warmup_initial_lr": 0.1}
 
 
         expected_lrs = [0.1, 0.18102751585334242, 0.40128313980266034, 1.0, 1.0]
         expected_lrs = [0.1, 0.18102751585334242, 0.40128313980266034, 1.0, 1.0]
-        trainer.train(train_params)
+        trainer.train(model=net, training_params=train_params)
         self.assertListEqual(lrs, expected_lrs)
         self.assertListEqual(lrs, expected_lrs)
 
 
 
 
Discard
@@ -17,7 +17,6 @@ class PhaseContextTest(unittest.TestCase):
         trainer.connect_dataset_interface(dataset)
         trainer.connect_dataset_interface(dataset)
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
-        trainer.build_model(net)
 
 
         phase_callbacks = [PhaseContextTestCallback(Phase.TRAIN_BATCH_END),
         phase_callbacks = [PhaseContextTestCallback(Phase.TRAIN_BATCH_END),
                            PhaseContextTestCallback(Phase.TRAIN_BATCH_STEP),
                            PhaseContextTestCallback(Phase.TRAIN_BATCH_STEP),
@@ -32,7 +31,7 @@ class PhaseContextTest(unittest.TestCase):
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Top5",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Top5",
                         "greater_metric_to_watch_is_better": True, "phase_callbacks": phase_callbacks}
                         "greater_metric_to_watch_is_better": True, "phase_callbacks": phase_callbacks}
 
 
-        trainer.train(train_params)
+        trainer.train(model=net, training_params=train_params)
         context_callbacks = list(filter(lambda cb: isinstance(cb, PhaseContextTestCallback), trainer.phase_callbacks))
         context_callbacks = list(filter(lambda cb: isinstance(cb, PhaseContextTestCallback), trainer.phase_callbacks))
 
 
         # CHECK THAT PHASE CONTEXES HAVE THE EXACT INFORMATION THERY'RE SUPPOSE TO HOLD
         # CHECK THAT PHASE CONTEXES HAVE THE EXACT INFORMATION THERY'RE SUPPOSE TO HOLD
Discard
@@ -1,7 +1,8 @@
 import unittest
 import unittest
+
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
-from super_gradients.training.metrics import Accuracy
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
+from super_gradients.training.metrics import Accuracy
 from super_gradients.training.models import LeNet
 from super_gradients.training.models import LeNet
 from super_gradients.training.utils.callbacks import Phase, PhaseCallback, PhaseContext
 from super_gradients.training.utils.callbacks import Phase, PhaseCallback, PhaseContext
 
 
@@ -37,7 +38,6 @@ class ContextMethodsTest(unittest.TestCase):
         net = LeNet()
         net = LeNet()
         trainer = Trainer("test_access_to_methods_by_phase", model_checkpoints_location='local')
         trainer = Trainer("test_access_to_methods_by_phase", model_checkpoints_location='local')
         trainer.connect_dataset_interface(self.dataset)
         trainer.connect_dataset_interface(self.dataset)
-        trainer.build_model(net, arch_params=self.arch_params)
 
 
         phase_callbacks = []
         phase_callbacks = []
         for phase in Phase:
         for phase in Phase:
@@ -68,7 +68,7 @@ class ContextMethodsTest(unittest.TestCase):
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks}
                         "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks}
 
 
-        trainer.train(train_params)
+        trainer.train(model=net, training_params=train_params)
         for phase_callback in phase_callbacks:
         for phase_callback in phase_callbacks:
             if isinstance(phase_callback, ContextMethodsCheckerCallback):
             if isinstance(phase_callback, ContextMethodsCheckerCallback):
                 self.assertTrue(phase_callback.result)
                 self.assertTrue(phase_callback.result)
Discard
@@ -1,6 +1,6 @@
 import unittest
 import unittest
 import super_gradients
 import super_gradients
-from super_gradients.training import MultiGPUMode
+from super_gradients.training import MultiGPUMode, models
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
@@ -19,25 +19,24 @@ class PretrainedModelsUnitTest(unittest.TestCase):
         trainer = Trainer('imagenet_pretrained_resnet50_unit_test', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_resnet50_unit_test', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.test_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.test_dataset, data_loader_num_workers=8)
-        trainer.build_model("resnet50", checkpoint_params={"pretrained_weights": "imagenet"})
-        trainer.test(test_loader=self.test_dataset.val_loader, test_metrics_list=[Accuracy()],
+        model = models.get("resnet50", pretrained_weights="imagenet")
+        trainer.test(model=model, test_loader=self.test_dataset.val_loader, test_metrics_list=[Accuracy()],
                      metrics_progress_verbose=True)
                      metrics_progress_verbose=True)
 
 
     def test_pretrained_regnetY800_imagenet(self):
     def test_pretrained_regnetY800_imagenet(self):
         trainer = Trainer('imagenet_pretrained_regnetY800_unit_test', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_regnetY800_unit_test', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.test_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.test_dataset, data_loader_num_workers=8)
-        trainer.build_model("regnetY800", checkpoint_params={"pretrained_weights": "imagenet"})
-        trainer.test(test_loader=self.test_dataset.val_loader, test_metrics_list=[Accuracy()],
+        model = models.get("regnetY800", pretrained_weights="imagenet")
+        trainer.test(model=model, test_loader=self.test_dataset.val_loader, test_metrics_list=[Accuracy()],
                      metrics_progress_verbose=True)
                      metrics_progress_verbose=True)
 
 
     def test_pretrained_repvgg_a0_imagenet(self):
     def test_pretrained_repvgg_a0_imagenet(self):
         trainer = Trainer('imagenet_pretrained_repvgg_a0_unit_test', model_checkpoints_location='local',
         trainer = Trainer('imagenet_pretrained_repvgg_a0_unit_test', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.test_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.test_dataset, data_loader_num_workers=8)
-        trainer.build_model("repvgg_a0", checkpoint_params={"pretrained_weights": "imagenet"},
-                            arch_params={"build_residual_branches": True})
-        trainer.test(test_loader=self.test_dataset.val_loader, test_metrics_list=[Accuracy()],
+        model = models.get("repvgg_a0", pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
+        trainer.test(model=model, test_loader=self.test_dataset.val_loader, test_metrics_list=[Accuracy()],
                      metrics_progress_verbose=True)
                      metrics_progress_verbose=True)
 
 
     def tearDown(self) -> None:
     def tearDown(self) -> None:
Discard
@@ -1,6 +1,6 @@
 import unittest
 import unittest
 import os
 import os
-from super_gradients.training import Trainer
+from super_gradients.training import Trainer, models
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 
 
@@ -25,10 +25,10 @@ class SaveCkptListUnitTest(unittest.TestCase):
         trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
 
 
         # Build Model
         # Build Model
-        trainer.build_model("resnet18_cifar")
+        model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
 
 
         # Train Model (and save ckpt_epoch_list)
         # Train Model (and save ckpt_epoch_list)
-        trainer.train(training_params=train_params)
+        trainer.train(model=model, training_params=train_params)
 
 
         dir_path = trainer.checkpoints_dir_path
         dir_path = trainer.checkpoints_dir_path
         self.file_names_list = [dir_path + f'/ckpt_epoch_{epoch}.pth' for epoch in train_params["save_ckpt_epoch_list"]]
         self.file_names_list = [dir_path + f'/ckpt_epoch_{epoch}.pth' for epoch in train_params["save_ckpt_epoch_list"]]
Discard
@@ -1,7 +1,7 @@
+import os
 import shutil
 import shutil
 import tempfile
 import tempfile
 import unittest
 import unittest
-import os
 
 
 from super_gradients.common.sg_loggers import BaseSGLogger
 from super_gradients.common.sg_loggers import BaseSGLogger
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
@@ -9,6 +9,7 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.functional as F
 
 
+from super_gradients.training import models
 from super_gradients.training.sg_trainer.sg_trainer import StrictLoad
 from super_gradients.training.sg_trainer.sg_trainer import StrictLoad
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.utils import HpmStruct
 
 
@@ -46,14 +47,14 @@ class StrictLoadEnumTest(unittest.TestCase):
         cls.checkpoint_diff_keys_path = cls.temp_working_file_dir + '/' + cls.checkpoint_diff_keys_name
         cls.checkpoint_diff_keys_path = cls.temp_working_file_dir + '/' + cls.checkpoint_diff_keys_name
 
 
         # Setup the model
         # Setup the model
-        cls.original_torch_net = Net()
+        cls.original_torch_model = Net()
 
 
         # Save the model's state_dict checkpoint with different keys
         # Save the model's state_dict checkpoint with different keys
-        torch.save(cls.change_state_dict_keys(cls.original_torch_net.state_dict()), cls.checkpoint_diff_keys_path)
+        torch.save(cls.change_state_dict_keys(cls.original_torch_model.state_dict()), cls.checkpoint_diff_keys_path)
 
 
         # Save the model's state_dict checkpoint in Trainer format
         # Save the model's state_dict checkpoint in Trainer format
         cls.trainer = Trainer("load_checkpoint_test", model_checkpoints_location='local')  # Saves in /checkpoints
         cls.trainer = Trainer("load_checkpoint_test", model_checkpoints_location='local')  # Saves in /checkpoints
-        cls.trainer.build_model(cls.original_torch_net, arch_params={'num_classes': 10})
+        cls.trainer.build_model(cls.original_torch_model, arch_params={'num_classes': 10})
         # FIXME: after uniting init and build_model we should remove this
         # FIXME: after uniting init and build_model we should remove this
         cls.trainer.sg_logger = BaseSGLogger('project_name', 'load_checkpoint_test', 'local', resumed=False,
         cls.trainer.sg_logger = BaseSGLogger('project_name', 'load_checkpoint_test', 'local', resumed=False,
                                              training_params=HpmStruct(max_epochs=10),
                                              training_params=HpmStruct(max_epochs=10),
@@ -69,7 +70,7 @@ class StrictLoadEnumTest(unittest.TestCase):
     def change_state_dict_keys(self, state_dict):
     def change_state_dict_keys(self, state_dict):
         new_ckpt_dict = {}
         new_ckpt_dict = {}
         for i, (ckpt_key, ckpt_val) in enumerate(state_dict.items()):
         for i, (ckpt_key, ckpt_val) in enumerate(state_dict.items()):
-            new_ckpt_dict[i] = ckpt_val
+            new_ckpt_dict[str(i)] = ckpt_val
         return new_ckpt_dict
         return new_ckpt_dict
 
 
     def check_models_have_same_weights(self, model_1, model_2):
     def check_models_have_same_weights(self, model_1, model_2):
@@ -91,71 +92,65 @@ class StrictLoadEnumTest(unittest.TestCase):
 
 
     def test_strict_load_on(self):
     def test_strict_load_on(self):
         # Define Model
         # Define Model
-        new_torch_net = Net()
+        model = models.get('resnet18', arch_params={"num_classes": 1000})
+        pretrained_model = models.get('resnet18', arch_params={"num_classes": 1000},
+                                      pretrained_weights="imagenet")
 
 
         # Make sure we initialized a model with different weights
         # Make sure we initialized a model with different weights
-        assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
+        assert not self.check_models_have_same_weights(model, pretrained_model)
+
+        pretrained_sd_path = os.path.join(self.temp_working_file_dir, "pretrained_net_strict_load_on.pth")
+        torch.save(pretrained_model.state_dict(), pretrained_sd_path)
 
 
-        # Build the Trainer and load the checkpoint
-        trainer = Trainer(self.experiment_name, model_checkpoints_location='local',
-                          ckpt_name='ckpt_latest_weights_only.pth')
-        trainer.build_model(new_torch_net, arch_params={'num_classes': 10},
-                            checkpoint_params={'strict_load': StrictLoad.ON,
-                                               'load_checkpoint': True})
+        model = models.get('resnet18', arch_params={"num_classes": 1000},
+                           checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
 
 
         # Assert the weights were loaded correctly
         # Assert the weights were loaded correctly
-        assert self.check_models_have_same_weights(trainer.net, self.original_torch_net)
+        assert self.check_models_have_same_weights(model, pretrained_model)
 
 
     def test_strict_load_off(self):
     def test_strict_load_off(self):
         # Define Model
         # Define Model
-        new_torch_net = Net()
+        model = models.get('resnet18', arch_params={"num_classes": 1000})
+        pretrained_model = models.get('resnet18', arch_params={"num_classes": 1000},
+                                      pretrained_weights="imagenet")
 
 
         # Make sure we initialized a model with different weights
         # Make sure we initialized a model with different weights
-        assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
-
-        # Build the Trainer and load the checkpoint
-        trainer = Trainer(self.experiment_name, model_checkpoints_location='local',
-                          ckpt_name='ckpt_latest_weights_only.pth')
-        trainer.build_model(new_torch_net, arch_params={'num_classes': 10},
-                            checkpoint_params={'strict_load': StrictLoad.OFF,
-                                               'load_checkpoint': True})
-
-        # Assert the weights were loaded correctly
-        assert self.check_models_have_same_weights(trainer.net, self.original_torch_net)
+        assert not self.check_models_have_same_weights(model, pretrained_model)
 
 
-    def test_strict_load_no_key_matching_external_checkpoint(self):
-        # Define Model
-        new_torch_net = Net()
+        pretrained_sd_path = os.path.join(self.temp_working_file_dir, "pretrained_net_strict_load_off.pth")
+        del pretrained_model.linear
+        torch.save(pretrained_model.state_dict(), pretrained_sd_path)
 
 
-        # Make sure we initialized a model with different weights
-        assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
-
-        # Build the Trainer and load the checkpoint
-        trainer = Trainer(self.experiment_name, model_checkpoints_location='local')
-        trainer.build_model(new_torch_net, arch_params={'num_classes': 10},
-                            checkpoint_params={'strict_load': StrictLoad.NO_KEY_MATCHING,
-                                               'external_checkpoint_path': self.checkpoint_diff_keys_path,
-                                               'load_checkpoint': True})
+        with self.assertRaises(RuntimeError):
+            models.get('resnet18', arch_params={"num_classes": 1000},
+                       checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
 
 
+        model = models.get('resnet18', arch_params={"num_classes": 1000},
+                           checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.OFF)
+        del model.linear
         # Assert the weights were loaded correctly
         # Assert the weights were loaded correctly
-        assert self.check_models_have_same_weights(trainer.net, self.original_torch_net)
+        assert self.check_models_have_same_weights(model, pretrained_model)
 
 
     def test_strict_load_no_key_matching_sg_checkpoint(self):
     def test_strict_load_no_key_matching_sg_checkpoint(self):
         # Define Model
         # Define Model
-        new_torch_net = Net()
+        model = models.get('resnet18', arch_params={"num_classes": 1000})
+        pretrained_model = models.get('resnet18', arch_params={"num_classes": 1000},
+                                      pretrained_weights="imagenet")
 
 
         # Make sure we initialized a model with different weights
         # Make sure we initialized a model with different weights
-        assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
+        assert not self.check_models_have_same_weights(model, pretrained_model)
+
+        pretrained_sd_path = os.path.join(self.temp_working_file_dir, "pretrained_model_strict_load_soft.pth")
+        torch.save(self.change_state_dict_keys(pretrained_model.state_dict()), pretrained_sd_path)
 
 
-        # Build the Trainer and load the checkpoint
-        trainer = Trainer(self.experiment_name, model_checkpoints_location='local',
-                          ckpt_name='ckpt_latest_weights_only.pth')
-        trainer.build_model(new_torch_net, arch_params={'num_classes': 10},
-                            checkpoint_params={'strict_load': StrictLoad.NO_KEY_MATCHING,
-                                               'load_checkpoint': True})
+        with self.assertRaises(RuntimeError):
+            models.get('resnet18', arch_params={"num_classes": 1000},
+                       checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
 
 
+        model = models.get('resnet18', arch_params={"num_classes": 1000},
+                           checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.NO_KEY_MATCHING)
         # Assert the weights were loaded correctly
         # Assert the weights were loaded correctly
-        assert self.check_models_have_same_weights(trainer.net, self.original_torch_net)
+        assert self.check_models_have_same_weights(model, pretrained_model)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -5,7 +5,7 @@ from super_gradients import Trainer, \
     ClassificationTestDatasetInterface, \
     ClassificationTestDatasetInterface, \
     SegmentationTestDatasetInterface, DetectionTestDatasetInterface
     SegmentationTestDatasetInterface, DetectionTestDatasetInterface
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
-from super_gradients.training import MultiGPUMode
+from super_gradients.training import MultiGPUMode, models
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 from super_gradients.training.utils.detection_utils import DetectionCollateFN
 from super_gradients.training.utils.detection_utils import DetectionCollateFN
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
@@ -31,8 +31,8 @@ class TestWithoutTrainTest(unittest.TestCase):
         dataset_params = {"batch_size": 4}
         dataset_params = {"batch_size": 4}
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         trainer.connect_dataset_interface(dataset)
         trainer.connect_dataset_interface(dataset)
-        trainer.build_model("resnet18_cifar")
-        return trainer
+        model = models.get("resnet18", arch_params={"num_classes": 5})
+        return trainer, model
 
 
     @staticmethod
     @staticmethod
     def get_detection_trainer(name=''):
     def get_detection_trainer(name=''):
@@ -50,8 +50,8 @@ class TestWithoutTrainTest(unittest.TestCase):
                           post_prediction_callback=YoloPostPredictionCallback())
                           post_prediction_callback=YoloPostPredictionCallback())
         dataset_interface = DetectionTestDatasetInterface(dataset_params=dataset_params)
         dataset_interface = DetectionTestDatasetInterface(dataset_params=dataset_params)
         trainer.connect_dataset_interface(dataset_interface, data_loader_num_workers=4)
         trainer.connect_dataset_interface(dataset_interface, data_loader_num_workers=4)
-        trainer.build_model('yolox_s')
-        return trainer
+        model = models.get("yolox_s", arch_params={"num_classes": 5})
+        return trainer, model
 
 
     @staticmethod
     @staticmethod
     def get_segmentation_trainer(name=''):
     def get_segmentation_trainer(name=''):
@@ -60,34 +60,38 @@ class TestWithoutTrainTest(unittest.TestCase):
 
 
         dataset_interface = SegmentationTestDatasetInterface()
         dataset_interface = SegmentationTestDatasetInterface()
         trainer.connect_dataset_interface(dataset_interface, data_loader_num_workers=8)
         trainer.connect_dataset_interface(dataset_interface, data_loader_num_workers=8)
-        trainer.build_model('shelfnet34_lw', arch_params=shelfnet_lw_arch_params)
-        return trainer
+        model = models.get('shelfnet34_lw', arch_params=shelfnet_lw_arch_params)
+        return trainer, model
 
 
     def test_test_without_train(self):
     def test_test_without_train(self):
-        trainer = self.get_classification_trainer(self.folder_names[0])
-        assert isinstance(trainer.test(silent_mode=True, test_metrics_list=[Accuracy(), Top5()]), tuple)
+        trainer, model = self.get_classification_trainer(self.folder_names[0])
+        assert isinstance(trainer.test(model=model, silent_mode=True, test_metrics_list=[Accuracy(), Top5()]), tuple)
 
 
-        trainer = self.get_detection_trainer(self.folder_names[1])
+        trainer, model = self.get_detection_trainer(self.folder_names[1])
 
 
         test_metrics = [DetectionMetrics(post_prediction_callback=trainer.post_prediction_callback, num_cls=5)]
         test_metrics = [DetectionMetrics(post_prediction_callback=trainer.post_prediction_callback, num_cls=5)]
 
 
-        assert isinstance(trainer.test(silent_mode=True, test_metrics_list=test_metrics), tuple)
+        assert isinstance(trainer.test(model=model, silent_mode=True, test_metrics_list=test_metrics), tuple)
 
 
-        trainer = self.get_segmentation_trainer(self.folder_names[2])
-        assert isinstance(trainer.test(silent_mode=True, test_metrics_list=[IoU(21), PixelAccuracy()]), tuple)
+        trainer, model = self.get_segmentation_trainer(self.folder_names[2])
+        assert isinstance(trainer.test(model=model, silent_mode=True, test_metrics_list=[IoU(21), PixelAccuracy()]), tuple)
 
 
     def test_test_on_valid_loader_without_train(self):
     def test_test_on_valid_loader_without_train(self):
-        trainer = self.get_classification_trainer(self.folder_names[0])
-        assert isinstance(trainer.test(test_loader=trainer.valid_loader, silent_mode=True, test_metrics_list=[Accuracy(), Top5()]), tuple)
+        trainer, model = self.get_classification_trainer(self.folder_names[0])
+        assert isinstance(trainer.test(model=model, test_loader=trainer.valid_loader, silent_mode=True,
+                                       test_metrics_list=[Accuracy(), Top5()]), tuple)
 
 
-        trainer = self.get_detection_trainer(self.folder_names[1])
+        trainer, model = self.get_detection_trainer(self.folder_names[1])
 
 
         test_metrics = [DetectionMetrics(post_prediction_callback=trainer.post_prediction_callback, num_cls=5)]
         test_metrics = [DetectionMetrics(post_prediction_callback=trainer.post_prediction_callback, num_cls=5)]
 
 
-        assert isinstance(trainer.test(test_loader=trainer.valid_loader, silent_mode=True, test_metrics_list=test_metrics), tuple)
+        assert isinstance(
+            trainer.test(model=model, test_loader=trainer.valid_loader, silent_mode=True, test_metrics_list=test_metrics),
+            tuple)
 
 
-        model = self.get_segmentation_trainer(self.folder_names[2])
-        assert isinstance(model.test(test_loader=model.valid_loader, silent_mode=True, test_metrics_list=[IoU(21), PixelAccuracy()]), tuple)
+        trainer, model = self.get_segmentation_trainer(self.folder_names[2])
+        assert isinstance(trainer.test(model=model, test_loader=trainer.valid_loader, silent_mode=True,
+                                       test_metrics_list=[IoU(21), PixelAccuracy()]), tuple)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -17,7 +17,6 @@ class SgTrainerLoggingTest(unittest.TestCase):
         trainer.connect_dataset_interface(dataset)
         trainer.connect_dataset_interface(dataset)
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
-        trainer.build_model(net)
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
@@ -26,7 +25,7 @@ class SgTrainerLoggingTest(unittest.TestCase):
                         "greater_metric_to_watch_is_better": True,
                         "greater_metric_to_watch_is_better": True,
                         "save_full_train_log": True}
                         "save_full_train_log": True}
 
 
-        trainer.train(train_params)
+        trainer.train(model=net, training_params=train_params)
 
 
         logfile_path = trainer.log_file.replace('.txt', 'full_train_log.log')
         logfile_path = trainer.log_file.replace('.txt', 'full_train_log.log')
         assert os.path.exists(logfile_path) and os.path.getsize(logfile_path) > 0
         assert os.path.exists(logfile_path) and os.path.getsize(logfile_path) > 0
Discard
@@ -1,15 +1,17 @@
 import unittest
 import unittest
-from super_gradients import Trainer, \
-    ClassificationTestDatasetInterface
-from super_gradients.training.metrics import Accuracy, Top5, ToyTestClassificationMetric
-from super_gradients.training.models import ResNet18
+
+import numpy as np
+import torch
 from torch.optim import SGD
 from torch.optim import SGD
 from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR
 from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR
-from super_gradients.training.utils.callbacks import LRSchedulerCallback, Phase
 from torchmetrics import F1Score
 from torchmetrics import F1Score
-import torch
-import numpy as np
+
+from super_gradients import Trainer, \
+    ClassificationTestDatasetInterface
+from super_gradients.training import models
 from super_gradients.training.datasets.dataset_interfaces import DatasetInterface
 from super_gradients.training.datasets.dataset_interfaces import DatasetInterface
+from super_gradients.training.metrics import Accuracy, Top5, ToyTestClassificationMetric
+from super_gradients.training.utils.callbacks import LRSchedulerCallback, Phase
 
 
 
 
 class TrainWithInitializedObjectsTest(unittest.TestCase):
 class TrainWithInitializedObjectsTest(unittest.TestCase):
@@ -23,15 +25,15 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         trainer.connect_dataset_interface(dataset)
         trainer.connect_dataset_interface(dataset)
 
 
-        net = ResNet18(num_classes=5, arch_params={})
-        trainer.build_model(net)
+        model = models.get("resnet18", arch_params={"num_classes": 5})
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
-                        "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": torch.nn.CrossEntropyLoss(), "optimizer": "SGD",
+                        "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": torch.nn.CrossEntropyLoss(),
+                        "optimizer": "SGD",
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
                         "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
                         "metric_to_watch": "Accuracy",
                         "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
-        trainer.train(train_params)
+        trainer.train(model=model, training_params=train_params)
 
 
     def test_train_with_external_optimizer(self):
     def test_train_with_external_optimizer(self):
         trainer = Trainer("external_optimizer_test", model_checkpoints_location='local')
         trainer = Trainer("external_optimizer_test", model_checkpoints_location='local')
@@ -39,16 +41,15 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         trainer.connect_dataset_interface(dataset)
         trainer.connect_dataset_interface(dataset)
 
 
-        net = ResNet18(num_classes=5, arch_params={})
-        optimizer = SGD(params=net.parameters(), lr=0.1)
-        trainer.build_model(net)
+        model = models.get("resnet18", arch_params={"num_classes": 5})
+        optimizer = SGD(params=model.parameters(), lr=0.1)
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": optimizer,
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": optimizer,
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "train_metrics_list": [Accuracy(), Top5()], "valid_metrics_list": [Accuracy(), Top5()],
                         "train_metrics_list": [Accuracy(), Top5()], "valid_metrics_list": [Accuracy(), Top5()],
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
-        trainer.train(train_params)
+        trainer.train(model=model, training_params=train_params)
 
 
     def test_train_with_external_scheduler(self):
     def test_train_with_external_scheduler(self):
         trainer = Trainer("external_scheduler_test", model_checkpoints_location='local')
         trainer = Trainer("external_scheduler_test", model_checkpoints_location='local')
@@ -57,11 +58,10 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer.connect_dataset_interface(dataset)
         trainer.connect_dataset_interface(dataset)
 
 
         lr = 0.3
         lr = 0.3
-        net = ResNet18(num_classes=5, arch_params={})
-        optimizer = SGD(params=net.parameters(), lr=lr)
+        model = models.get("resnet18", arch_params={"num_classes": 5})
+        optimizer = SGD(params=model.parameters(), lr=lr)
         lr_scheduler = MultiStepLR(optimizer=optimizer, milestones=[1, 2], gamma=0.1)
         lr_scheduler = MultiStepLR(optimizer=optimizer, milestones=[1, 2], gamma=0.1)
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.TRAIN_EPOCH_END)]
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.TRAIN_EPOCH_END)]
-        trainer.build_model(net)
 
 
         train_params = {"max_epochs": 2, "phase_callbacks": phase_callbacks,
         train_params = {"max_epochs": 2, "phase_callbacks": phase_callbacks,
                         "lr_warmup_epochs": 0, "initial_lr": lr, "loss": "cross_entropy", "optimizer": optimizer,
                         "lr_warmup_epochs": 0, "initial_lr": lr, "loss": "cross_entropy", "optimizer": optimizer,
@@ -69,7 +69,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
                         "train_metrics_list": [Accuracy(), Top5()], "valid_metrics_list": [Accuracy(), Top5()],
                         "train_metrics_list": [Accuracy(), Top5()], "valid_metrics_list": [Accuracy(), Top5()],
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
-        trainer.train(train_params)
+        trainer.train(model=model, training_params=train_params)
         assert lr_scheduler.get_last_lr()[0] == lr * 0.1 * 0.1
         assert lr_scheduler.get_last_lr()[0] == lr * 0.1 * 0.1
 
 
     def test_train_with_external_scheduler_class(self):
     def test_train_with_external_scheduler_class(self):
@@ -78,9 +78,8 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         trainer.connect_dataset_interface(dataset)
         trainer.connect_dataset_interface(dataset)
 
 
-        net = ResNet18(num_classes=5, arch_params={})
+        model = models.get("resnet18", arch_params={"num_classes": 5})
         optimizer = SGD  # a class - not an instance
         optimizer = SGD  # a class - not an instance
-        trainer.build_model(net)
 
 
         train_params = {"max_epochs": 2,
         train_params = {"max_epochs": 2,
                         "lr_warmup_epochs": 0, "initial_lr": 0.3, "loss": "cross_entropy", "optimizer": optimizer,
                         "lr_warmup_epochs": 0, "initial_lr": 0.3, "loss": "cross_entropy", "optimizer": optimizer,
@@ -88,7 +87,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
                         "train_metrics_list": [Accuracy(), Top5()], "valid_metrics_list": [Accuracy(), Top5()],
                         "train_metrics_list": [Accuracy(), Top5()], "valid_metrics_list": [Accuracy(), Top5()],
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
-        trainer.train(train_params)
+        trainer.train(model=model, training_params=train_params)
 
 
     def test_train_with_reduce_on_plateau(self):
     def test_train_with_reduce_on_plateau(self):
         trainer = Trainer("external_reduce_on_plateau_scheduler_test", model_checkpoints_location='local')
         trainer = Trainer("external_reduce_on_plateau_scheduler_test", model_checkpoints_location='local')
@@ -97,11 +96,10 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer.connect_dataset_interface(dataset)
         trainer.connect_dataset_interface(dataset)
 
 
         lr = 0.3
         lr = 0.3
-        net = ResNet18(num_classes=5, arch_params={})
-        optimizer = SGD(params=net.parameters(), lr=lr)
+        model = models.get("resnet18", arch_params={"num_classes": 5})
+        optimizer = SGD(params=model.parameters(), lr=lr)
         lr_scheduler = ReduceLROnPlateau(optimizer=optimizer, patience=0)
         lr_scheduler = ReduceLROnPlateau(optimizer=optimizer, patience=0)
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.VALIDATION_EPOCH_END, "ToyTestClassificationMetric")]
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.VALIDATION_EPOCH_END, "ToyTestClassificationMetric")]
-        trainer.build_model(net)
 
 
         train_params = {"max_epochs": 2, "phase_callbacks": phase_callbacks,
         train_params = {"max_epochs": 2, "phase_callbacks": phase_callbacks,
                         "lr_warmup_epochs": 0, "initial_lr": lr, "loss": "cross_entropy", "optimizer": optimizer,
                         "lr_warmup_epochs": 0, "initial_lr": lr, "loss": "cross_entropy", "optimizer": optimizer,
@@ -110,7 +108,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
                         "valid_metrics_list": [Accuracy(), Top5(), ToyTestClassificationMetric()],
                         "valid_metrics_list": [Accuracy(), Top5(), ToyTestClassificationMetric()],
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
-        trainer.train(train_params)
+        trainer.train(model=model, training_params=train_params)
         assert lr_scheduler._last_lr[0] == lr * 0.1
         assert lr_scheduler._last_lr[0] == lr * 0.1
 
 
     def test_train_with_external_metric(self):
     def test_train_with_external_metric(self):
@@ -119,23 +117,24 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         trainer.connect_dataset_interface(dataset)
         trainer.connect_dataset_interface(dataset)
 
 
-        net = ResNet18(num_classes=5, arch_params={})
-        trainer.build_model(net)
+        model = models.get("resnet18", arch_params={"num_classes": 5})
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "train_metrics_list": [F1Score()], "valid_metrics_list": [F1Score()],
                         "train_metrics_list": [F1Score()], "valid_metrics_list": [F1Score()],
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "F1Score",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "F1Score",
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
-        trainer.train(train_params)
+        trainer.train(model=model, training_params=train_params)
 
 
     def test_train_with_external_dataloaders(self):
     def test_train_with_external_dataloaders(self):
         trainer = Trainer("external_data_loader_test", model_checkpoints_location='local')
         trainer = Trainer("external_data_loader_test", model_checkpoints_location='local')
 
 
         batch_size = 5
         batch_size = 5
-        trainset = torch.utils.data.TensorDataset(torch.Tensor(np.random.random((10, 3, 32, 32))), torch.LongTensor(np.zeros((10))))
+        trainset = torch.utils.data.TensorDataset(torch.Tensor(np.random.random((10, 3, 32, 32))),
+                                                  torch.LongTensor(np.zeros((10))))
 
 
-        valset = torch.utils.data.TensorDataset(torch.Tensor(np.random.random((10, 3, 32, 32))), torch.LongTensor(np.zeros((10))))
+        valset = torch.utils.data.TensorDataset(torch.Tensor(np.random.random((10, 3, 32, 32))),
+                                                torch.LongTensor(np.zeros((10))))
 
 
         classes = [0, 1, 2, 3, 4]
         classes = [0, 1, 2, 3, 4]
         train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size)
         train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size)
@@ -144,15 +143,14 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         dataset_interface = DatasetInterface(train_loader=train_loader, val_loader=val_loader, classes=classes)
         dataset_interface = DatasetInterface(train_loader=train_loader, val_loader=val_loader, classes=classes)
         trainer.connect_dataset_interface(dataset_interface)
         trainer.connect_dataset_interface(dataset_interface)
 
 
-        net = ResNet18(num_classes=5, arch_params={})
-        trainer.build_model(net)
+        model = models.get("resnet18", arch_params={"num_classes": 5})
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "train_metrics_list": [F1Score()], "valid_metrics_list": [F1Score()],
                         "train_metrics_list": [F1Score()], "valid_metrics_list": [F1Score()],
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "F1Score",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "F1Score",
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
-        trainer.train(train_params)
+        trainer.train(model=model, training_params=train_params)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -1,4 +1,5 @@
 import unittest
 import unittest
+
 from super_gradients import Trainer, \
 from super_gradients import Trainer, \
     ClassificationTestDatasetInterface
     ClassificationTestDatasetInterface
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
@@ -17,7 +18,6 @@ class TrainWithPreciseBNTest(unittest.TestCase):
         trainer.connect_dataset_interface(dataset)
         trainer.connect_dataset_interface(dataset)
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
-        trainer.build_model(net)
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
@@ -25,7 +25,7 @@ class TrainWithPreciseBNTest(unittest.TestCase):
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True,
                         "greater_metric_to_watch_is_better": True,
                         "precise_bn": True, "precise_bn_batch_size": 100}
                         "precise_bn": True, "precise_bn_batch_size": 100}
-        trainer.train(train_params)
+        trainer.train(model=net, training_params=train_params)
 
 
     def test_train_with_precise_bn_implicit_size(self):
     def test_train_with_precise_bn_implicit_size(self):
         trainer = Trainer("test_train_with_precise_bn_implicit_size", model_checkpoints_location='local')
         trainer = Trainer("test_train_with_precise_bn_implicit_size", model_checkpoints_location='local')
@@ -34,7 +34,6 @@ class TrainWithPreciseBNTest(unittest.TestCase):
         trainer.connect_dataset_interface(dataset)
         trainer.connect_dataset_interface(dataset)
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
-        trainer.build_model(net)
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
@@ -42,7 +41,7 @@ class TrainWithPreciseBNTest(unittest.TestCase):
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True,
                         "greater_metric_to_watch_is_better": True,
                         "precise_bn": True}
                         "precise_bn": True}
-        trainer.train(train_params)
+        trainer.train(model=net, training_params=train_params)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -35,7 +35,6 @@ class UpdateParamGroupsTest(unittest.TestCase):
         net = TestNet()
         net = TestNet()
         trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
         trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
         trainer.connect_dataset_interface(self.dataset)
         trainer.connect_dataset_interface(self.dataset)
-        trainer.build_model(net, arch_params=self.arch_params)
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -53,6 +52,6 @@ class UpdateParamGroupsTest(unittest.TestCase):
                         }
                         }
 
 
         expected_lrs = np.array([0.1, 0.2, 0.3])
         expected_lrs = np.array([0.1, 0.2, 0.3])
-        trainer.train(train_params)
+        trainer.train(model=net, training_params=train_params)
 
 
         self.assertTrue(np.allclose(np.array(lrs), expected_lrs, rtol=0.0000001))
         self.assertTrue(np.allclose(np.array(lrs), expected_lrs, rtol=0.0000001))
Discard
@@ -3,6 +3,7 @@ from super_gradients.training.utils.utils import HpmStruct
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
 from super_gradients import Trainer
 from super_gradients import Trainer
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
+from super_gradients.training import models
 
 
 
 
 class TestViT(unittest.TestCase):
 class TestViT(unittest.TestCase):
@@ -23,8 +24,8 @@ class TestViT(unittest.TestCase):
         """
         """
         trainer = Trainer("test_vit_base", device='cpu')
         trainer = Trainer("test_vit_base", device='cpu')
         trainer.connect_dataset_interface(self.dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.dataset, data_loader_num_workers=8)
-        trainer.build_model('vit_base', load_checkpoint=False)
-        trainer.train(training_params=self.train_params)
+        model = models.get('vit_base', arch_params={"num_classes": 5})
+        trainer.train(model=model, training_params=self.train_params)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard