Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#313 Feature/sg 187 rename sg model

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-187_rename_sg_model
99 changed files with 2089 additions and 2152 deletions
  1. 3
    3
      README.md
  2. 13
    13
      documentation/source/user_guide.md
  3. 4
    3
      src/super_gradients/__init__.py
  4. 1
    1
      src/super_gradients/common/data_types/enum/evaluation_type.py
  5. 3
    3
      src/super_gradients/common/sg_loggers/base_sg_logger.py
  6. 4
    4
      src/super_gradients/examples/SG_Walkthrough.ipynb
  7. 2
    2
      src/super_gradients/examples/SG_quickstart_classification.ipynb
  8. 2
    2
      src/super_gradients/examples/SG_quickstart_model_upload_deci_lab.ipynb
  9. 6
    6
      src/super_gradients/examples/cifar10_training_torch_objects/cifar10_training_torch_objects_example.py
  10. 8
    9
      src/super_gradients/examples/ddrnet_imagenet/ddrnet_classification_example.py
  11. 5
    5
      src/super_gradients/examples/deci_lab_export_example/deci_lab_export_example.py
  12. 2
    2
      src/super_gradients/examples/deci_platform_logger_example/deci_platform_logger_example.py
  13. 5
    5
      src/super_gradients/examples/early_stop/early_stop_example.py
  14. 3
    3
      src/super_gradients/examples/legacy/cifar_resnet/cifar_example.py
  15. 8
    8
      src/super_gradients/examples/legacy/darknet53_example.py
  16. 3
    3
      src/super_gradients/examples/legacy/imagenet_efficientnet/efficientnet_example.py
  17. 3
    3
      src/super_gradients/examples/legacy/imagenet_mobilenetv3/mobilenetv3_imagenet_example.py
  18. 3
    3
      src/super_gradients/examples/legacy/imagenet_regnetY800/regnetY800_example.py
  19. 3
    3
      src/super_gradients/examples/legacy/imagenet_repvgg/imagenet_repvgg_example.py
  20. 3
    3
      src/super_gradients/examples/legacy/imagenet_resnet/imagenet_resnet_example.py
  21. 9
    9
      src/super_gradients/examples/legacy/imagenet_resnet_ddp/distributed_training_imagenet.py
  22. 8
    8
      src/super_gradients/examples/legacy/shelfnet_lw_example.py
  23. 5
    5
      src/super_gradients/examples/regseg_transfer_learning_example/regseg_transfer_learning_example.py
  24. 8
    8
      src/super_gradients/examples/resnet_qat/resnet_qat_example.py
  25. 7
    7
      src/super_gradients/examples/shelfnet_lw_pascal_aug/shelfnet_pascal_aug.py
  26. 1
    1
      src/super_gradients/examples/train_from_kd_recipe_example/train_from_kd_recipe.py
  27. 3
    2
      src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py
  28. 7
    7
      src/super_gradients/examples/user_guide_walkthrough_example/train.py
  29. 1
    1
      src/super_gradients/recipes/checkpoint_params/default_checkpoint_params.yaml
  30. 0
    6
      src/super_gradients/recipes/cifar10_resnet.yaml
  31. 0
    6
      src/super_gradients/recipes/cityscapes_ddrnet.yaml
  32. 0
    6
      src/super_gradients/recipes/cityscapes_regseg48.yaml
  33. 0
    7
      src/super_gradients/recipes/cityscapes_stdc_base.yaml
  34. 1
    6
      src/super_gradients/recipes/coco2017_yolox.yaml
  35. 1
    9
      src/super_gradients/recipes/coco_segmentation_shelfnet_lw.yaml
  36. 1
    8
      src/super_gradients/recipes/imagenet_efficientnet.yaml
  37. 1
    8
      src/super_gradients/recipes/imagenet_mobilenetv2.yaml
  38. 1
    8
      src/super_gradients/recipes/imagenet_mobilenetv3.yaml
  39. 1
    7
      src/super_gradients/recipes/imagenet_regnetY.yaml
  40. 1
    8
      src/super_gradients/recipes/imagenet_repvgg.yaml
  41. 1
    8
      src/super_gradients/recipes/imagenet_resnet50.yaml
  42. 2
    9
      src/super_gradients/recipes/imagenet_resnet50_kd.yaml
  43. 0
    6
      src/super_gradients/recipes/imagenet_vit_base.yaml
  44. 0
    6
      src/super_gradients/recipes/test_resnet.yaml
  45. 8
    5
      src/super_gradients/training/__init__.py
  46. 2
    2
      src/super_gradients/training/datasets/datasets_utils.py
  47. 3
    3
      src/super_gradients/training/exceptions/kd_trainer_exceptions.py
  48. 1
    1
      src/super_gradients/training/exceptions/sg_trainer_exceptions.py
  49. 6
    272
      src/super_gradients/training/kd_model/kd_model.py
  50. 0
    16
      src/super_gradients/training/kd_trainer.py
  51. 5
    0
      src/super_gradients/training/kd_trainer/__init__.py
  52. 315
    0
      src/super_gradients/training/kd_trainer/kd_trainer.py
  53. 1
    1
      src/super_gradients/training/models/segmentation_models/stdc.py
  54. 3
    3
      src/super_gradients/training/sg_model/__init__.py
  55. 2
    992
      src/super_gradients/training/sg_model/sg_model.py
  56. 5
    0
      src/super_gradients/training/sg_trainer/__init__.py
  57. 999
    0
      src/super_gradients/training/sg_trainer/sg_trainer.py
  58. 0
    33
      src/super_gradients/training/trainer.py
  59. 1
    1
      src/super_gradients/training/utils/callbacks.py
  60. 1
    1
      src/super_gradients/training/utils/checkpoint_utils.py
  61. 2
    2
      src/super_gradients/training/utils/distributed_training_utils.py
  62. 0
    0
      src/super_gradients/training/utils/kd_trainer_utils.py
  63. 3
    3
      src/super_gradients/training/utils/quantization_utils.py
  64. 21
    2
      src/super_gradients/training/utils/sg_trainer_utils.py
  65. 2
    2
      tests/deci_core_unit_test_suite_runner.py
  66. 5
    5
      tests/end_to_end_tests/cifar10_trainer_test.py
  67. 7
    7
      tests/end_to_end_tests/external_dataset_e2e.py
  68. 31
    39
      tests/end_to_end_tests/trainer_test.py
  69. 9
    9
      tests/integration_tests/conversion_callback_test.py
  70. 5
    5
      tests/integration_tests/deci_lab_export_test.py
  71. 11
    11
      tests/integration_tests/ema_train_integration_test.py
  72. 12
    12
      tests/integration_tests/lr_test.py
  73. 55
    55
      tests/integration_tests/pretrained_models_test.py
  74. 16
    16
      tests/integration_tests/qat_integration_test.py
  75. 6
    6
      tests/test-data-interface.py
  76. 3
    3
      tests/unit_tests/dataset_interface_test.py
  77. 7
    7
      tests/unit_tests/dataset_statistics_test.py
  78. 13
    13
      tests/unit_tests/detection_utils_test.py
  79. 29
    29
      tests/unit_tests/early_stop_test.py
  80. 9
    9
      tests/unit_tests/factories_test.py
  81. 5
    5
      tests/unit_tests/forward_pass_prep_fn_test.py
  82. 31
    31
      tests/unit_tests/initialize_with_dataloaders_test.py
  83. 58
    58
      tests/unit_tests/kd_ema_test.py
  84. 70
    70
      tests/unit_tests/kd_trainer_test.py
  85. 9
    9
      tests/unit_tests/load_checkpoint_from_direct_path_test.py
  86. 11
    11
      tests/unit_tests/load_ema_ckpt_test.py
  87. 5
    5
      tests/unit_tests/lr_cooldown_test.py
  88. 17
    17
      tests/unit_tests/lr_warmup_test.py
  89. 6
    6
      tests/unit_tests/phase_context_test.py
  90. 6
    6
      tests/unit_tests/phase_delegates_test.py
  91. 4
    4
      tests/unit_tests/pretrained_models_unit_test.py
  92. 6
    6
      tests/unit_tests/save_ckpt_test.py
  93. 37
    37
      tests/unit_tests/strictload_enum_test.py
  94. 27
    27
      tests/unit_tests/test_without_train_test.py
  95. 7
    7
      tests/unit_tests/train_logging_test.py
  96. 30
    29
      tests/unit_tests/train_with_intialized_param_args_test.py
  97. 9
    9
      tests/unit_tests/train_with_precise_bn_test.py
  98. 5
    5
      tests/unit_tests/update_param_groups_unit_test.py
  99. 5
    5
      tests/unit_tests/vit_unit_test.py
@@ -124,14 +124,14 @@ The most simple and straightforward way to start training SOTA performance model
 python -m super_gradients.train_from_recipe --config-name=imagenet_regnetY architecture=regnetY800 dataset_interface.data_dir=<YOUR_Imagenet_LOCAL_PATH> ckpt_root_dir=<CHEKPOINT_DIRECTORY>
 python -m super_gradients.train_from_recipe --config-name=imagenet_regnetY architecture=regnetY800 dataset_interface.data_dir=<YOUR_Imagenet_LOCAL_PATH> ckpt_root_dir=<CHEKPOINT_DIRECTORY>
 ```
 ```
 ### Quickly Load Pre-Trained Weights for Your Desired Model with SOTA Performance
 ### Quickly Load Pre-Trained Weights for Your Desired Model with SOTA Performance
-Want to try our pre-trained models on your machine? Import SuperGradients, initialize your SgModel, and load your desired architecture and pre-trained weights from our [SOTA model zoo](#computer-vision-models---pretrained-checkpoints)
-    
+Want to try our pre-trained models on your machine? Import SuperGradients, initialize your Trainer, and load your desired architecture and pre-trained weights from our [SOTA model zoo](#computer-vision-models---pretrained-checkpoints)
+
 ```python
 ```python
 # The pretrained_weights argument will load a pre-trained architecture on the provided dataset
 # The pretrained_weights argument will load a pre-trained architecture on the provided dataset
 # This is an example of loading COCO-2017 pre-trained weights for a YOLOX Nano object detection model
 # This is an example of loading COCO-2017 pre-trained weights for a YOLOX Nano object detection model
     
     
 import super_gradients
 import super_gradients
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 
 
 trainer = SgModel(experiment_name="yoloxn_coco_experiment",ckpt_root_dir=<CHECKPOINT_DIRECTORY>)
 trainer = SgModel(experiment_name="yoloxn_coco_experiment",ckpt_root_dir=<CHECKPOINT_DIRECTORY>)
 trainer.build_model(architecture="yolox_n", arch_params={"pretrained_weights": "coco", num_classes": 80})
 trainer.build_model(architecture="yolox_n", arch_params={"pretrained_weights": "coco", num_classes": 80})
Discard
@@ -68,11 +68,11 @@ In this section we present the modifications needed in order to launch your trai
 
 
 #### Integrating Your Training Code: Main components:
 #### Integrating Your Training Code: Main components:
 
 
-<span style="text-decoration:underline;">SgModel </span>- the main class in charge of training, testing, logging and basically everything that has to do with the execution of training code.
+<span style="text-decoration:underline;">Trainer </span>- the main class in charge of training, testing, logging and basically everything that has to do with the execution of training code.
 
 
-<span style="text-decoration:underline;">DatasetInterface</span> - which is passed as an argument to the SgModel and wraps the training set, validation set and optionally a test set for the SgModel instance to work with accordingly.
+<span style="text-decoration:underline;">DatasetInterface</span> - which is passed as an argument to the Trainer and wraps the training set, validation set and optionally a test set for the Trainer instance to work with accordingly.
 
 
-<span style="text-decoration:underline;">SgModel.net</span> -The network to be used for training/testing (of torch.nn.Module type).
+<span style="text-decoration:underline;">Trainer.net</span> -The network to be used for training/testing (of torch.nn.Module type).
 
 
 
 
 #### Integrating Your Training Code - Complete Walkthrough: Dataset
 #### Integrating Your Training Code - Complete Walkthrough: Dataset
@@ -241,13 +241,13 @@ class Top5(torchmetrics.Accuracy):
 
 
 #### Integrating Your Training Code- Complete Walkthrough: Training script
 #### Integrating Your Training Code- Complete Walkthrough: Training script
 
 
-We instantiate an SgModel and a UserDatasetInterface, then call connect_dataset_interface which will initialize the dataloaders and pass additional dataset parameters to the SgModel instance.
+We instantiate an Trainer and a UserDatasetInterface, then call connect_dataset_interface which will initialize the dataloaders and pass additional dataset parameters to the Trainer instance.
 
 
 
 
 ```
 ```
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 
 
-sg_model = SgModel(experiment_name='LeNet_cifar10_example')
+sg_model = Trainer(experiment_name='LeNet_cifar10_example')
 dataset_params = {"batch_size": 256}
 dataset_params = {"batch_size": 256}
 dataset = UserDataset(dataset_params)
 dataset = UserDataset(dataset_params)
 sg_model.connect_dataset_interface(dataset)
 sg_model.connect_dataset_interface(dataset)
@@ -255,7 +255,7 @@ sg_model.connect_dataset_interface(dataset)
 ```
 ```
 
 
 
 
-**Now, we pass a LeNet instance we defined above to the SgModel:**
+**Now, we pass a LeNet instance we defined above to the Trainer:**
 
 
 
 
 ```
 ```
@@ -589,11 +589,11 @@ You can also connect your training logs into Weights and Biases (WandB) assuming
     **2nd option:**
     **2nd option:**
 
 
 
 
-        Set the “launch_tensorboard_process” flag in your training_params passed to SgModel.train(...), and follow instructions displayed in the shell.
+        Set the “launch_tensorboard_process” flag in your training_params passed to Trainer.train(...), and follow instructions displayed in the shell.
 
 
 
 
 * **To resume training –**
 * **To resume training –**
-When building the network- call SgModel.build_model(...arch_params={'load_checkpoint'True...}). Doing so, will load the network’s weights, as well as any relevant information for resuming training (monitored metric values, optimizer states, etc) with the latest checkpoint. For more advanced usage see SgModel.build_model docs in code.
+When building the network- call Trainer.build_model(...arch_params={'load_checkpoint'True...}). Doing so, will load the network’s weights, as well as any relevant information for resuming training (monitored metric values, optimizer states, etc) with the latest checkpoint. For more advanced usage see Trainer.build_model docs in code.
 
 
 
 
 
 
@@ -613,7 +613,7 @@ When building the network- call SgModel.build_model(...arch_params={'load_checkp
 
 
 ## Dataset Parameters
 ## Dataset Parameters
 
 
-dataset_params argument passed to SgModel.build_model().
+dataset_params argument passed to Trainer.build_model().
 
 
 `batch_size`: int (default=64)
 `batch_size`: int (default=64)
 
 
@@ -637,13 +637,13 @@ The remote s3 link from which to download the data (optional).
 
 
 ## Network Architectures
 ## Network Architectures
 
 
-The following architectures are implemented in SuperGradients’ code, and can be initialized by passing their name (i.e string) to SgModel.build_model easily.
+The following architectures are implemented in SuperGradients’ code, and can be initialized by passing their name (i.e string) to Trainer.build_model easily.
 
 
 For example:
 For example:
 
 
 
 
 ```
 ```
-sg_model = SgModel("resnet50_experiment")
+sg_model = Trainer("resnet50_experiment")
 sg_model.build_model(architecture="resnet50")
 sg_model.build_model(architecture="resnet50")
 ```
 ```
 
 
@@ -894,7 +894,7 @@ Example- how to load a pretrained model:
 
 
 
 
 ```
 ```
-sg_model = SgModel("resnet50_experiment")
+sg_model = Trainer("resnet50_experiment")
 
 
 sg_model.build_model(architecture="resnet50",
 sg_model.build_model(architecture="resnet50",
                       arch_params={"pretrained_weights": "imagenet", "num_classes": 1000}
                       arch_params={"pretrained_weights": "imagenet", "num_classes": 1000}
Discard
@@ -1,14 +1,15 @@
 from super_gradients.training import ARCHITECTURES, losses, utils, datasets_utils, DataAugmentation, \
 from super_gradients.training import ARCHITECTURES, losses, utils, datasets_utils, DataAugmentation, \
-    TestDatasetInterface, SegmentationTestDatasetInterface, DetectionTestDatasetInterface, ClassificationTestDatasetInterface, SgModel, KDModel
+    TestDatasetInterface, SegmentationTestDatasetInterface, DetectionTestDatasetInterface, ClassificationTestDatasetInterface, SgModel, KDModel, \
+    Trainer, KDTrainer
 from super_gradients.common import init_trainer, is_distributed
 from super_gradients.common import init_trainer, is_distributed
 from super_gradients.examples.train_from_recipe_example import train_from_recipe
 from super_gradients.examples.train_from_recipe_example import train_from_recipe
 from super_gradients.examples.train_from_kd_recipe_example import train_from_kd_recipe
 from super_gradients.examples.train_from_kd_recipe_example import train_from_kd_recipe
 from super_gradients.sanity_check import env_sanity_check
 from super_gradients.sanity_check import env_sanity_check
 
 
 __all__ = ['ARCHITECTURES', 'losses', 'utils', 'datasets_utils', 'DataAugmentation',
 __all__ = ['ARCHITECTURES', 'losses', 'utils', 'datasets_utils', 'DataAugmentation',
-           'TestDatasetInterface', 'SgModel', 'KDModel', 'SegmentationTestDatasetInterface', 'DetectionTestDatasetInterface',
+           'TestDatasetInterface', 'Trainer', 'KDTrainer', 'SegmentationTestDatasetInterface', 'DetectionTestDatasetInterface',
            'ClassificationTestDatasetInterface', 'init_trainer', 'is_distributed', 'train_from_recipe', 'train_from_kd_recipe',
            'ClassificationTestDatasetInterface', 'init_trainer', 'is_distributed', 'train_from_recipe', 'train_from_kd_recipe',
-           'env_sanity_check']
+           'env_sanity_check', 'KDModel', 'SgModel']
 
 
 
 
 env_sanity_check()
 env_sanity_check()
Discard
@@ -5,7 +5,7 @@ class EvaluationType(str, Enum):
     """
     """
     EvaluationType
     EvaluationType
 
 
-    Passed to SgModel.evaluate(..), and controls which phase callbacks should be triggered (if at all).
+    Passed to Trainer.evaluate(..), and controls which phase callbacks should be triggered (if at all).
 
 
         Attributes:
         Attributes:
             TEST
             TEST
Discard
@@ -15,7 +15,7 @@ from super_gradients.common import ADNNModelRepositoryDataInterfaces
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
 from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
 from super_gradients.common.environment.env_helpers import multi_process_safe
 from super_gradients.common.environment.env_helpers import multi_process_safe
-from super_gradients.training.utils import sg_model_utils
+from super_gradients.training.utils import sg_trainer_utils
 from super_gradients.training.params import TrainingParams
 from super_gradients.training.params import TrainingParams
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -89,11 +89,11 @@ class BaseSGLogger(AbstractSGLogger):
 
 
     @multi_process_safe
     @multi_process_safe
     def _launch_tensorboard(self, port):
     def _launch_tensorboard(self, port):
-        self.tensor_board_process, _ = sg_model_utils.launch_tensorboard_process(self._local_dir, port=port)
+        self.tensor_board_process, _ = sg_trainer_utils.launch_tensorboard_process(self._local_dir, port=port)
 
 
     @multi_process_safe
     @multi_process_safe
     def _init_tensorboard(self, resumed, tb_files_user_prompt):
     def _init_tensorboard(self, resumed, tb_files_user_prompt):
-        self.tensorboard_writer = sg_model_utils.init_summary_writer(self._local_dir, resumed, tb_files_user_prompt)
+        self.tensorboard_writer = sg_trainer_utils.init_summary_writer(self._local_dir, resumed, tb_files_user_prompt)
 
 
     @multi_process_safe
     @multi_process_safe
     def _make_dir(self):
     def _make_dir(self):
Discard
Discard
Discard
Discard
@@ -11,7 +11,7 @@ Cifar10 training with SuperGradients training with the following initialized tor
 Main purpose is to demonstrate training in SG with minimal abstraction and maximal flexibility
 Main purpose is to demonstrate training in SG with minimal abstraction and maximal flexibility
 """
 """
 
 
-from super_gradients import SgModel
+from super_gradients import Trainer
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
 from super_gradients.training import MultiGPUMode
 from super_gradients.training import MultiGPUMode
 from torch.optim import ASGD
 from torch.optim import ASGD
@@ -48,10 +48,10 @@ loss_fn = CrossEntropyLoss()
 phase_callbacks = [LRSchedulerCallback(scheduler=rop_lr_scheduler, phase=Phase.VALIDATION_EPOCH_END, metric_name="Accuracy"),
 phase_callbacks = [LRSchedulerCallback(scheduler=rop_lr_scheduler, phase=Phase.VALIDATION_EPOCH_END, metric_name="Accuracy"),
                    LRSchedulerCallback(scheduler=step_lr_scheduler, phase=Phase.TRAIN_EPOCH_END)]
                    LRSchedulerCallback(scheduler=step_lr_scheduler, phase=Phase.TRAIN_EPOCH_END)]
 
 
-# Bring everything together with SgModel and start training
-model = SgModel("Cifar10_external_objects_example", multi_gpu=MultiGPUMode.OFF,
-                train_loader=train_loader, valid_loader=valid_loader, classes=train_dataset.classes)
-model.build_model(net)
+# Bring everything together with Trainer and start training
+trainer = Trainer("Cifar10_external_objects_example", multi_gpu=MultiGPUMode.OFF,
+                  train_loader=train_loader, valid_loader=valid_loader, classes=train_dataset.classes)
+trainer.build_model(net)
 
 
 train_params = {"max_epochs": 300,
 train_params = {"max_epochs": 300,
                 "phase_callbacks": phase_callbacks,
                 "phase_callbacks": phase_callbacks,
@@ -65,4 +65,4 @@ train_params = {"max_epochs": 300,
                 "greater_metric_to_watch_is_better": True,
                 "greater_metric_to_watch_is_better": True,
                 "lr_scheduler_step_type": "epoch"}
                 "lr_scheduler_step_type": "epoch"}
 
 
-model.train(training_params=train_params)
+trainer.train(training_params=train_params)
Discard
@@ -17,13 +17,12 @@ import torch
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ImageNetDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ImageNetDatasetInterface
 
 
 import super_gradients
 import super_gradients
-from super_gradients.training import SgModel, MultiGPUMode
+from super_gradients.training import Trainer, MultiGPUMode
 from super_gradients.training.models import HpmStruct
 from super_gradients.training.models import HpmStruct
 import argparse
 import argparse
 
 
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 
 
-
 parser = argparse.ArgumentParser()
 parser = argparse.ArgumentParser()
 super_gradients.init_trainer()
 super_gradients.init_trainer()
 
 
@@ -60,16 +59,16 @@ dataset_params = {"batch_size": args.batch,
                   "auto_augment_config_string": 'rand-m9-mstd0.5'
                   "auto_augment_config_string": 'rand-m9-mstd0.5'
                   }
                   }
 
 
-model = SgModel(experiment_name=args.experiment_name,
-                multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL if distributed else MultiGPUMode.DATA_PARALLEL,
-                device='cuda')
+trainer = Trainer(experiment_name=args.experiment_name,
+                  multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL if distributed else MultiGPUMode.DATA_PARALLEL,
+                  device='cuda')
 
 
 dataset = ImageNetDatasetInterface(dataset_params=dataset_params)
 dataset = ImageNetDatasetInterface(dataset_params=dataset_params)
 
 
-model.connect_dataset_interface(dataset, data_loader_num_workers=8 * devices)
+trainer.connect_dataset_interface(dataset, data_loader_num_workers=8 * devices)
 
 
 arch_params = HpmStruct(**{"num_classes": 1000, "aux_head": False, "classification_mode": True, 'dropout_prob': 0.3})
 arch_params = HpmStruct(**{"num_classes": 1000, "aux_head": False, "classification_mode": True, 'dropout_prob': 0.3})
 
 
-model.build_model(architecture="ddrnet_23_slim" if args.slim else "ddrnet_23",
-                  arch_params=arch_params)
-model.train(training_params=train_params_ddr)
+trainer.build_model(architecture="ddrnet_23_slim" if args.slim else "ddrnet_23",
+                    arch_params=arch_params)
+trainer.train(training_params=train_params_ddr)
Discard
@@ -4,7 +4,7 @@ Deci-lab model export example.
 The main purpose of this code is to demonstrate how to upload the model to the platform, optimize and download it
 The main purpose of this code is to demonstrate how to upload the model to the platform, optimize and download it
  after training is complete, using DeciPlatformCallback.
  after training is complete, using DeciPlatformCallback.
 """
 """
-from super_gradients import SgModel, ClassificationTestDatasetInterface
+from super_gradients import Trainer, ClassificationTestDatasetInterface
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.utils.callbacks import DeciLabUploadCallback, ModelConversionCheckCallback
 from super_gradients.training.utils.callbacks import DeciLabUploadCallback, ModelConversionCheckCallback
 from deci_lab_client.models import (
 from deci_lab_client.models import (
@@ -23,15 +23,15 @@ def main(architecture_name: str):
 
 
     auth_token = "YOUR_API_TOKEN_HERE"
     auth_token = "YOUR_API_TOKEN_HERE"
 
 
-    model = SgModel(
+    trainer = Trainer(
         f"lab_optimization_{architecture_name}_example",
         f"lab_optimization_{architecture_name}_example",
         model_checkpoints_location="local",
         model_checkpoints_location="local",
         ckpt_root_dir=checkpoint_dir,
         ckpt_root_dir=checkpoint_dir,
     )
     )
     dataset = ClassificationTestDatasetInterface(dataset_params={"batch_size": 10})
     dataset = ClassificationTestDatasetInterface(dataset_params={"batch_size": 10})
-    model.connect_dataset_interface(dataset, data_loader_num_workers=0)
+    trainer.connect_dataset_interface(dataset, data_loader_num_workers=0)
 
 
-    model.build_model(architecture=architecture_name, arch_params={"use_aux_heads": True, "aux_head": True})
+    trainer.build_model(architecture=architecture_name, arch_params={"use_aux_heads": True, "aux_head": True})
 
 
     # CREATE META-DATA, AND OPTIMIZATION REQUEST FORM FOR DECI PLATFORM POST TRAINING CALLBACK
     # CREATE META-DATA, AND OPTIMIZATION REQUEST FORM FOR DECI PLATFORM POST TRAINING CALLBACK
     model_name = f"{architecture_name}_for_deci_lab_export_example"
     model_name = f"{architecture_name}_for_deci_lab_export_example"
@@ -91,7 +91,7 @@ def main(architecture_name: str):
 
 
     # RUN TRAINING. ONCE ALL EPOCHS ARE DONE THE OPTIMIZED MODEL FILE WILL BE LOCATED IN THE EXPERIMENT'S
     # RUN TRAINING. ONCE ALL EPOCHS ARE DONE THE OPTIMIZED MODEL FILE WILL BE LOCATED IN THE EXPERIMENT'S
     # CHECKPOINT DIRECTORY
     # CHECKPOINT DIRECTORY
-    model.train(train_params)
+    trainer.train(train_params)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard
@@ -1,12 +1,12 @@
 import os
 import os
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 from super_gradients.training.datasets.dataset_interfaces import Cifar10DatasetInterface
 from super_gradients.training.datasets.dataset_interfaces import Cifar10DatasetInterface
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
 
 
 os.environ["DECI_PLATFORM_TOKEN"] = "XXX"  # Replace XXX with your token
 os.environ["DECI_PLATFORM_TOKEN"] = "XXX"  # Replace XXX with your token
 
 
 
 
-trainer = SgModel(experiment_name='demo-deci-platform-logger')
+trainer = Trainer(experiment_name='demo-deci-platform-logger')
 dataset = Cifar10DatasetInterface(dataset_params={"batch_size": 256, "val_batch_size": 512})
 dataset = Cifar10DatasetInterface(dataset_params={"batch_size": 256, "val_batch_size": 512})
 trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
 trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
 trainer.build_model("resnet18")
 trainer.build_model("resnet18")
Discard
@@ -1,7 +1,7 @@
 # Cifar10 Classification Training:
 # Cifar10 Classification Training:
 # Reaches ~94.9 Accuracy after 250 Epochs
 # Reaches ~94.9 Accuracy after 250 Epochs
 import super_gradients
 import super_gradients
-from super_gradients import SgModel
+from super_gradients import Trainer
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import Cifar10DatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import Cifar10DatasetInterface
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
 from super_gradients.training.utils.early_stopping import EarlyStop
 from super_gradients.training.utils.early_stopping import EarlyStop
@@ -20,12 +20,12 @@ train_params = {"max_epochs": 250, "lr_updates": [100, 150, 200], "lr_decay_fact
                 "greater_metric_to_watch_is_better": True, "phase_callbacks": [early_stop_acc, early_stop_val_loss]}
                 "greater_metric_to_watch_is_better": True, "phase_callbacks": [early_stop_acc, early_stop_val_loss]}
 
 
 # Define Model
 # Define Model
-model = SgModel("Callback_Example")
+trainer = Trainer("Callback_Example")
 
 
 # Connect Dataset
 # Connect Dataset
 dataset = Cifar10DatasetInterface()
 dataset = Cifar10DatasetInterface()
-model.connect_dataset_interface(dataset, data_loader_num_workers=8)
+trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
 
 
 # Build Model
 # Build Model
-model.build_model("resnet18_cifar")
-model.train(training_params=train_params)
+trainer.build_model("resnet18_cifar")
+trainer.train(training_params=train_params)
Discard
@@ -13,13 +13,13 @@ def train(cfg: DictConfig) -> None:
     cfg = hydra.utils.instantiate(cfg)
     cfg = hydra.utils.instantiate(cfg)
 
 
     # CONNECT THE DATASET INTERFACE WITH DECI MODEL
     # CONNECT THE DATASET INTERFACE WITH DECI MODEL
-    cfg.sg_model.connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
+    cfg.trainer .connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
 
 
     # BUILD NETWORK
     # BUILD NETWORK
-    cfg.sg_model.build_model(cfg.architecture, load_checkpoint=cfg.load_checkpoint)
+    cfg.trainer .build_model(cfg.architecture, load_checkpoint=cfg.load_checkpoint)
 
 
     # TRAIN
     # TRAIN
-    cfg.sg_model.train(training_params=cfg.training_params)
+    cfg.trainer .train(training_params=cfg.training_params)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard
@@ -1,6 +1,6 @@
 # Darknet53 Backbone Training on HAM10000 Dataset
 # Darknet53 Backbone Training on HAM10000 Dataset
 from super_gradients.training import MultiGPUMode
 from super_gradients.training import MultiGPUMode
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationDatasetInterface
 
 
 # Define Parameters
 # Define Parameters
@@ -11,10 +11,10 @@ arch_params = {'backbone_mode': False, 'num_classes': 7}
 dataset_params = {"batch_size": 16, "test_batch_size": 16, 'dataset_dir': '/data/HAM10000'}
 dataset_params = {"batch_size": 16, "test_batch_size": 16, 'dataset_dir': '/data/HAM10000'}
 
 
 # Define Model
 # Define Model
-model = SgModel("Darknet53_Backbone_HAM10000",
-                model_checkpoints_location='local',
-                device='cuda',
-                multi_gpu=MultiGPUMode.DATA_PARALLEL)
+trainer = Trainer("Darknet53_Backbone_HAM10000",
+                  model_checkpoints_location='local',
+                  device='cuda',
+                  multi_gpu=MultiGPUMode.DATA_PARALLEL)
 
 
 # Connect Dataset
 # Connect Dataset
 dataset = ClassificationDatasetInterface(normalization_mean=(0.7483, 0.5154, 0.5353),
 dataset = ClassificationDatasetInterface(normalization_mean=(0.7483, 0.5154, 0.5353),
@@ -22,10 +22,10 @@ dataset = ClassificationDatasetInterface(normalization_mean=(0.7483, 0.5154, 0.5
                                          resolution=416,
                                          resolution=416,
                                          dataset_params=dataset_params)
                                          dataset_params=dataset_params)
 
 
-model.connect_dataset_interface(dataset, data_loader_num_workers=8)
+trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
 
 
 # Build Model
 # Build Model
-model.build_model("darknet53", arch_params=arch_params)
+trainer.build_model("darknet53", arch_params=arch_params)
 
 
 # Start Training
 # Start Training
-model.train(training_params=train_params)
+trainer.train(training_params=train_params)
Discard
@@ -17,13 +17,13 @@ def train(cfg: DictConfig) -> None:
     cfg = hydra.utils.instantiate(cfg)
     cfg = hydra.utils.instantiate(cfg)
 
 
     # CONNECT THE DATASET INTERFACE WITH DECI MODEL
     # CONNECT THE DATASET INTERFACE WITH DECI MODEL
-    cfg.sg_model.connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
+    cfg.trainer .connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
 
 
     # BUILD NETWORK
     # BUILD NETWORK
-    cfg.sg_model.build_model(cfg.architecture, load_checkpoint=cfg.load_checkpoint)
+    cfg.trainer .build_model(cfg.architecture, load_checkpoint=cfg.load_checkpoint)
 
 
     # TRAIN
     # TRAIN
-    cfg.sg_model.train(training_params=cfg.training_params)
+    cfg.trainer .train(training_params=cfg.training_params)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard
@@ -16,13 +16,13 @@ def train(cfg: DictConfig) -> None:
     cfg = hydra.utils.instantiate(cfg)
     cfg = hydra.utils.instantiate(cfg)
 
 
     # CONNECT THE DATASET INTERFACE WITH DECI MODEL
     # CONNECT THE DATASET INTERFACE WITH DECI MODEL
-    cfg.sg_model.connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
+    cfg.trainer .connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
 
 
     # BUILD NETWORK
     # BUILD NETWORK
-    cfg.sg_model.build_model(cfg.architecture, load_checkpoint=cfg.load_checkpoint)
+    cfg.trainer .build_model(cfg.architecture, load_checkpoint=cfg.load_checkpoint)
 
 
     # TRAIN
     # TRAIN
-    cfg.sg_model.train(training_params=cfg.training_params)
+    cfg.trainer .train(training_params=cfg.training_params)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard
@@ -15,13 +15,13 @@ def train(cfg: DictConfig) -> None:
     cfg = hydra.utils.instantiate(cfg)
     cfg = hydra.utils.instantiate(cfg)
 
 
     # CONNECT THE DATASET INTERFACE WITH DECI MODEL
     # CONNECT THE DATASET INTERFACE WITH DECI MODEL
-    cfg.sg_model.connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
+    cfg.trainer .connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
 
 
     # BUILD NETWORK
     # BUILD NETWORK
-    cfg.sg_model.build_model(cfg.architecture, arch_params=cfg.arch_params, load_checkpoint=cfg.load_checkpoint)
+    cfg.trainer .build_model(cfg.architecture, arch_params=cfg.arch_params, load_checkpoint=cfg.load_checkpoint)
 
 
     # TRAIN
     # TRAIN
-    cfg.sg_model.train(training_params=cfg.training_params)
+    cfg.trainer .train(training_params=cfg.training_params)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard
@@ -10,13 +10,13 @@ def train(cfg: DictConfig) -> None:
     cfg = hydra.utils.instantiate(cfg)
     cfg = hydra.utils.instantiate(cfg)
 
 
     # CONNECT THE DATASET INTERFACE WITH DECI MODEL
     # CONNECT THE DATASET INTERFACE WITH DECI MODEL
-    cfg.sg_model.connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
+    cfg.trainer .connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
 
 
     # BUILD NETWORK
     # BUILD NETWORK
-    cfg.sg_model.build_model(cfg.architecture, arch_params=cfg.arch_params, load_checkpoint=cfg.load_checkpoint)
+    cfg.trainer .build_model(cfg.architecture, arch_params=cfg.arch_params, load_checkpoint=cfg.load_checkpoint)
 
 
     # TRAIN
     # TRAIN
-    cfg.sg_model.train(training_params=cfg.training_params)
+    cfg.trainer .train(training_params=cfg.training_params)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard
@@ -28,13 +28,13 @@ def train(cfg: DictConfig) -> None:
     cfg = hydra.utils.instantiate(cfg)
     cfg = hydra.utils.instantiate(cfg)
 
 
     # CONNECT THE DATASET INTERFACE WITH DECI MODEL
     # CONNECT THE DATASET INTERFACE WITH DECI MODEL
-    cfg.sg_model.connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
+    cfg.trainer .connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
 
 
     # BUILD NETWORK
     # BUILD NETWORK
-    cfg.sg_model.build_model(cfg.architecture, arch_params=cfg.arch_params, load_checkpoint=cfg.load_checkpoint)
+    cfg.trainer .build_model(cfg.architecture, arch_params=cfg.arch_params, load_checkpoint=cfg.load_checkpoint)
 
 
     # TRAIN
     # TRAIN
-    cfg.sg_model.train(training_params=cfg.training_params)
+    cfg.trainer .train(training_params=cfg.training_params)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard
@@ -26,8 +26,8 @@
 """
 """
 import super_gradients
 import super_gradients
 import torch.distributed
 import torch.distributed
-from super_gradients.training.sg_model import MultiGPUMode
-from super_gradients.training import SgModel
+from super_gradients.training.sg_trainer import MultiGPUMode
+from super_gradients.training import Trainer
 from super_gradients.training.datasets.dataset_interfaces import ImageNetDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces import ImageNetDatasetInterface
 from super_gradients.common.aws_connection.aws_secrets_manager_connector import AWSSecretsManagerConnector
 from super_gradients.common.aws_connection.aws_secrets_manager_connector import AWSSecretsManagerConnector
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
@@ -57,12 +57,12 @@ dataset_params = {"batch_size": 128}
 model_repo_bucket_name = AWSSecretsManagerConnector.get_secret_value_for_secret_key(aws_env='research',
 model_repo_bucket_name = AWSSecretsManagerConnector.get_secret_value_for_secret_key(aws_env='research',
                                                                                     secret_name='training_secrets',
                                                                                     secret_name='training_secrets',
                                                                                     secret_key='S3.MODEL_REPOSITORY_BUCKET_NAME')
                                                                                     secret_key='S3.MODEL_REPOSITORY_BUCKET_NAME')
-model = SgModel("test_checkpoints_resnet_8_gpus",
-                model_checkpoints_location='s3://' + model_repo_bucket_name,
-                multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL
-                )
+trainer = Trainer("test_checkpoints_resnet_8_gpus",
+                  model_checkpoints_location='s3://' + model_repo_bucket_name,
+                  multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL
+                  )
 # FOR AWS
 # FOR AWS
 dataset = ImageNetDatasetInterface(data_dir="/data/Imagenet", dataset_params=dataset_params)
 dataset = ImageNetDatasetInterface(data_dir="/data/Imagenet", dataset_params=dataset_params)
-model.connect_dataset_interface(dataset, data_loader_num_workers=8)
-model.build_model("resnet50")
-model.train(training_params=train_params)
+trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
+trainer.build_model("resnet50")
+trainer.train(training_params=train_params)
Discard
@@ -10,9 +10,9 @@
 # P.S. - Use the relevant training params dict if you are running on TZAG or on V100
 # P.S. - Use the relevant training params dict if you are running on TZAG or on V100
 
 
 import torch
 import torch
-from super_gradients.training import SgModel, MultiGPUMode
+from super_gradients.training import Trainer, MultiGPUMode
 from super_gradients.training.datasets import CoCoSegmentationDatasetInterface
 from super_gradients.training.datasets import CoCoSegmentationDatasetInterface
-from super_gradients.training.sg_model.sg_model import StrictLoad
+from super_gradients.training.sg_trainer.sg_trainer import StrictLoad
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
 
 
 model_size_str = '34'
 model_size_str = '34'
@@ -66,16 +66,16 @@ experiment_name_dataset_suffix = '_coco_seg_' + str(
 
 
 experiment_name = experiment_name_prefix + model_size_str + experiment_name_dataset_suffix
 experiment_name = experiment_name_prefix + model_size_str + experiment_name_dataset_suffix
 
 
-model = SgModel(experiment_name,
-                multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL,
-                ckpt_name='ckpt_best.pth')
+trainer = Trainer(experiment_name,
+                  multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL,
+                  ckpt_name='ckpt_best.pth')
 
 
 coco_seg_datasaet_interface = CoCoSegmentationDatasetInterface(dataset_params=coco_seg_dataset_tzag_params,
 coco_seg_datasaet_interface = CoCoSegmentationDatasetInterface(dataset_params=coco_seg_dataset_tzag_params,
                                                                cache_labels=False,
                                                                cache_labels=False,
                                                                dataset_classes_inclusion_tuples_list=coco_sub_classes_inclusion_tuples_list)
                                                                dataset_classes_inclusion_tuples_list=coco_sub_classes_inclusion_tuples_list)
 
 
-model.connect_dataset_interface(coco_seg_datasaet_interface, data_loader_num_workers=data_loader_num_workers)
-model.build_model('shelfnet' + model_size_str, arch_params=shelfnet_lw_arch_params)
+trainer.connect_dataset_interface(coco_seg_datasaet_interface, data_loader_num_workers=data_loader_num_workers)
+trainer.build_model('shelfnet' + model_size_str, arch_params=shelfnet_lw_arch_params)
 
 
 print('Training ShelfNet-LW model: ' + experiment_name)
 print('Training ShelfNet-LW model: ' + experiment_name)
-model.train(training_params=shelfnet_coco_training_params)
+trainer.train(training_params=shelfnet_coco_training_params)
Discard
@@ -1,5 +1,5 @@
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import SuperviselyPersonsDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import SuperviselyPersonsDatasetInterface
-from super_gradients.training.sg_model import SgModel
+from super_gradients.training.sg_trainer import Trainer
 from super_gradients.training.metrics import BinaryIOU
 from super_gradients.training.metrics import BinaryIOU
 from super_gradients.training.transforms.transforms import ResizeSeg, RandomFlip, RandomRescale, CropImageAndMask, \
 from super_gradients.training.transforms.transforms import ResizeSeg, RandomFlip, RandomRescale, CropImageAndMask, \
     PadShortToCropSize, ColorJitterSeg
     PadShortToCropSize, ColorJitterSeg
@@ -19,15 +19,15 @@ dataset_params = {
 
 
 dataset_interface = SuperviselyPersonsDatasetInterface(dataset_params)
 dataset_interface = SuperviselyPersonsDatasetInterface(dataset_params)
 
 
-model = SgModel("regseg48_transfer_learning_old_dice_diff_lrs_head_fixed_50_epochs")
+trainer = Trainer("regseg48_transfer_learning_old_dice_diff_lrs_head_fixed_50_epochs")
 
 
 # CONNECTING THE DATASET INTERFACE WILL SET SGMODEL'S CLASSES ATTRIBUTE ACCORDING TO SUPERVISELY
 # CONNECTING THE DATASET INTERFACE WILL SET SGMODEL'S CLASSES ATTRIBUTE ACCORDING TO SUPERVISELY
-model.connect_dataset_interface(dataset_interface)
+trainer.connect_dataset_interface(dataset_interface)
 
 
 # THIS IS WHERE THE MAGIC HAPPENS- SINCE SGMODEL'S CLASSES ATTRIBUTE WAS SET TO BE DIFFERENT FROM CITYSCAPES'S, AFTER
 # THIS IS WHERE THE MAGIC HAPPENS- SINCE SGMODEL'S CLASSES ATTRIBUTE WAS SET TO BE DIFFERENT FROM CITYSCAPES'S, AFTER
 # LOADING THE PRETRAINED REGSET, IT WILL CALL IT'S REPLACE_HEAD METHOD AND CHANGE IT'S SEGMENTATION HEAD LAYER ACCORDING
 # LOADING THE PRETRAINED REGSET, IT WILL CALL IT'S REPLACE_HEAD METHOD AND CHANGE IT'S SEGMENTATION HEAD LAYER ACCORDING
 # TO OUR BINARY SEGMENTATION DATASET
 # TO OUR BINARY SEGMENTATION DATASET
-model.build_model("regseg48", arch_params={"pretrained_weights": "cityscapes"})
+trainer.build_model("regseg48", arch_params={"pretrained_weights": "cityscapes"})
 
 
 # DEFINE TRAINING PARAMS. SEE DOCS FOR THE FULL LIST.
 # DEFINE TRAINING PARAMS. SEE DOCS FOR THE FULL LIST.
 train_params = {"max_epochs": 50,
 train_params = {"max_epochs": 50,
@@ -55,4 +55,4 @@ train_params = {"max_epochs": 50,
                                                                             last_img_idx_in_batch=4)],
                                                                             last_img_idx_in_batch=4)],
                 }
                 }
 
 
-model.train(train_params)
+trainer.train(train_params)
Discard
@@ -18,7 +18,7 @@ Finally, once training is over- we trigger a pos-training callback that will exp
 """
 """
 
 
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ImageNetDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ImageNetDatasetInterface
-from super_gradients.training import SgModel, MultiGPUMode
+from super_gradients.training import Trainer, MultiGPUMode
 from super_gradients.training.metrics.classification_metrics import Accuracy
 from super_gradients.training.metrics.classification_metrics import Accuracy
 
 
 import super_gradients
 import super_gradients
@@ -27,12 +27,12 @@ from super_gradients.training.utils.quantization_utils import PostQATConversionC
 super_gradients.init_trainer()
 super_gradients.init_trainer()
 
 
 dataset = ImageNetDatasetInterface(data_dir="/data/Imagenet", dataset_params={"batch_size": 128})
 dataset = ImageNetDatasetInterface(data_dir="/data/Imagenet", dataset_params={"batch_size": 128})
-model = SgModel("resnet18_qat_example",
-                model_checkpoints_location='local',
-                multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL)
+trainer = Trainer("resnet18_qat_example",
+                  model_checkpoints_location='local',
+                  multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL)
 
 
-model.connect_dataset_interface(dataset)
-model.build_model("resnet18", checkpoint_params={"pretrained_weights": "imagenet"})
+trainer.connect_dataset_interface(dataset)
+trainer.build_model("resnet18", checkpoint_params={"pretrained_weights": "imagenet"})
 
 
 train_params = {"max_epochs": 1,
 train_params = {"max_epochs": 1,
                 "lr_mode": "step",
                 "lr_mode": "step",
@@ -53,9 +53,9 @@ train_params = {"max_epochs": 1,
                     # statistics method for amax computation (one of [percentile, mse, entropy, max]).
                     # statistics method for amax computation (one of [percentile, mse, entropy, max]).
                     "calibrate": True,  # whether to perform calibration.
                     "calibrate": True,  # whether to perform calibration.
                     "num_calib_batches": 2,  # number of batches to collect the statistics from.
                     "num_calib_batches": 2,  # number of batches to collect the statistics from.
-                    "percentile": 99.99  # percentile value to use when SgModel,
+                    "percentile": 99.99  # percentile value to use when Trainer,
                 },
                 },
                 "phase_callbacks": [PostQATConversionCallback(dummy_input_size=(1, 3, 224, 224))]
                 "phase_callbacks": [PostQATConversionCallback(dummy_input_size=(1, 3, 224, 224))]
                 }
                 }
 
 
-model.train(training_params=train_params)
+trainer.train(training_params=train_params)
Discard
@@ -2,8 +2,8 @@
 import super_gradients
 import super_gradients
 import torch
 import torch
 from super_gradients.training.datasets import PascalAUG2012SegmentationDataSetInterface
 from super_gradients.training.datasets import PascalAUG2012SegmentationDataSetInterface
-from super_gradients.training import SgModel, MultiGPUMode
-from super_gradients.training.sg_model.sg_model import StrictLoad
+from super_gradients.training import Trainer, MultiGPUMode
+from super_gradients.training.sg_trainer.sg_trainer import StrictLoad
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
 
 
 super_gradients.init_trainer()
 super_gradients.init_trainer()
@@ -44,13 +44,13 @@ model_size_str = '34'
 experiment_name_prefix = 'shelfnet_lw_'
 experiment_name_prefix = 'shelfnet_lw_'
 experiment_name_dataset_suffix = '_pascal_aug_encoding_dataset_train_250_epochs_no_batchnorm_decoder'
 experiment_name_dataset_suffix = '_pascal_aug_encoding_dataset_train_250_epochs_no_batchnorm_decoder'
 experiment_name = experiment_name_prefix + model_size_str + experiment_name_dataset_suffix
 experiment_name = experiment_name_prefix + model_size_str + experiment_name_dataset_suffix
-model = SgModel(experiment_name, model_checkpoints_location='local', multi_gpu=True,
-                ckpt_name='resnet' + model_size_str + '.pth')
+trainer = Trainer(experiment_name, model_checkpoints_location='local', multi_gpu=True,
+                  ckpt_name='resnet' + model_size_str + '.pth')
 
 
 pascal_aug_datasaet_interface = PascalAUG2012SegmentationDataSetInterface(
 pascal_aug_datasaet_interface = PascalAUG2012SegmentationDataSetInterface(
     dataset_params=pascal_aug_dataset_params,
     dataset_params=pascal_aug_dataset_params,
     cache_labels=False)
     cache_labels=False)
-model.connect_dataset_interface(pascal_aug_datasaet_interface, data_loader_num_workers=data_loader_num_workers)
-model.build_model('shelfnet' + model_size_str, arch_params=shelfnet_lw_arch_params, checkpoint_params=checkpoint_params)
+trainer.connect_dataset_interface(pascal_aug_datasaet_interface, data_loader_num_workers=data_loader_num_workers)
+trainer.build_model('shelfnet' + model_size_str, arch_params=shelfnet_lw_arch_params, checkpoint_params=checkpoint_params)
 print('Training ShelfNet-LW model: ' + experiment_name)
 print('Training ShelfNet-LW model: ' + experiment_name)
-model.train(training_params=shelfnet_lw_pascal_aug_training_params)
+trainer.train(training_params=shelfnet_lw_pascal_aug_training_params)
Discard
@@ -14,7 +14,7 @@ from super_gradients.training.kd_trainer import KDTrainer
 
 
 @hydra.main(config_path=pkg_resources.resource_filename("super_gradients.recipes", ""))
 @hydra.main(config_path=pkg_resources.resource_filename("super_gradients.recipes", ""))
 def main(cfg: DictConfig) -> None:
 def main(cfg: DictConfig) -> None:
-    KDTrainer.train(cfg)
+    KDTrainer.train_from_config(cfg)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard
@@ -9,12 +9,13 @@ import super_gradients
 from omegaconf import DictConfig
 from omegaconf import DictConfig
 import hydra
 import hydra
 import pkg_resources
 import pkg_resources
-from super_gradients.training.trainer import Trainer
+
+from super_gradients import Trainer
 
 
 
 
 @hydra.main(config_path=pkg_resources.resource_filename("super_gradients.recipes", ""))
 @hydra.main(config_path=pkg_resources.resource_filename("super_gradients.recipes", ""))
 def main(cfg: DictConfig) -> None:
 def main(cfg: DictConfig) -> None:
-    Trainer.train(cfg)
+    Trainer.train_from_config(cfg)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard
@@ -1,4 +1,4 @@
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 from super_gradients.training import MultiGPUMode
 from super_gradients.training import MultiGPUMode
 from dataset import UserDataset
 from dataset import UserDataset
 from model import ResNet, BasicBlock
 from model import ResNet, BasicBlock
@@ -11,17 +11,17 @@ def main():
     arch_params = {'num_classes': 10}
     arch_params = {'num_classes': 10}
     model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=arch_params['num_classes'])
     model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=arch_params['num_classes'])
 
 
-    deci_classification_model = SgModel('client_model_training',
-                                        model_checkpoints_location='local',
-                                        multi_gpu=MultiGPUMode.OFF)
+    trainer = Trainer('client_model_training',
+                      model_checkpoints_location='local',
+                      multi_gpu=MultiGPUMode.OFF)
 
 
     # if a torch.nn.Module is provided when building the model, the model will be integrated into deci model class
     # if a torch.nn.Module is provided when building the model, the model will be integrated into deci model class
-    deci_classification_model.build_model(model, arch_params=arch_params)
+    trainer.build_model(model, arch_params=arch_params)
 
 
     # ------------------ Loading The Dataset From Dataset.py----------------
     # ------------------ Loading The Dataset From Dataset.py----------------
     dataset_params = {"batch_size": 256}
     dataset_params = {"batch_size": 256}
     dataset = UserDataset(dataset_params)
     dataset = UserDataset(dataset_params)
-    deci_classification_model.connect_dataset_interface(dataset)
+    trainer.connect_dataset_interface(dataset)
 
 
     # ------------------ Loading The Loss From Loss.py -----------------
     # ------------------ Loading The Loss From Loss.py -----------------
     loss = LabelSmoothingCrossEntropyLoss()
     loss = LabelSmoothingCrossEntropyLoss()
@@ -48,7 +48,7 @@ def main():
                     "metric_to_watch": "Accuracy",
                     "metric_to_watch": "Accuracy",
                     "greater_metric_to_watch_is_better": True}
                     "greater_metric_to_watch_is_better": True}
 
 
-    deci_classification_model.train(train_params)
+    trainer.train(train_params)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -3,5 +3,5 @@ load_backbone: False # whether to load only backbone part of checkpoint
 external_checkpoint_path: # checkpoint path that is not located in super_gradients/checkpoints
 external_checkpoint_path: # checkpoint path that is not located in super_gradients/checkpoints
 source_ckpt_folder_name: # dirname for checkpoint loading
 source_ckpt_folder_name: # dirname for checkpoint loading
 strict_load: # key matching strictness for loading checkpoint's weights
 strict_load: # key matching strictness for loading checkpoint's weights
-  _target_: super_gradients.training.sg_model.StrictLoad
+  _target_: super_gradients.training.sg_trainer.StrictLoad
   value: True
   value: True
Discard
@@ -25,11 +25,5 @@ checkpoint_params:
 model_checkpoints_location: local
 model_checkpoints_location: local
 ckpt_root_dir:
 ckpt_root_dir:
 
 
-sg_model:
-  _target_: super_gradients.SgModel
-  experiment_name: ${experiment_name}
-  model_checkpoints_location: ${model_checkpoints_location}
-  multi_gpu: Off
-
 architecture: resnet18_cifar
 architecture: resnet18_cifar
 
 
Discard
@@ -115,10 +115,4 @@ multi_gpu:
   _target_: super_gradients.training.sg_model.MultiGPUMode
   _target_: super_gradients.training.sg_model.MultiGPUMode
   value: 'DDP'
   value: 'DDP'
 
 
-sg_model:
-  _target_: super_gradients.SgModel
-  experiment_name: ${experiment_name}
-  model_checkpoints_location: ${model_checkpoints_location}
-  ckpt_root_dir: ${ckpt_root_dir}
-  multi_gpu: ${multi_gpu}
 
 
Discard
@@ -126,9 +126,3 @@ checkpoint_params:
 
 
 model_checkpoints_location: local
 model_checkpoints_location: local
 ckpt_root_dir:
 ckpt_root_dir:
-
-sg_model:
-  _target_: super_gradients.SgModel
-  experiment_name: ${experiment_name}
-  model_checkpoints_location: ${model_checkpoints_location}
-  multi_gpu: DDP
Discard
@@ -64,10 +64,3 @@ multi_gpu:
   _target_: super_gradients.training.sg_model.MultiGPUMode
   _target_: super_gradients.training.sg_model.MultiGPUMode
   value: 'DDP'
   value: 'DDP'
 
 
-sg_model:
-  _target_: super_gradients.SgModel
-  experiment_name: ${experiment_name}
-  model_checkpoints_location: ${model_checkpoints_location}
-  ckpt_root_dir: ${ckpt_root_dir}
-  multi_gpu: ${multi_gpu}
-
Discard
@@ -38,11 +38,6 @@ experiment_suffix: res${dataset_params.train_image_size}
 experiment_name: ${architecture}_coco2017_${experiment_suffix}
 experiment_name: ${architecture}_coco2017_${experiment_suffix}
 
 
 ckpt_root_dir:
 ckpt_root_dir:
-sg_model:
-  _target_: super_gradients.SgModel
-  experiment_name: ${experiment_name}
-  model_checkpoints_location: ${model_checkpoints_location}
-  multi_gpu: ${multi_gpu}
-  ckpt_root_dir: ${ckpt_root_dir}
+
 
 
 
 
Discard
@@ -32,17 +32,9 @@ checkpoint_params:
 experiment_name: coco_segmentation_21_subclass_shelfnet34
 experiment_name: coco_segmentation_21_subclass_shelfnet34
 
 
 multi_gpu:
 multi_gpu:
-  _target_: super_gradients.training.sg_model.MultiGPUMode
+  _target_: super_gradients.training.sg_trainer.MultiGPUMode
   value: 'DDP'
   value: 'DDP'
 
 
 ckpt_root_dir:
 ckpt_root_dir:
 
 
-sg_model:
-  _target_: super_gradients.SgModel
-  experiment_name: ${experiment_name}
-  model_checkpoints_location: ${model_checkpoints_location}
-  ckpt_root_dir: ${ckpt_root_dir}
-  multi_gpu: ${multi_gpu}
-
-
 architecture: shelfnet34_lw
 architecture: shelfnet34_lw
Discard
@@ -41,14 +41,7 @@ model_checkpoints_location: local
 ckpt_root_dir:
 ckpt_root_dir:
 
 
 multi_gpu:
 multi_gpu:
-  _target_: super_gradients.training.sg_model.MultiGPUMode
+  _target_: super_gradients.training.sg_trainer.MultiGPUMode
   value: 'DDP'
   value: 'DDP'
 
 
-sg_model:
-  _target_: super_gradients.SgModel
-  experiment_name: ${experiment_name}
-  model_checkpoints_location: ${model_checkpoints_location}
-  ckpt_root_dir: ${ckpt_root_dir}
-  multi_gpu: ${multi_gpu}
-
 architecture: efficientnet_b0
 architecture: efficientnet_b0
Discard
@@ -42,14 +42,7 @@ experiment_name: mobileNetv2_training
 ckpt_root_dir:
 ckpt_root_dir:
 
 
 multi_gpu:
 multi_gpu:
-  _target_: super_gradients.training.sg_model.MultiGPUMode
+  _target_: super_gradients.training.sg_trainer.MultiGPUMode
   value: 'DDP'
   value: 'DDP'
 
 
-sg_model:
-  _target_: super_gradients.SgModel
-  experiment_name: ${experiment_name}
-  model_checkpoints_location: ${model_checkpoints_location}
-  ckpt_root_dir: ${ckpt_root_dir}
-  multi_gpu: ${multi_gpu}
-
 architecture: mobilenet_v2
 architecture: mobilenet_v2
Discard
@@ -24,14 +24,7 @@ experiment_name: mobileNetv3_large_training
 ckpt_root_dir:
 ckpt_root_dir:
 
 
 multi_gpu:
 multi_gpu:
-  _target_: super_gradients.training.sg_model.MultiGPUMode
+  _target_: super_gradients.training.sg_trainer.MultiGPUMode
   value: 'DDP'
   value: 'DDP'
 
 
-sg_model:
-  _target_: super_gradients.SgModel
-  experiment_name: ${experiment_name}
-  model_checkpoints_location: ${model_checkpoints_location}
-  ckpt_root_dir: ${ckpt_root_dir}
-  multi_gpu: ${multi_gpu}
-
 architecture: mobilenet_v3_large
 architecture: mobilenet_v3_large
Discard
@@ -58,14 +58,8 @@ experiment_name: regnetY800_imagenet
 ckpt_root_dir:
 ckpt_root_dir:
 
 
 multi_gpu:
 multi_gpu:
-  _target_: super_gradients.training.sg_model.MultiGPUMode
+  _target_: super_gradients.training.sg_trainer.MultiGPUMode
   value: 'Off'
   value: 'Off'
 
 
-sg_model:
-  _target_: super_gradients.SgModel
-  experiment_name: ${experiment_name}
-  model_checkpoints_location: ${model_checkpoints_location}
-  ckpt_root_dir: ${ckpt_root_dir}
-  multi_gpu: ${multi_gpu}
 
 
 architecture: regnetY800
 architecture: regnetY800
Discard
@@ -36,14 +36,7 @@ checkpoint_params:
 experiment_name: repvgg_a0_imagenet_reproduce_fix
 experiment_name: repvgg_a0_imagenet_reproduce_fix
 
 
 multi_gpu:
 multi_gpu:
-  _target_: super_gradients.training.sg_model.MultiGPUMode
+  _target_: super_gradients.training.sg_trainer.MultiGPUMode
   value: 'DDP'
   value: 'DDP'
 
 
-sg_model:
-  _target_: super_gradients.SgModel
-  experiment_name: ${experiment_name}
-  model_checkpoints_location: ${model_checkpoints_location}
-  ckpt_root_dir: ${ckpt_root_dir}
-  multi_gpu: ${multi_gpu}
-
 architecture: repvgg_a0
 architecture: repvgg_a0
Discard
@@ -49,14 +49,7 @@ experiment_name: resnet50_imagenet
 ckpt_root_dir:
 ckpt_root_dir:
 
 
 multi_gpu:
 multi_gpu:
-  _target_: super_gradients.training.sg_model.MultiGPUMode
+  _target_: super_gradients.training.sg_trainer.MultiGPUMode
   value: 'DDP'
   value: 'DDP'
 
 
-sg_model:
-  _target_: super_gradients.SgModel
-  experiment_name: ${experiment_name}
-  model_checkpoints_location: ${model_checkpoints_location}
-  ckpt_root_dir: ${ckpt_root_dir}
-  multi_gpu: ${multi_gpu}
-
 architecture: resnet50
 architecture: resnet50
Discard
@@ -26,7 +26,7 @@ training_hyperparams:
 
 
 arch_params:
 arch_params:
   teacher_input_adapter:
   teacher_input_adapter:
-    _target_: super_gradients.training.utils.kd_model_utils.NormalizationAdapter
+    _target_: super_gradients.training.utils.kd_trainer_utils.NormalizationAdapter
     mean_original: [0.485, 0.456, 0.406]
     mean_original: [0.485, 0.456, 0.406]
     std_original: [0.229, 0.224, 0.225]
     std_original: [0.229, 0.224, 0.225]
     mean_required: [0.5, 0.5, 0.5]
     mean_required: [0.5, 0.5, 0.5]
@@ -73,16 +73,9 @@ experiment_name: resnet50_imagenet_KD_Model
 ckpt_root_dir:
 ckpt_root_dir:
 
 
 multi_gpu:
 multi_gpu:
-  _target_: super_gradients.training.sg_model.MultiGPUMode
+  _target_: super_gradients.training.sg_trainer.MultiGPUMode
   value: DDP
   value: DDP
 
 
-sg_model:
-  _target_: super_gradients.KDModel
-  experiment_name: ${experiment_name}
-  model_checkpoints_location: ${model_checkpoints_location}
-  ckpt_root_dir: ${ckpt_root_dir}
-  multi_gpu: ${multi_gpu}
-
 architecture: kd_module
 architecture: kd_module
 student_architecture: resnet50
 student_architecture: resnet50
 teacher_architecture: beit_base_patch16_224
 teacher_architecture: beit_base_patch16_224
Discard
@@ -47,10 +47,4 @@ load_weights_only: True
 
 
 experiment_name: vit_base_imagenet1k
 experiment_name: vit_base_imagenet1k
 
 
-sg_model:
-  _target_: super_gradients.SgModel
-  experiment_name: ${experiment_name}
-  model_checkpoints_location: ${model_checkpoints_location}
-  multi_gpu: AUTO
-
 architecture: vit_base
 architecture: vit_base
Discard
@@ -20,11 +20,5 @@ experiment_name: test
 
 
 model_checkpoints_location: local
 model_checkpoints_location: local
 
 
-sg_model:
-  _target_: super_gradients.SgModel
-  experiment_name: ${experiment_name}
-  model_checkpoints_location: ${model_checkpoints_location}
-  multi_gpu: Off
-
 architecture: resnet18
 architecture: resnet18
 
 
Discard
@@ -1,11 +1,14 @@
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 import super_gradients.training.utils.distributed_training_utils as distributed_training_utils
 import super_gradients.training.utils.distributed_training_utils as distributed_training_utils
-from super_gradients.training.datasets import datasets_utils, DataAugmentation, TestDatasetInterface, SegmentationTestDatasetInterface,\
+from super_gradients.training.datasets import datasets_utils, DataAugmentation, TestDatasetInterface, SegmentationTestDatasetInterface, \
     DetectionTestDatasetInterface, ClassificationTestDatasetInterface
     DetectionTestDatasetInterface, ClassificationTestDatasetInterface
 from super_gradients.training.models import ARCHITECTURES
 from super_gradients.training.models import ARCHITECTURES
-from super_gradients.training.sg_model import SgModel, MultiGPUMode, StrictLoad
+from super_gradients.training.sg_trainer import Trainer
+from super_gradients.training.kd_trainer import KDTrainer
+from super_gradients.training.sg_model import SgModel
 from super_gradients.training.kd_model import KDModel
 from super_gradients.training.kd_model import KDModel
+from super_gradients.common import MultiGPUMode, StrictLoad, EvaluationType
 
 
-__all__ = ['distributed_training_utils', 'datasets_utils', 'DataAugmentation', 'TestDatasetInterface', 'ARCHITECTURES', 'SgModel',
-           'KDModel', 'MultiGPUMode', 'TestDatasetInterface', 'SegmentationTestDatasetInterface', 'DetectionTestDatasetInterface',
-           'ClassificationTestDatasetInterface', 'StrictLoad']
+__all__ = ['distributed_training_utils', 'datasets_utils', 'DataAugmentation', 'TestDatasetInterface',
+           'ARCHITECTURES', 'Trainer', 'KDTrainer', 'MultiGPUMode', 'TestDatasetInterface', 'SegmentationTestDatasetInterface', 'DetectionTestDatasetInterface',
+           'ClassificationTestDatasetInterface', 'StrictLoad', 'SgModel', 'EvaluationType', 'KDModel']
Discard
@@ -14,7 +14,7 @@ import torch.distributed as dist
 from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
 from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
 
 
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
-from deprecated import deprecated
+from deprecate import deprecated
 from matplotlib.patches import Rectangle
 from matplotlib.patches import Rectangle
 from torchvision.datasets import ImageFolder
 from torchvision.datasets import ImageFolder
 from super_gradients.training.datasets.auto_augment import rand_augment_transform
 from super_gradients.training.datasets.auto_augment import rand_augment_transform
@@ -75,7 +75,7 @@ def get_mean_and_std_torch(data_dir=None, dataloader=None, num_workers=4, Random
     return mean.view(-1).cpu().numpy().tolist(), std.view(-1).cpu().numpy().tolist()
     return mean.view(-1).cpu().numpy().tolist(), std.view(-1).cpu().numpy().tolist()
 
 
 
 
-@deprecated(reason='Use get_mean_and_std_torch() instead. It is faster and more accurate')
+@deprecated(target=get_mean_and_std_torch, deprecated_in="2.1.0", remove_in="3.0.0")
 def get_mean_and_std(dataset):
 def get_mean_and_std(dataset):
     '''Compute the mean and std value of dataset.'''
     '''Compute the mean and std value of dataset.'''
     dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)
     dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)
Discard
@@ -6,7 +6,7 @@ class KDModelException(Exception):
     """
     """
 
 
     def __init__(self, desc):
     def __init__(self, desc):
-        self.message = "KDModel: " + desc
+        self.message = "KDTrainer: " + desc
         super().__init__(self.message)
         super().__init__(self.message)
 
 
 
 
@@ -48,7 +48,7 @@ class InconsistentParamsException(KDModelException):
 
 
 
 
 class UnsupportedKDModelArgException(KDModelException):
 class UnsupportedKDModelArgException(KDModelException):
-    """Exception raised for unsupported args that might be supported for SgModel but not for KDModel.
+    """Exception raised for unsupported args that might be supported for Trainer but not for KDTrainer.
 
 
     Attributes:
     Attributes:
         message -- explanation of the error
         message -- explanation of the error
@@ -66,7 +66,7 @@ class TeacherKnowledgeException(KDModelException):
     """
     """
 
 
     def __init__(self):
     def __init__(self):
-        super().__init__("Expected: at least one of: teacher_pretrained_weights, teacher_checkpoint_path or load_kd_model_checkpoint=True")
+        super().__init__("Expected: at least one of: teacher_pretrained_weights, teacher_checkpoint_path or load_kd_trainer_checkpoint=True")
 
 
 
 
 class UndefinedNumClassesException(KDModelException):
 class UndefinedNumClassesException(KDModelException):
Discard
@@ -29,4 +29,4 @@ class IllegalDataloaderInitialization(Exception):
 
 
     def __init__(self):
     def __init__(self):
         super().__init__(
         super().__init__(
-            "train_loader, valid_loader and class parameters are required when initializing SgModel with data loaders")
+            "train_loader, valid_loader and class parameters are required when initializing Trainer with data loaders")
Discard
@@ -1,275 +1,9 @@
-import torch.nn
+from deprecate import deprecated
 
 
-from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
-from super_gradients.training.models.kd_modules.kd_module import KDModule
-from super_gradients.training.sg_model import SgModel
-from typing import Union
-from super_gradients.common.abstractions.abstract_logger import get_logger
-from super_gradients.training import utils as core_utils
-from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
-from super_gradients.training.utils import get_param
-from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, \
-    load_checkpoint_to_model
-from super_gradients.training.exceptions.kd_model_exceptions import ArchitectureKwargsException, \
-    UnsupportedKDArchitectureException, InconsistentParamsException, UnsupportedKDModelArgException, \
-    TeacherKnowledgeException, UndefinedNumClassesException
-from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
-from super_gradients.training.utils.ema import KDModelEMA
-logger = get_logger(__name__)
+from super_gradients.training import KDTrainer
 
 
 
 
-class KDModel(SgModel):
-    def __init__(self, *args, **kwargs):
-        super(KDModel, self).__init__(*args, **kwargs)
-        self.student_architecture = None
-        self.teacher_architecture = None
-        self.student_arch_params = None
-        self.teacher_arch_params = None
-
-    def build_model(self,
-                    # noqa: C901 - too complex
-                    architecture: Union[str, KDModule] = 'kd_module',
-                    arch_params={}, checkpoint_params={},
-                    *args, **kwargs):
-        """
-        :param architecture: (Union[str, KDModule]) Defines the network's architecture from models/KD_ARCHITECTURES
-         (default='kd_module')
-
-        :param arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc to be passed to kd
-            architecture class (discarded when architecture is KDModule instance)
-
-        :param checkpoint_params: (dict) A dictionary like object with the following keys/values:
-
-              student_pretrained_weights:   String describing the dataset of the pretrained weights (for example
-              "imagenent") for the student network.
-
-              teacher_pretrained_weights:   String describing the dataset of the pretrained weights (for example
-              "imagenent") for the teacher network.
-
-              teacher_checkpoint_path:    Local path to the teacher's checkpoint. Note that when passing pretrained_weights
-                                   through teacher_arch_params these weights will be overridden by the
-                                   pretrained checkpoint. (default=None)
-
-              load_kd_model_checkpoint:   Whether to load an entire KDModule checkpoint (used to continue KD training)
-               (default=False)
-
-              kd_model_source_ckpt_folder_name: Folder name to load an entire KDModule checkpoint from
-                (self.experiment_name if none is given) to resume KD training (default=None)
-
-              kd_model_external_checkpoint_path: The path to the external checkpoint to be loaded. Can be absolute or relative
-                                               (ie: path/to/checkpoint.pth). If provided, will automatically attempt to
-                                               load the checkpoint even if the load_checkpoint flag is not provided.
-                                               (deafult=None)
-
-        :keyword student_architecture: (Union[str, SgModule]) Defines the student's architecture from
-            models/ALL_ARCHITECTURES (when str), or directly defined the student network (when SgModule).
-
-        :keyword teacher_architecture: (Union[str, SgModule]) Defines the teacher's architecture from
-            models/ALL_ARCHITECTURES (when str), or directly defined the teacher network (when SgModule).
-
-        :keyword student_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for student
-            net. (deafult={})
-
-        :keyword teacher_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for teacher
-            net. (deafult={})
-
-        :keyword run_teacher_on_eval: (bool)- whether to run self.teacher at eval mode regardless of self.train(mode)
-
-
-        """
-        kwargs.setdefault("student_architecture", None)
-        kwargs.setdefault("teacher_architecture", None)
-        kwargs.setdefault("student_arch_params", {})
-        kwargs.setdefault("teacher_arch_params", {})
-        kwargs.setdefault("run_teacher_on_eval", False)
-
-        self._validate_args(arch_params, architecture, checkpoint_params, **kwargs)
-
-        self.student_architecture = kwargs.get("student_architecture")
-        self.teacher_architecture = kwargs.get("teacher_architecture")
-        self.student_arch_params = kwargs.get("student_arch_params")
-        self.teacher_arch_params = kwargs.get("teacher_arch_params")
-
-        super(KDModel, self).build_model(architecture=architecture, arch_params=arch_params,
-                                         checkpoint_params=checkpoint_params, **kwargs)
-
-    def _validate_args(self, arch_params, architecture, checkpoint_params, **kwargs):
-        student_architecture = get_param(kwargs, "student_architecture")
-        teacher_architecture = get_param(kwargs, "teacher_architecture")
-        student_arch_params = get_param(kwargs, "student_arch_params")
-        teacher_arch_params = get_param(kwargs, "teacher_arch_params")
-
-        if get_param(checkpoint_params, 'pretrained_weights') is not None:
-            raise UnsupportedKDModelArgException("pretrained_weights", "checkpoint_params")
-
-        if not isinstance(architecture, KDModule):
-            if student_architecture is None or teacher_architecture is None:
-                raise ArchitectureKwargsException()
-            if architecture not in KD_ARCHITECTURES.keys():
-                raise UnsupportedKDArchitectureException(architecture)
-
-        # DERIVE NUMBER OF CLASSES FROM DATASET INTERFACE IF NOT SPECIFIED OR ARCH PARAMS FOR TEACHER AND STUDENT
-        self._validate_num_classes(student_arch_params, teacher_arch_params)
-
-        arch_params['num_classes'] = student_arch_params['num_classes']
-
-        # MAKE SURE TEACHER'S PRETRAINED NUM CLASSES EQUALS TO THE ONES BELONGING TO STUDENT AS WE CAN'T REPLACE
-        # THE TEACHER'S HEAD
-        teacher_pretrained_weights = core_utils.get_param(checkpoint_params, 'teacher_pretrained_weights',
-                                                          default_val=None)
-        if teacher_pretrained_weights is not None:
-            teacher_pretrained_num_classes = PRETRAINED_NUM_CLASSES[teacher_pretrained_weights]
-            if teacher_pretrained_num_classes != teacher_arch_params['num_classes']:
-                raise InconsistentParamsException("Pretrained dataset number of classes", "teacher's arch params",
-                                                  "number of classes", "student's number of classes")
-
-        teacher_checkpoint_path = get_param(checkpoint_params, "teacher_checkpoint_path")
-        load_kd_model_checkpoint = get_param(checkpoint_params, "load_checkpoint")
-
-        # CHECK THAT TEACHER NETWORK HOLDS KNOWLEDGE FOR THE STUDENT TO LEARN FROM OR THAT WE ARE LOADING AN ENTIRE KD
-        if not (teacher_pretrained_weights or teacher_checkpoint_path or load_kd_model_checkpoint or isinstance(teacher_architecture, torch.nn.Module)):
-            raise TeacherKnowledgeException()
-
-    def _validate_num_classes(self, student_arch_params, teacher_arch_params):
-        """
-        Checks validity of num_classes for num_classes (i.e existence and consistency between subnets)
-
-        :param student_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for student
-        :param teacher_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for teacher
-
-        """
-        self._validate_subnet_num_classes(student_arch_params)
-        self._validate_subnet_num_classes(teacher_arch_params)
-        if teacher_arch_params['num_classes'] != student_arch_params['num_classes']:
-            raise InconsistentParamsException("num_classes", "student_arch_params", "num_classes",
-                                              "teacher_arch_params")
-
-    def _validate_subnet_num_classes(self, subnet_arch_params):
-        """
-        Derives num_classes in student_arch_params/teacher_arch_params from dataset interface or raises an error
-         when none is given
-
-        :param subnet_arch_params: Arch params for student/teacher
-
-        """
-
-        if 'num_classes' not in subnet_arch_params.keys():
-            if self.dataset_interface is None:
-                raise UndefinedNumClassesException()
-            else:
-                subnet_arch_params['num_classes'] = len(self.classes)
-
-    def _instantiate_net(self, architecture: Union[KDModule, KDModule.__class__, str], arch_params: dict,
-                         checkpoint_params: dict, *args, **kwargs) -> tuple:
-        """
-        Instantiates kd_module according to architecture and arch_params, handles pretrained weights for the student
-         and teacher networks, and the required module manipulation (i.e head replacement) for the teacher network.
-
-        :param architecture: String, KDModule or uninstantiated KDModule class describing the netowrks architecture.
-        :param arch_params: Architecture's parameters passed to networks c'tor.
-        :param checkpoint_params: checkpoint loading related parameters dictionary with 'pretrained_weights' key,
-            s.t it's value is a string describing the dataset of the pretrained weights (for example "imagenent").
-
-        :return: instantiated netowrk i.e KDModule, architecture_class (will be none when architecture is not str)
-        """
-
-        student_architecture = get_param(kwargs, "student_architecture")
-        teacher_architecture = get_param(kwargs, "teacher_architecture")
-        student_arch_params = get_param(kwargs, "student_arch_params")
-        teacher_arch_params = get_param(kwargs, "teacher_arch_params")
-        student_arch_params = core_utils.HpmStruct(**student_arch_params)
-        teacher_arch_params = core_utils.HpmStruct(**teacher_arch_params)
-        student_pretrained_weights = get_param(checkpoint_params, 'student_pretrained_weights')
-        teacher_pretrained_weights = get_param(checkpoint_params, 'teacher_pretrained_weights')
-
-        student = super()._instantiate_net(student_architecture, student_arch_params,
-                                           {"pretrained_weights": student_pretrained_weights})
-        teacher = super()._instantiate_net(teacher_architecture, teacher_arch_params,
-                                           {"pretrained_weights": teacher_pretrained_weights})
-
-        run_teacher_on_eval = get_param(kwargs, "run_teacher_on_eval", default_val=False)
-
-        if isinstance(architecture, str):
-            architecture_cls = KD_ARCHITECTURES[architecture]
-            net = architecture_cls(arch_params=arch_params, student=student, teacher=teacher,
-                                   run_teacher_on_eval=run_teacher_on_eval)
-        elif isinstance(architecture, KDModule.__class__):
-            net = architecture(arch_params=arch_params, student=student, teacher=teacher,
-                               run_teacher_on_eval=run_teacher_on_eval)
-        else:
-            net = architecture
-
-        return net
-
-    def _load_checkpoint_to_model(self):
-        """
-        Initializes teacher weights with teacher_checkpoint_path if needed, then handles checkpoint loading for
-         the entire KD network following the same logic as in SgModel.
-        """
-        teacher_checkpoint_path = get_param(self.checkpoint_params, "teacher_checkpoint_path")
-        teacher_net = self.net.module.teacher
-
-        if teacher_checkpoint_path is not None:
-
-            #  WARN THAT TEACHER_CKPT WILL OVERRIDE TEACHER'S PRETRAINED WEIGHTS
-            teacher_pretrained_weights = get_param(self.checkpoint_params, "teacher_pretrained_weights")
-            if teacher_pretrained_weights:
-                logger.warning(
-                    teacher_checkpoint_path + " checkpoint is "
-                                              "overriding " + teacher_pretrained_weights + " for teacher model")
-
-            # ALWAYS LOAD ITS EMA IF IT EXISTS
-            load_teachers_ema = 'ema_net' in read_ckpt_state_dict(teacher_checkpoint_path).keys()
-            load_checkpoint_to_model(ckpt_local_path=teacher_checkpoint_path,
-                                     load_backbone=False,
-                                     net=teacher_net,
-                                     strict='no_key_matching',
-                                     load_weights_only=True,
-                                     load_ema_as_net=load_teachers_ema)
-
-        super(KDModel, self)._load_checkpoint_to_model()
-
-    def _add_metrics_update_callback(self, phase):
-        """
-        Adds KDModelMetricsUpdateCallback to be fired at phase
-
-        :param phase: Phase for the metrics callback to be fired at
-        """
-        self.phase_callbacks.append(KDModelMetricsUpdateCallback(phase))
-
-    def _get_hyper_param_config(self):
-        """
-        Creates a training hyper param config for logging with additional KD related hyper params.
-        """
-        hyper_param_config = super()._get_hyper_param_config()
-        hyper_param_config.update({"student_architecture": self.student_architecture,
-                                   "teacher_architecture": self.teacher_architecture,
-                                   "student_arch_params": self.student_arch_params,
-                                   "teacher_arch_params": self.teacher_arch_params
-                                   })
-        return hyper_param_config
-
-    def _instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> KDModelEMA:
-        """Instantiate KD ema model for KDModule.
-
-        If the model is of class KDModule, the instance will be adapted to work on knowledge distillation.
-        :param decay:           the maximum decay value. as the training process advances, the decay will climb towards
-                                this value until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
-        :param beta:            the exponent coefficient. The higher the beta, the sooner in the training the decay will
-                                saturate to its final value. beta=15 is ~40% of the training process.
-        :param exp_activation:
-        """
-        return KDModelEMA(self.net, decay, beta, exp_activation)
-
-    def _save_best_checkpoint(self, epoch, state):
-        """
-        Overrides parent best_ckpt saving to modify the state dict so that we only save the student.
-        """
-        if self.ema:
-            best_net = core_utils.WrappedModel(self.ema_model.ema.module.student)
-            state.pop("ema_net")
-        else:
-            best_net = core_utils.WrappedModel(self.net.module.student)
-
-        state["net"] = best_net.state_dict()
-        self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
+@deprecated(target=KDTrainer, deprecated_in='2.3.0', remove_in='3.0.0')
+class KDModel(KDTrainer):
+    def __init__(self, experiment_name: str, *args, **kwargs):
+        super().__init__(experiment_name, *args, **kwargs)
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
  1. from super_gradients.training.trainer import Trainer
  2. class KDTrainer(Trainer):
  3. """
  4. Class for running SuperGradient's recipes for KD Models.
  5. See train_from_kd_recipe example in the examples directory to demonstrate it's usage.
  6. """
  7. @classmethod
  8. def build_model(cls, cfg):
  9. cfg.sg_model.build_model(student_architecture=cfg.student_architecture,
  10. teacher_architecture=cfg.teacher_architecture,
  11. arch_params=cfg.arch_params, student_arch_params=cfg.student_arch_params,
  12. teacher_arch_params=cfg.teacher_arch_params,
  13. checkpoint_params=cfg.checkpoint_params, run_teacher_on_eval=cfg.run_teacher_on_eval)
Discard
1
2
3
4
5
  1. # PACKAGE IMPORTS FOR EXTERNAL USAGE
  2. from super_gradients.training.kd_trainer.kd_trainer import KDTrainer
  3. __all__ = ['KDTrainer']
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
  1. import hydra
  2. import torch.nn
  3. from omegaconf import DictConfig
  4. from torch.utils.data import DataLoader
  5. from super_gradients.common import MultiGPUMode
  6. from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
  7. from super_gradients.training.models.kd_modules.kd_module import KDModule
  8. from super_gradients.training.sg_trainer import Trainer
  9. from typing import Union, List, Any
  10. from super_gradients.common.abstractions.abstract_logger import get_logger
  11. from super_gradients.training import utils as core_utils
  12. from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
  13. from super_gradients.training.utils import get_param
  14. from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, \
  15. load_checkpoint_to_model
  16. from super_gradients.training.exceptions.kd_trainer_exceptions import ArchitectureKwargsException, \
  17. UnsupportedKDArchitectureException, InconsistentParamsException, UnsupportedKDModelArgException, \
  18. TeacherKnowledgeException, UndefinedNumClassesException
  19. from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
  20. from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
  21. from super_gradients.training.utils.ema import KDModelEMA
  22. from super_gradients.training.utils.sg_trainer_utils import parse_args
  23. logger = get_logger(__name__)
  24. class KDTrainer(Trainer):
  25. def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
  26. model_checkpoints_location: str = 'local', overwrite_local_checkpoint: bool = True, ckpt_name: str = 'ckpt_latest.pth',
  27. post_prediction_callback: DetectionPostPredictionCallback = None, ckpt_root_dir: str = None, train_loader: DataLoader = None,
  28. valid_loader: DataLoader = None, test_loader: DataLoader = None, classes: List[Any] = None):
  29. super().__init__(experiment_name, device, multi_gpu, model_checkpoints_location, overwrite_local_checkpoint, ckpt_name, post_prediction_callback,
  30. ckpt_root_dir, train_loader, valid_loader, test_loader, classes)
  31. self.student_architecture = None
  32. self.teacher_architecture = None
  33. self.student_arch_params = None
  34. self.teacher_arch_params = None
  35. @classmethod
  36. def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None:
  37. """
  38. Trains according to cfg recipe configuration.
  39. @param cfg: The parsed DictConfig from yaml recipe files
  40. @return: output of kd_trainer.train(...) (i.e results tuple)
  41. """
  42. # INSTANTIATE ALL OBJECTS IN CFG
  43. cfg = hydra.utils.instantiate(cfg)
  44. kwargs = parse_args(cfg, cls.__init__)
  45. trainer = KDTrainer(**kwargs)
  46. # CONNECT THE DATASET INTERFACE WITH DECI MODEL
  47. trainer.connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
  48. # BUILD NETWORK
  49. trainer.build_model(student_architecture=cfg.student_architecture,
  50. teacher_architecture=cfg.teacher_architecture,
  51. arch_params=cfg.arch_params, student_arch_params=cfg.student_arch_params,
  52. teacher_arch_params=cfg.teacher_arch_params,
  53. checkpoint_params=cfg.checkpoint_params, run_teacher_on_eval=cfg.run_teacher_on_eval)
  54. # TRAIN
  55. trainer.train(training_params=cfg.training_hyperparams)
  56. def build_model(self,
  57. # noqa: C901 - too complex
  58. architecture: Union[str, KDModule] = 'kd_module',
  59. arch_params={}, checkpoint_params={},
  60. *args, **kwargs):
  61. """
  62. :param architecture: (Union[str, KDModule]) Defines the network's architecture from models/KD_ARCHITECTURES
  63. (default='kd_module')
  64. :param arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc to be passed to kd
  65. architecture class (discarded when architecture is KDModule instance)
  66. :param checkpoint_params: (dict) A dictionary like object with the following keys/values:
  67. student_pretrained_weights: String describing the dataset of the pretrained weights (for example
  68. "imagenent") for the student network.
  69. teacher_pretrained_weights: String describing the dataset of the pretrained weights (for example
  70. "imagenent") for the teacher network.
  71. teacher_checkpoint_path: Local path to the teacher's checkpoint. Note that when passing pretrained_weights
  72. through teacher_arch_params these weights will be overridden by the
  73. pretrained checkpoint. (default=None)
  74. load_kd_model_checkpoint: Whether to load an entire KDModule checkpoint (used to continue KD training)
  75. (default=False)
  76. kd_model_source_ckpt_folder_name: Folder name to load an entire KDModule checkpoint from
  77. (self.experiment_name if none is given) to resume KD training (default=None)
  78. kd_model_external_checkpoint_path: The path to the external checkpoint to be loaded. Can be absolute or relative
  79. (ie: path/to/checkpoint.pth). If provided, will automatically attempt to
  80. load the checkpoint even if the load_checkpoint flag is not provided.
  81. (deafult=None)
  82. :keyword student_architecture: (Union[str, SgModule]) Defines the student's architecture from
  83. models/ALL_ARCHITECTURES (when str), or directly defined the student network (when SgModule).
  84. :keyword teacher_architecture: (Union[str, SgModule]) Defines the teacher's architecture from
  85. models/ALL_ARCHITECTURES (when str), or directly defined the teacher network (when SgModule).
  86. :keyword student_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for student
  87. net. (deafult={})
  88. :keyword teacher_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for teacher
  89. net. (deafult={})
  90. :keyword run_teacher_on_eval: (bool)- whether to run self.teacher at eval mode regardless of self.train(mode)
  91. """
  92. kwargs.setdefault("student_architecture", None)
  93. kwargs.setdefault("teacher_architecture", None)
  94. kwargs.setdefault("student_arch_params", {})
  95. kwargs.setdefault("teacher_arch_params", {})
  96. kwargs.setdefault("run_teacher_on_eval", False)
  97. self._validate_args(arch_params, architecture, checkpoint_params, **kwargs)
  98. self.student_architecture = kwargs.get("student_architecture")
  99. self.teacher_architecture = kwargs.get("teacher_architecture")
  100. self.student_arch_params = kwargs.get("student_arch_params")
  101. self.teacher_arch_params = kwargs.get("teacher_arch_params")
  102. super(KDTrainer, self).build_model(architecture=architecture, arch_params=arch_params,
  103. checkpoint_params=checkpoint_params, **kwargs)
  104. def _validate_args(self, arch_params, architecture, checkpoint_params, **kwargs):
  105. student_architecture = get_param(kwargs, "student_architecture")
  106. teacher_architecture = get_param(kwargs, "teacher_architecture")
  107. student_arch_params = get_param(kwargs, "student_arch_params")
  108. teacher_arch_params = get_param(kwargs, "teacher_arch_params")
  109. if get_param(checkpoint_params, 'pretrained_weights') is not None:
  110. raise UnsupportedKDModelArgException("pretrained_weights", "checkpoint_params")
  111. if not isinstance(architecture, KDModule):
  112. if student_architecture is None or teacher_architecture is None:
  113. raise ArchitectureKwargsException()
  114. if architecture not in KD_ARCHITECTURES.keys():
  115. raise UnsupportedKDArchitectureException(architecture)
  116. # DERIVE NUMBER OF CLASSES FROM DATASET INTERFACE IF NOT SPECIFIED OR ARCH PARAMS FOR TEACHER AND STUDENT
  117. self._validate_num_classes(student_arch_params, teacher_arch_params)
  118. arch_params['num_classes'] = student_arch_params['num_classes']
  119. # MAKE SURE TEACHER'S PRETRAINED NUM CLASSES EQUALS TO THE ONES BELONGING TO STUDENT AS WE CAN'T REPLACE
  120. # THE TEACHER'S HEAD
  121. teacher_pretrained_weights = core_utils.get_param(checkpoint_params, 'teacher_pretrained_weights',
  122. default_val=None)
  123. if teacher_pretrained_weights is not None:
  124. teacher_pretrained_num_classes = PRETRAINED_NUM_CLASSES[teacher_pretrained_weights]
  125. if teacher_pretrained_num_classes != teacher_arch_params['num_classes']:
  126. raise InconsistentParamsException("Pretrained dataset number of classes", "teacher's arch params",
  127. "number of classes", "student's number of classes")
  128. teacher_checkpoint_path = get_param(checkpoint_params, "teacher_checkpoint_path")
  129. load_kd_model_checkpoint = get_param(checkpoint_params, "load_checkpoint")
  130. # CHECK THAT TEACHER NETWORK HOLDS KNOWLEDGE FOR THE STUDENT TO LEARN FROM OR THAT WE ARE LOADING AN ENTIRE KD
  131. if not (teacher_pretrained_weights or teacher_checkpoint_path or load_kd_model_checkpoint or isinstance(teacher_architecture, torch.nn.Module)):
  132. raise TeacherKnowledgeException()
  133. def _validate_num_classes(self, student_arch_params, teacher_arch_params):
  134. """
  135. Checks validity of num_classes for num_classes (i.e existence and consistency between subnets)
  136. :param student_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for student
  137. :param teacher_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for teacher
  138. """
  139. self._validate_subnet_num_classes(student_arch_params)
  140. self._validate_subnet_num_classes(teacher_arch_params)
  141. if teacher_arch_params['num_classes'] != student_arch_params['num_classes']:
  142. raise InconsistentParamsException("num_classes", "student_arch_params", "num_classes",
  143. "teacher_arch_params")
  144. def _validate_subnet_num_classes(self, subnet_arch_params):
  145. """
  146. Derives num_classes in student_arch_params/teacher_arch_params from dataset interface or raises an error
  147. when none is given
  148. :param subnet_arch_params: Arch params for student/teacher
  149. """
  150. if 'num_classes' not in subnet_arch_params.keys():
  151. if self.dataset_interface is None:
  152. raise UndefinedNumClassesException()
  153. else:
  154. subnet_arch_params['num_classes'] = len(self.classes)
  155. def _instantiate_net(self, architecture: Union[KDModule, KDModule.__class__, str], arch_params: dict,
  156. checkpoint_params: dict, *args, **kwargs) -> tuple:
  157. """
  158. Instantiates kd_module according to architecture and arch_params, handles pretrained weights for the student
  159. and teacher networks, and the required module manipulation (i.e head replacement) for the teacher network.
  160. :param architecture: String, KDModule or uninstantiated KDModule class describing the netowrks architecture.
  161. :param arch_params: Architecture's parameters passed to networks c'tor.
  162. :param checkpoint_params: checkpoint loading related parameters dictionary with 'pretrained_weights' key,
  163. s.t it's value is a string describing the dataset of the pretrained weights (for example "imagenent").
  164. :return: instantiated netowrk i.e KDModule, architecture_class (will be none when architecture is not str)
  165. """
  166. student_architecture = get_param(kwargs, "student_architecture")
  167. teacher_architecture = get_param(kwargs, "teacher_architecture")
  168. student_arch_params = get_param(kwargs, "student_arch_params")
  169. teacher_arch_params = get_param(kwargs, "teacher_arch_params")
  170. student_arch_params = core_utils.HpmStruct(**student_arch_params)
  171. teacher_arch_params = core_utils.HpmStruct(**teacher_arch_params)
  172. student_pretrained_weights = get_param(checkpoint_params, 'student_pretrained_weights')
  173. teacher_pretrained_weights = get_param(checkpoint_params, 'teacher_pretrained_weights')
  174. student = super()._instantiate_net(student_architecture, student_arch_params,
  175. {"pretrained_weights": student_pretrained_weights})
  176. teacher = super()._instantiate_net(teacher_architecture, teacher_arch_params,
  177. {"pretrained_weights": teacher_pretrained_weights})
  178. run_teacher_on_eval = get_param(kwargs, "run_teacher_on_eval", default_val=False)
  179. if isinstance(architecture, str):
  180. architecture_cls = KD_ARCHITECTURES[architecture]
  181. net = architecture_cls(arch_params=arch_params, student=student, teacher=teacher,
  182. run_teacher_on_eval=run_teacher_on_eval)
  183. elif isinstance(architecture, KDModule.__class__):
  184. net = architecture(arch_params=arch_params, student=student, teacher=teacher,
  185. run_teacher_on_eval=run_teacher_on_eval)
  186. else:
  187. net = architecture
  188. return net
  189. def _load_checkpoint_to_model(self):
  190. """
  191. Initializes teacher weights with teacher_checkpoint_path if needed, then handles checkpoint loading for
  192. the entire KD network following the same logic as in Trainer.
  193. """
  194. teacher_checkpoint_path = get_param(self.checkpoint_params, "teacher_checkpoint_path")
  195. teacher_net = self.net.module.teacher
  196. if teacher_checkpoint_path is not None:
  197. # WARN THAT TEACHER_CKPT WILL OVERRIDE TEACHER'S PRETRAINED WEIGHTS
  198. teacher_pretrained_weights = get_param(self.checkpoint_params, "teacher_pretrained_weights")
  199. if teacher_pretrained_weights:
  200. logger.warning(
  201. teacher_checkpoint_path + " checkpoint is "
  202. "overriding " + teacher_pretrained_weights + " for teacher model")
  203. # ALWAYS LOAD ITS EMA IF IT EXISTS
  204. load_teachers_ema = 'ema_net' in read_ckpt_state_dict(teacher_checkpoint_path).keys()
  205. load_checkpoint_to_model(ckpt_local_path=teacher_checkpoint_path,
  206. load_backbone=False,
  207. net=teacher_net,
  208. strict='no_key_matching',
  209. load_weights_only=True,
  210. load_ema_as_net=load_teachers_ema)
  211. super(KDTrainer, self)._load_checkpoint_to_model()
  212. def _add_metrics_update_callback(self, phase):
  213. """
  214. Adds KDModelMetricsUpdateCallback to be fired at phase
  215. :param phase: Phase for the metrics callback to be fired at
  216. """
  217. self.phase_callbacks.append(KDModelMetricsUpdateCallback(phase))
  218. def _get_hyper_param_config(self):
  219. """
  220. Creates a training hyper param config for logging with additional KD related hyper params.
  221. """
  222. hyper_param_config = super()._get_hyper_param_config()
  223. hyper_param_config.update({"student_architecture": self.student_architecture,
  224. "teacher_architecture": self.teacher_architecture,
  225. "student_arch_params": self.student_arch_params,
  226. "teacher_arch_params": self.teacher_arch_params
  227. })
  228. return hyper_param_config
  229. def _instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> KDModelEMA:
  230. """Instantiate KD ema model for KDModule.
  231. If the model is of class KDModule, the instance will be adapted to work on knowledge distillation.
  232. :param decay: the maximum decay value. as the training process advances, the decay will climb towards
  233. this value until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
  234. :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will
  235. saturate to its final value. beta=15 is ~40% of the training process.
  236. :param exp_activation:
  237. """
  238. return KDModelEMA(self.net, decay, beta, exp_activation)
  239. def _save_best_checkpoint(self, epoch, state):
  240. """
  241. Overrides parent best_ckpt saving to modify the state dict so that we only save the student.
  242. """
  243. if self.ema:
  244. best_net = core_utils.WrappedModel(self.ema_model.ema.module.student)
  245. state.pop("ema_net")
  246. else:
  247. best_net = core_utils.WrappedModel(self.net.module.student)
  248. state["net"] = best_net.state_dict()
  249. self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
Discard
@@ -480,7 +480,7 @@ class STDCSegmentationBase(SgModule):
     @property
     @property
     def backbone(self):
     def backbone(self):
         """
         """
-        For SgModel load_backbone compatibility.
+        For Trainer load_backbone compatibility.
         """
         """
         return self.cp.backbone
         return self.cp.backbone
 
 
Discard
@@ -1,5 +1,5 @@
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
+from super_gradients.common import MultiGPUMode, StrictLoad, EvaluationType
+from super_gradients.training.sg_model.sg_model import SgModel
 
 
-from super_gradients.training.sg_model.sg_model import SgModel, MultiGPUMode, StrictLoad
-
-__all__ = ['SgModel', 'MultiGPUMode', 'StrictLoad']
+__all__ = ['SgModel', 'MultiGPUMode', 'StrictLoad', 'EvaluationType']
Discard
Only showing up to 1000 lines per file, please use a local Git client to see the full diff.
@@ -1,1826 +1,9 @@
-import inspect
-import os
-import sys
-from copy import deepcopy
-from typing import Union, Tuple, Mapping, List, Any
+from deprecate import deprecated
 
 
-import numpy as np
-import pkg_resources
-import torch
-import torchvision.transforms as transforms
-from deprecated import deprecated
-from torch import nn
-from torch.utils.data import DataLoader, DistributedSampler
-from torch.cuda.amp import GradScaler, autocast
-from torchmetrics import MetricCollection
-from tqdm import tqdm
-from piptools.scripts.sync import _get_installed_distributions
+from super_gradients.training import Trainer
 
 
-from super_gradients.common.factories.callbacks_factory import CallbacksFactory
-from super_gradients.training.models.all_architectures import ARCHITECTURES
-from super_gradients.common.decorators.factory_decorator import resolve_param
-from super_gradients.common.environment import env_helpers
-from super_gradients.common.abstractions.abstract_logger import get_logger
-from super_gradients.common.factories.datasets_factory import DatasetsFactory
-from super_gradients.common.factories.list_factory import ListFactory
-from super_gradients.common.factories.losses_factory import LossesFactory
-from super_gradients.common.factories.metrics_factory import MetricsFactory
-from super_gradients.common.sg_loggers import SG_LOGGERS
-from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
-from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
-from super_gradients.training import utils as core_utils
-from super_gradients.training.models import SgModule
-from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
-from super_gradients.training.utils import sg_model_utils
-from super_gradients.training.utils.quantization_utils import QATCallback
-from super_gradients.training.utils.sg_model_utils import MonitoredValue
-from super_gradients.training import metrics
-from super_gradients.training.exceptions.sg_model_exceptions import UnsupportedOptimizerFormat, \
-    IllegalDataloaderInitialization
-from super_gradients.training.datasets import DatasetInterface
-from super_gradients.training.losses import LOSSES
-from super_gradients.training.metrics.metric_utils import get_metrics_titles, get_metrics_results_tuple, \
-    get_logging_values, \
-    get_metrics_dict, get_train_loop_description_dict
-from super_gradients.training.params import TrainingParams
-from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
-from super_gradients.training.utils.distributed_training_utils import MultiGPUModeAutocastWrapper, \
-    reduce_results_tuple_for_ddp, compute_precise_bn_stats
-from super_gradients.training.utils.ema import ModelEMA
-from super_gradients.training.utils.optimizer_utils import build_optimizer
-from super_gradients.training.utils.weight_averaging_utils import ModelWeightAveraging
-from super_gradients.training.metrics import Accuracy, Top5
-from super_gradients.training.utils import random_seed
-from super_gradients.training.utils.checkpoint_utils import get_ckpt_local_path, read_ckpt_state_dict, \
-    load_checkpoint_to_model, load_pretrained_weights
-from super_gradients.training.datasets.datasets_utils import DatasetStatisticsTensorboardLogger
-from super_gradients.training.utils.callbacks import CallbackHandler, Phase, LR_SCHEDULERS_CLS_DICT, PhaseContext, \
-    MetricsUpdateCallback, LR_WARMUP_CLS_DICT, ContextSgMethods, LRCallbackBase
-from super_gradients.common.environment import environment_config
-from super_gradients.training.utils import HpmStruct
-from super_gradients.training.datasets.samplers.infinite_sampler import InfiniteSampler
 
 
-from super_gradients.common import StrictLoad, MultiGPUMode, EvaluationType
-logger = get_logger(__name__)
-
-
-class SgModel:
-    """
-    SuperGradient Model - Base Class for Sg Models
-
-    Methods
-    -------
-    train(max_epochs : int, initial_epoch : int, save_model : bool)
-        the main function used for the training, h.p. updating, logging etc.
-
-    predict(idx : int)
-        returns the predictions and label of the current inputs
-
-    test(epoch : int, idx : int, save : bool):
-        returns the test loss, accuracy and runtime
-    """
-
-    def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
-                 model_checkpoints_location: str = 'local',
-                 overwrite_local_checkpoint: bool = True, ckpt_name: str = 'ckpt_latest.pth',
-                 post_prediction_callback: DetectionPostPredictionCallback = None, ckpt_root_dir: str = None,
-                 train_loader: DataLoader = None, valid_loader: DataLoader = None, test_loader: DataLoader = None,
-                 classes: List[Any] = None):
-        """
-
-        :param experiment_name:                      Used for logging and loading purposes
-        :param device:                          If equal to 'cpu' runs on the CPU otherwise on GPU
-        :param multi_gpu:                       If True, runs on all available devices
-        :param model_checkpoints_location:      If set to 's3' saves the Checkpoints in AWS S3
-                                                otherwise saves the Checkpoints Locally
-        :param overwrite_local_checkpoint:      If set to False keeps the current local checkpoint when importing
-                                                checkpoint from cloud service, otherwise overwrites the local checkpoints file
-        :param ckpt_name:                       The Checkpoint to Load
-        :param ckpt_root_dir:                   Local root directory path where all experiment logging directories will
-                                                reside. When none is give, it is assumed that
-                                                pkg_resources.resource_filename('checkpoints', "") exists and will be used.
-        :param train_loader:                    Training set Dataloader instead of using DatasetInterface, must pass "valid_loader"
-                                                and "classes" along with it
-        :param valid_loader:                    Validation set Dataloader
-        :param test_loader:                     Test set Dataloader
-        :param classes:                         List of class labels
-
-        """
-        # SET THE EMPTY PROPERTIES
-        self.net, self.architecture, self.arch_params, self.dataset_interface = None, None, None, None
-        self.device, self.multi_gpu = None, None
-        self.ema = None
-        self.ema_model = None
-        self.sg_logger = None
-        self.update_param_groups = None
-        self.post_prediction_callback = None
-        self.criterion = None
-        self.training_params = None
-        self.scaler = None
-        self.phase_callbacks = None
-        self.checkpoint_params = None
-        self.pre_prediction_callback = None
-
-        # SET THE DEFAULT PROPERTIES
-        self.half_precision = False
-        self.load_checkpoint = False
-        self.load_backbone = False
-        self.load_weights_only = False
-        self.ddp_silent_mode = False
-        self.source_ckpt_folder_name = None
-        self.model_weight_averaging = None
-        self.average_model_checkpoint_filename = 'average_model.pth'
-        self.start_epoch = 0
-        self.best_metric = np.inf
-        self.external_checkpoint_path = None
-        self.strict_load = StrictLoad.ON
-        self.load_ema_as_net = False
-        self.ckpt_best_name = 'ckpt_best.pth'
-        self.enable_qat = False
-        self.qat_params = {}
-        self._infinite_train_loader = False
-
-        # DETERMINE THE LOCATION OF THE LOSS AND ACCURACY IN THE RESULTS TUPLE OUTPUTED BY THE TEST
-        self.loss_idx_in_results_tuple, self.acc_idx_in_results_tuple = None, None
-
-        # METRICS
-        self.loss_logging_items_names = None
-        self.train_metrics = None
-        self.valid_metrics = None
-        self.greater_metric_to_watch_is_better = None
-
-        # SETTING THE PROPERTIES FROM THE CONSTRUCTOR
-        self.experiment_name = experiment_name
-        self.ckpt_name = ckpt_name
-        self.overwrite_local_checkpoint = overwrite_local_checkpoint
-        self.model_checkpoints_location = model_checkpoints_location
-        self._set_dataset_properties(classes, test_loader, train_loader, valid_loader)
-
-        # CREATING THE LOGGING DIR BASED ON THE INPUT PARAMS TO PREVENT OVERWRITE OF LOCAL VERSION
-        if ckpt_root_dir:
-            self.checkpoints_dir_path = os.path.join(ckpt_root_dir, self.experiment_name)
-        elif pkg_resources.resource_exists("checkpoints", ""):
-            self.checkpoints_dir_path = pkg_resources.resource_filename('checkpoints', self.experiment_name)
-        else:
-            raise ValueError("Illegal checkpoints directory: pass ckpt_root_dir that exists, or add 'checkpoints' to"
-                             "resources.")
-
-        # INITIALIZE THE DEVICE FOR THE MODEL
-        self._initialize_device(requested_device=device, requested_multi_gpu=multi_gpu)
-
-        self.post_prediction_callback = post_prediction_callback
-        # SET THE DEFAULTS
-        # TODO: SET DEFAULT TRAINING PARAMS FOR EACH TASK
-
-        default_results_titles = ['Train Loss', 'Train Acc', 'Train Top5', 'Valid Loss', 'Valid Acc', 'Valid Top5']
-
-        self.results_titles = default_results_titles
-
-        self.loss_idx_in_results_tuple, self.acc_idx_in_results_tuple = 0, 1
-        default_train_metrics, default_valid_metrics = MetricCollection([Accuracy(), Top5()]), MetricCollection(
-            [Accuracy(), Top5()])
-
-        default_loss_logging_items_names = ["Loss"]
-
-        self.train_metrics, self.valid_metrics = default_train_metrics, default_valid_metrics
-        self.loss_logging_items_names = default_loss_logging_items_names
-
-        self.train_monitored_values = {}
-        self.valid_monitored_values = {}
-
-    def _set_dataset_properties(self, classes, test_loader, train_loader, valid_loader):
-        if any([train_loader, valid_loader, classes]) and not all([train_loader, valid_loader, classes]):
-            raise IllegalDataloaderInitialization()
-
-        dataset_params = {"batch_size": train_loader.batch_size if train_loader else None,
-                          "val_batch_size": valid_loader.batch_size if valid_loader else None,
-                          "test_batch_size": test_loader.batch_size if test_loader else None,
-                          "dataset_dir": None,
-                          "s3_link": None}
-
-        if train_loader and self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
-            if not all([isinstance(train_loader.sampler, DistributedSampler),
-                        isinstance(valid_loader.sampler, DistributedSampler),
-                        test_loader is None or isinstance(test_loader.sampler, DistributedSampler)]):
-                logger.warning("DDP training was selected but the dataloader samplers are not of type DistributedSamplers")
-
-        self.dataset_params, self.train_loader, self.valid_loader, self.test_loader, self.classes = \
-            HpmStruct(**dataset_params), train_loader, valid_loader, test_loader, classes
-
-    @resolve_param('dataset_interface', DatasetsFactory())
-    def connect_dataset_interface(self, dataset_interface: DatasetInterface, data_loader_num_workers: int = 8):
-        """
-        :param dataset_interface: DatasetInterface object
-        :param data_loader_num_workers: The number of threads to initialize the Data Loaders with
-            The dataset to be connected
-        """
-        if self.train_loader:
-            logger.warning("Overriding the dataloaders that SgModel was initialized with")
-        self.dataset_interface = dataset_interface
-        self.train_loader, self.valid_loader, self.test_loader, self.classes = \
-            self.dataset_interface.get_data_loaders(batch_size_factor=self.num_devices,
-                                                    num_workers=data_loader_num_workers,
-                                                    distributed_sampler=self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL)
-
-        self.dataset_params = self.dataset_interface.get_dataset_params()
-
-    # FIXME - we need to resolve flake8's 'function is too complex' for this function
-    def build_model(self,  # noqa: C901 - too complex
-                    architecture: Union[str, nn.Module],
-                    arch_params={}, checkpoint_params={}, *args, **kwargs):
-        """
-        :param architecture:               Defines the network's architecture from models/ALL_ARCHITECTURES
-        :param arch_params:                Architecture H.P. e.g.: block, num_blocks, num_classes, etc.
-        :param checkpoint_params:          Dictionary like object with the following key:values:
-
-            load_checkpoint:            Load a pre-trained checkpoint
-            strict_load:                See StrictLoad class documentation for details.
-            source_ckpt_folder_name:    folder name to load the checkpoint from (self.experiment_name if none is given)
-            load_weights_only:          loads only the weight from the checkpoint and zeroize the training params
-            load_backbone:              loads the provided checkpoint to self.net.backbone instead of self.net
-            external_checkpoint_path:   The path to the external checkpoint to be loaded. Can be absolute or relative
-                                               (ie: path/to/checkpoint.pth). If provided, will automatically attempt to
-                                               load the checkpoint even if the load_checkpoint flag is not provided.
-
-        """
-        if 'num_classes' not in arch_params.keys():
-            if self.classes is None and self.dataset_interface is None:
-                raise Exception('Error', 'Number of classes not defined in arch params and dataset is not defined')
-            else:
-                arch_params['num_classes'] = len(self.classes)
-
-        self.arch_params = core_utils.HpmStruct(**arch_params)
-        self.checkpoint_params = core_utils.HpmStruct(**checkpoint_params)
-
-        self.net = self._instantiate_net(architecture, self.arch_params, checkpoint_params, *args, **kwargs)
-
-        # SAVE THE ARCHITECTURE FOR NEURAL ARCHITECTURE SEARCH
-
-        self.architecture = architecture
-
-        self._net_to_device()
-
-        # SET THE FLAG FOR DIFFERENT PARAMETER GROUP OPTIMIZER UPDATE
-        self.update_param_groups = hasattr(self.net.module, 'update_param_groups')
-
-        self._load_checkpoint_to_model()
-
-    def _set_ckpt_loading_attributes(self):
-        """
-        Sets checkpoint loading related attributes according to self.checkpoint_params
-        """
-        self.checkpoint = {}
-        self.strict_load = core_utils.get_param(self.checkpoint_params, 'strict_load', default_val=StrictLoad.ON)
-        self.load_ema_as_net = core_utils.get_param(self.checkpoint_params, 'load_ema_as_net', default_val=False)
-        self.source_ckpt_folder_name = core_utils.get_param(self.checkpoint_params, 'source_ckpt_folder_name')
-        self.load_checkpoint = core_utils.get_param(self.checkpoint_params, 'load_checkpoint', default_val=False)
-        self.load_backbone = core_utils.get_param(self.checkpoint_params, 'load_backbone', default_val=False)
-        self.external_checkpoint_path = core_utils.get_param(self.checkpoint_params, 'external_checkpoint_path')
-        if self.load_checkpoint or self.external_checkpoint_path:
-            self.load_weights_only = core_utils.get_param(self.checkpoint_params, 'load_weights_only',
-                                                          default_val=False)
-        self.ckpt_name = core_utils.get_param(self.checkpoint_params, 'ckpt_name', default_val=self.ckpt_name)
-
-    def _net_to_device(self):
-        """
-        Manipulates self.net according to self.multi_gpu
-        """
-        self.net.to(self.device)
-
-        # FOR MULTI-GPU TRAINING (not distributed)
-        self.arch_params.sync_bn = core_utils.get_param(self.arch_params, 'sync_bn', default_val=False)
-        if self.multi_gpu == MultiGPUMode.DATA_PARALLEL:
-            self.net = torch.nn.DataParallel(self.net, device_ids=self.device_ids)
-        elif self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
-            if self.arch_params.sync_bn:
-                if not self.ddp_silent_mode:
-                    logger.info('DDP - Using Sync Batch Norm... Training time will be affected accordingly')
-                self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net).to(self.device)
-
-            local_rank = int(self.device.split(':')[1])
-            self.net = torch.nn.parallel.DistributedDataParallel(self.net,
-                                                                 device_ids=[local_rank],
-                                                                 output_device=local_rank,
-                                                                 find_unused_parameters=True)
-
-        else:
-            self.net = core_utils.WrappedModel(self.net)
-
-    def _train_epoch(self, epoch: int, silent_mode: bool = False) -> tuple:
-        """
-        train_epoch - A single epoch training procedure
-            :param optimizer:   The optimizer for the network
-            :param epoch:       The current epoch
-            :param silent_mode: No verbosity
-        """
-        # SET THE MODEL IN training STATE
-        self.net.train()
-        # THE DISABLE FLAG CONTROLS WHETHER THE PROGRESS BAR IS SILENT OR PRINTS THE LOGS
-        progress_bar_train_loader = tqdm(self.train_loader, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True,
-                                         disable=silent_mode)
-        progress_bar_train_loader.set_description(f"Train epoch {epoch}")
-
-        # RESET/INIT THE METRIC LOGGERS
-        self._reset_metrics()
-
-        self.train_metrics.to(self.device)
-        loss_avg_meter = core_utils.utils.AverageMeter()
-
-        context = PhaseContext(epoch=epoch,
-                               optimizer=self.optimizer,
-                               metrics_compute_fn=self.train_metrics,
-                               loss_avg_meter=loss_avg_meter,
-                               criterion=self.criterion,
-                               device=self.device,
-                               lr_warmup_epochs=self.training_params.lr_warmup_epochs,
-                               sg_logger=self.sg_logger,
-                               train_loader=self.train_loader,
-                               context_methods=self._get_context_methods(Phase.TRAIN_BATCH_END),
-                               ddp_silent_mode=self.ddp_silent_mode)
-
-        for batch_idx, batch_items in enumerate(progress_bar_train_loader):
-            batch_items = core_utils.tensor_container_to_device(batch_items, self.device, non_blocking=True)
-            inputs, targets, additional_batch_items = sg_model_utils.unpack_batch_items(batch_items)
-
-            if self.pre_prediction_callback is not None:
-                inputs, targets = self.pre_prediction_callback(inputs, targets, batch_idx)
-            # AUTOCAST IS ENABLED ONLY IF self.training_params.mixed_precision - IF enabled=False AUTOCAST HAS NO EFFECT
-            with autocast(enabled=self.training_params.mixed_precision):
-                # FORWARD PASS TO GET NETWORK'S PREDICTIONS
-                outputs = self.net(inputs)
-
-                # COMPUTE THE LOSS FOR BACK PROP + EXTRA METRICS COMPUTED DURING THE LOSS FORWARD PASS
-                loss, loss_log_items = self._get_losses(outputs, targets)
-
-            context.update_context(batch_idx=batch_idx,
-                                   inputs=inputs,
-                                   preds=outputs,
-                                   target=targets,
-                                   loss_log_items=loss_log_items,
-                                   **additional_batch_items)
-
-            self.phase_callback_handler(Phase.TRAIN_BATCH_END, context)
-
-            # LOG LR THAT WILL BE USED IN CURRENT EPOCH AND AFTER FIRST WARMUP/LR_SCHEDULER UPDATE BEFORE WEIGHT UPDATE
-            if not self.ddp_silent_mode and batch_idx == 0:
-                self._write_lrs(epoch)
-
-            self._backward_step(loss, epoch, batch_idx, context)
-
-            # COMPUTE THE RUNNING USER METRICS AND LOSS RUNNING ITEMS. RESULT TUPLE IS THEIR CONCATENATION.
-            logging_values = loss_avg_meter.average + get_metrics_results_tuple(self.train_metrics)
-            gpu_memory_utilization = torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0
-
-            # RENDER METRICS PROGRESS
-            pbar_message_dict = get_train_loop_description_dict(logging_values,
-                                                                self.train_metrics,
-                                                                self.loss_logging_items_names,
-                                                                gpu_mem=gpu_memory_utilization)
-
-            progress_bar_train_loader.set_postfix(**pbar_message_dict)
-
-            # TODO: ITERATE BY MAX ITERS
-            # FOR INFINITE SAMPLERS WE MUST BREAK WHEN REACHING LEN ITERATIONS.
-            if self._infinite_train_loader and batch_idx == len(self.train_loader) - 1:
-                break
-
-        if not self.ddp_silent_mode:
-            self.sg_logger.upload()
-
-        self.train_monitored_values = sg_model_utils.update_monitored_values_dict(
-            monitored_values_dict=self.train_monitored_values, new_values_dict=pbar_message_dict)
-
-        return logging_values
-
-    def _get_losses(self, outputs: torch.Tensor, targets: torch.Tensor) -> Tuple[torch.Tensor, tuple]:
-        # GET THE OUTPUT OF THE LOSS FUNCTION
-        loss = self.criterion(outputs, targets)
-        if isinstance(loss, tuple):
-            loss, loss_logging_items = loss
-            # IF ITS NOT A TUPLE THE LOGGING ITEMS CONTAIN ONLY THE LOSS FOR BACKPROP (USER DEFINED LOSS RETURNS SCALAR)
-        else:
-            loss_logging_items = loss.unsqueeze(0).detach()
-
-        if len(loss_logging_items) != len(self.loss_logging_items_names):
-            raise ValueError("Loss output length must match loss_logging_items_names. Got " + str(
-                len(loss_logging_items)) + ', and ' + str(len(self.loss_logging_items_names)))
-        # RETURN AND THE LOSS LOGGING ITEMS COMPUTED DURING LOSS FORWARD PASS
-        return loss, loss_logging_items
-
-    def _backward_step(self, loss: torch.Tensor, epoch: int, batch_idx: int, context: PhaseContext, *args, **kwargs):
-        """
-        Run backprop on the loss and perform a step
-        :param loss: The value computed by the loss function
-        :param optimizer: An object that can perform a gradient step and zeroize model gradient
-        :param epoch: number of epoch the training is on
-        :param batch_idx: number of iteration inside the current epoch
-        :param context: current phase context
-        :return:
-        """
-        # SCALER IS ENABLED ONLY IF self.training_params.mixed_precision=True
-        self.scaler.scale(loss).backward()
-
-        # APPLY GRADIENT CLIPPING IF REQUIRED
-        if self.training_params.clip_grad_norm:
-            torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.training_params.clip_grad_norm)
-
-        # ACCUMULATE GRADIENT FOR X BATCHES BEFORE OPTIMIZING
-        integrated_batches_num = batch_idx + len(self.train_loader) * epoch + 1
-
-        if integrated_batches_num % self.batch_accumulate == 0:
-            # SCALER IS ENABLED ONLY IF self.training_params.mixed_precision=True
-            self.scaler.step(self.optimizer)
-            self.scaler.update()
-
-            self.optimizer.zero_grad()
-            if self.ema:
-                self.ema_model.update(self.net, integrated_batches_num / (len(self.train_loader) * self.max_epochs))
-
-            # RUN PHASE CALLBACKS
-            self.phase_callback_handler(Phase.TRAIN_BATCH_STEP, context)
-
-    def _save_checkpoint(self, optimizer=None, epoch: int = None, validation_results_tuple: tuple = None,
-                         context: PhaseContext = None):
-        """
-        Save the current state dict as latest (always), best (if metric was improved), epoch# (if determined in training
-        params)
-        """
-        # WHEN THE validation_results_tuple IS NONE WE SIMPLY SAVE THE state_dict AS LATEST AND Return
-        if validation_results_tuple is None:
-            self.sg_logger.add_checkpoint(tag='ckpt_latest_weights_only.pth', state_dict={'net': self.net.state_dict()},
-                                          global_step=epoch)
-            return
-
-        # COMPUTE THE CURRENT metric
-        # IF idx IS A LIST - SUM ALL THE VALUES STORED IN THE LIST'S INDICES
-        metric = validation_results_tuple[self.metric_idx_in_results_tuple] if isinstance(
-            self.metric_idx_in_results_tuple, int) else \
-            sum([validation_results_tuple[idx] for idx in self.metric_idx_in_results_tuple])
-
-        # BUILD THE state_dict
-        state = {'net': self.net.state_dict(), 'acc': metric, 'epoch': epoch}
-        if optimizer is not None:
-            state['optimizer_state_dict'] = optimizer.state_dict()
-
-        if self.scaler is not None:
-            state['scaler_state_dict'] = self.scaler.state_dict()
-
-        if self.ema:
-            state['ema_net'] = self.ema_model.ema.state_dict()
-        # SAVES CURRENT MODEL AS ckpt_latest
-        self.sg_logger.add_checkpoint(tag='ckpt_latest.pth', state_dict=state, global_step=epoch)
-
-        # SAVE MODEL AT SPECIFIC EPOCHS DETERMINED BY save_ckpt_epoch_list
-        if epoch in self.training_params.save_ckpt_epoch_list:
-            self.sg_logger.add_checkpoint(tag=f'ckpt_epoch_{epoch}.pth', state_dict=state, global_step=epoch)
-
-        # OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST
-        if (metric > self.best_metric and self.greater_metric_to_watch_is_better) or (
-                metric < self.best_metric and not self.greater_metric_to_watch_is_better):
-            # STORE THE CURRENT metric AS BEST
-            self.best_metric = metric
-            self._save_best_checkpoint(epoch, state)
-
-            # RUN PHASE CALLBACKS
-            self.phase_callback_handler(Phase.VALIDATION_END_BEST_EPOCH, context)
-
-            if isinstance(metric, torch.Tensor):
-                metric = metric.item()
-            logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(metric))
-
-        if self.training_params.average_best_models:
-            net_for_averaging = self.ema_model.ema if self.ema else self.net
-            averaged_model_sd = self.model_weight_averaging.get_average_model(net_for_averaging,
-                                                                              validation_results_tuple=validation_results_tuple)
-            self.sg_logger.add_checkpoint(tag=self.average_model_checkpoint_filename,
-                                          state_dict={'net': averaged_model_sd}, global_step=epoch)
-
-    def _save_best_checkpoint(self, epoch, state):
-        self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
-
-    # FIXME - we need to resolve flake8's 'function is too complex' for this function
-    def train(self, training_params: dict = dict()):  # noqa: C901
-        """
-
-        train - Trains the Model
-
-        IMPORTANT NOTE: Additional batch parameters can be added as a third item (optional) if a tuple is returned by
-          the data loaders, as dictionary. The phase context will hold the additional items, under an attribute with
-          the same name as the key in this dictionary. Then such items can be accessed through phase callbacks.
-
-
-            :param training_params:
-                - `max_epochs` : int
-
-                    Number of epochs to run training.
-
-                - `lr_updates` : list(int)
-
-                    List of fixed epoch numbers to perform learning rate updates when `lr_mode='step'`.
-
-                - `lr_decay_factor` : float
-
-                    Decay factor to apply to the learning rate at each update when `lr_mode='step'`.
-
-
-                -  `lr_mode` : str
-
-                    Learning rate scheduling policy, one of ['step','poly','cosine','function']. 'step' refers to
-                    constant updates at epoch numbers passed through `lr_updates`. 'cosine' refers to Cosine Anealing
-                    policy as mentioned in https://arxiv.org/abs/1608.03983. 'poly' refers to polynomial decrease i.e
-                    in each epoch iteration `self.lr = self.initial_lr * pow((1.0 - (current_iter / max_iter)),
-                    0.9)` 'function' refers to user defined learning rate scheduling function, that is passed through
-                    `lr_schedule_function`.
-
-                - `lr_schedule_function` : Union[callable,None]
-
-                    Learning rate scheduling function to be used when `lr_mode` is 'function'.
-
-                - `lr_warmup_epochs` : int (default=0)
-
-                    Number of epochs for learning rate warm up - see https://arxiv.org/pdf/1706.02677.pdf (Section 2.2).
-
-                - `cosine_final_lr_ratio` : float (default=0.01)
-                    Final learning rate ratio (only relevant when `lr_mode`='cosine'). The cosine starts from initial_lr and reaches
-                     initial_lr * cosine_final_lr_ratio in last epoch
-
-                - `inital_lr` : float
-
-                    Initial learning rate.
-
-                - `loss` : Union[nn.module, str]
-
-                    Loss function for training.
-                    One of SuperGradient's built in options:
-
-                              "cross_entropy": LabelSmoothingCrossEntropyLoss,
-                              "mse": MSELoss,
-                              "r_squared_loss": RSquaredLoss,
-                              "detection_loss": YoLoV3DetectionLoss,
-                              "shelfnet_ohem_loss": ShelfNetOHEMLoss,
-                              "shelfnet_se_loss": ShelfNetSemanticEncodingLoss,
-                              "ssd_loss": SSDLoss,
-
-
-                    or user defined nn.module loss function.
-
-                    IMPORTANT: forward(...) should return a (loss, loss_items) tuple where loss is the tensor used
-                    for backprop (i.e what your original loss function returns), and loss_items should be a tensor of
-                    shape (n_items), of values computed during the forward pass which we desire to log over the
-                    entire epoch. For example- the loss itself should always be logged. Another example is a scenario
-                    where the computed loss is the sum of a few components we would like to log- these entries in
-                    loss_items).
-
-                    When training, set the loss_logging_items_names parameter in train_params to be a list of
-                    strings, of length n_items who's ith element is the name of the ith entry in loss_items. Then
-                    each item will be logged, rendered on tensorboard and "watched" (i.e saving model checkpoints
-                    according to it).
-
-                    Since running logs will save the loss_items in some internal state, it is recommended that
-                    loss_items are detached from their computational graph for memory efficiency.
-
-                - `optimizer` : Union[str, torch.optim.Optimizer]
-
-                    Optimization algorithm. One of ['Adam','SGD','RMSProp'] corresponding to the torch.optim
-                    optimzers implementations, or any object that implements torch.optim.Optimizer.
-
-                - `criterion_params` : dict
-
-                    Loss function parameters.
-
-                - `optimizer_params` : dict
-                    When `optimizer` is one of ['Adam','SGD','RMSProp'], it will be initialized with optimizer_params.
-
-                    (see https://pytorch.org/docs/stable/optim.html for the full list of
-                    parameters for each optimizer).
-
-                - `train_metrics_list` : list(torchmetrics.Metric)
-
-                    Metrics to log during training. For more information on torchmetrics see
-                    https://torchmetrics.rtfd.io/en/latest/.
-
-
-                - `valid_metrics_list` : list(torchmetrics.Metric)
-
-                    Metrics to log during validation/testing. For more information on torchmetrics see
-                    https://torchmetrics.rtfd.io/en/latest/.
-
-
-                - `loss_logging_items_names` : list(str)
-
-                    The list of names/titles for the outputs returned from the loss functions forward pass (reminder-
-                    the loss function should return the tuple (loss, loss_items)). These names will be used for
-                    logging their values.
-
-                - `metric_to_watch` : str (default="Accuracy")
-
-                    will be the metric which the model checkpoint will be saved according to, and can be set to any
-                    of the following:
-
-                        a metric name (str) of one of the metric objects from the valid_metrics_list
-
-                        a "metric_name" if some metric in valid_metrics_list has an attribute component_names which
-                        is a list referring to the names of each entry in the output metric (torch tensor of size n)
-
-                        one of "loss_logging_items_names" i.e which will correspond to an item returned during the
-                        loss function's forward pass.
-
-                    At the end of each epoch, if a new best metric_to_watch value is achieved, the models checkpoint
-                    is saved in YOUR_PYTHON_PATH/checkpoints/ckpt_best.pth
-
-                - `greater_metric_to_watch_is_better` : bool
-
-                    When choosing a model's checkpoint to be saved, the best achieved model is the one that maximizes the
-                     metric_to_watch when this parameter is set to True, and a one that minimizes it otherwise.
-
-                - `ema` : bool (default=False)
-
-                    Whether to use Model Exponential Moving Average (see
-                    https://github.com/rwightman/pytorch-image-models ema implementation)
-
-                - `batch_accumulate` : int (default=1)
-
-                    Number of batches to accumulate before every backward pass.
-
-                - `ema_params` : dict
-
-                    Parameters for the ema model.
-
-                - `zero_weight_decay_on_bias_and_bn` : bool (default=False)
-
-                    Whether to apply weight decay on batch normalization parameters or not (ignored when the passed
-                    optimizer has already been initialized).
-
-
-                - `load_opt_params` : bool (default=True)
-
-                    Whether to load the optimizers parameters as well when loading a model's checkpoint.
-
-                - `run_validation_freq` : int (default=1)
-
-                    The frequency in which validation is performed during training (i.e the validation is ran every
-                     `run_validation_freq` epochs.
-
-                - `save_model` : bool (default=True)
-
-                    Whether to save the model checkpoints.
-
-                - `silent_mode` : bool
-
-                    Silents the print outs.
-
-                - `mixed_precision` : bool
-
-                    Whether to use mixed precision or not.
-
-                - `save_ckpt_epoch_list` : list(int) (default=[])
-
-                    List of fixed epoch indices the user wishes to save checkpoints in.
-
-                - `average_best_models` : bool (default=False)
-
-                    If set, a snapshot dictionary file and the average model will be saved / updated at every epoch
-                    and evaluated only when training is completed. The snapshot file will only be deleted upon
-                    completing the training. The snapshot dict will be managed on cpu.
-
-                - `precise_bn` : bool (default=False)
-
-                    Whether to use precise_bn calculation during the training.
-
-                - `precise_bn_batch_size` : int (default=None)
-
-                    The effective batch size we want to calculate the batchnorm on. For example, if we are training a model
-                    on 8 gpus, with a batch of 128 on each gpu, a good rule of thumb would be to give it 8192
-                    (ie: effective_batch_size * num_gpus = batch_per_gpu * num_gpus * num_gpus).
-                    If precise_bn_batch_size is not provided in the training_params, the latter heuristic will be taken.
-
-                - `seed` : int (default=42)
-
-                    Random seed to be set for torch, numpy, and random. When using DDP each process will have it's seed
-                    set to seed + rank.
-
-
-                - `log_installed_packages` : bool (default=False)
-
-                    When set, the list of all installed packages (and their versions) will be written to the tensorboard
-                     and logfile (useful when trying to reproduce results).
-
-                - `dataset_statistics` : bool (default=False)
-
-                    Enable a statistic analysis of the dataset. If set to True the dataset will be analyzed and a report
-                    will be added to the tensorboard along with some sample images from the dataset. Currently only
-                    detection datasets are supported for analysis.
-
-                -  `save_full_train_log` : bool (default=False)
-
-                    When set, a full log (of all super_gradients modules, including uncaught exceptions from any other
-                     module) of the training will be saved in the checkpoint directory under full_train_log.log
-
-                -  `sg_logger` : Union[AbstractSGLogger, str] (defauls=base_sg_logger)
-
-                    Define the SGLogger object for this training process. The SGLogger handles all disk writes, logs, TensorBoard, remote logging
-                    and remote storage. By overriding the default base_sg_logger, you can change the storage location, support external monitoring and logging
-                    or support remote storage.
-
-                -   `sg_logger_params` : dict
-
-                    SGLogger parameters
-
-                -   `clip_grad_norm` : float
-
-                    Defines a maximal L2 norm of the gradients. Values which exceed the given value will be clipped
-
-                -   `lr_cooldown_epochs` : int (default=0)
-
-                    Number of epochs to cooldown LR (i.e the last epoch from scheduling view point=max_epochs-cooldown).
-
-                -   `pre_prediction_callback` : Callable (default=None)
-
-                     When not None, this callback will be applied to images and targets, and returning them to be used
-                      for the forward pass, and further computations. Args for this callable should be in the order
-                      (inputs, targets, batch_idx) returning modified_inputs, modified_targets
-
-                -   `ckpt_best_name` : str (default='ckpt_best.pth')
-
-                    The best checkpoint (according to metric_to_watch) will be saved under this filename in the checkpoints directory.
-
-                -   `enable_qat`: bool (default=False)
-
-                    Adds a QATCallback to the phase callbacks, that triggers quantization aware training starting from
-                     qat_params["start_epoch"]
-
-                -   `qat_params`: dict-like object with the following key/values:
-
-                        start_epoch: int, first epoch to start QAT.
-
-                        quant_modules_calib_method: str, One of [percentile, mse, entropy, max]. Statistics method for amax
-                         computation of the quantized modules (default=percentile).
-
-                        per_channel_quant_modules: bool, whether quant modules should be per channel (default=False).
-
-                        calibrate: bool, whether to perfrom calibration (default=False).
-
-                        calibrated_model_path: str, path to a calibrated checkpoint (default=None).
-
-                        calib_data_loader: torch.utils.data.DataLoader, data loader of the calibration dataset. When None,
-                         context.train_loader will be used (default=None).
-
-                        num_calib_batches: int, number of batches to collect the statistics from.
-
-                        percentile: float, percentile value to use when SgModel,quant_modules_calib_method='percentile'.
-                         Discarded when other methods are used (Default=99.99).
-
-
-        :return:
-        """
-        global logger
-
-        if self.net is None:
-            raise Exception('Model', 'No model found')
-        if self.dataset_interface is None and self.train_loader is None:
-            raise Exception('Data', 'No dataset found')
-
-        self.training_params = TrainingParams()
-        self.training_params.override(**training_params)
-
-        # SET RANDOM SEED
-        random_seed(is_ddp=self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL,
-                    device=self.device, seed=self.training_params.seed)
-
-        silent_mode = self.training_params.silent_mode or self.ddp_silent_mode
-        # METRICS
-        self._set_train_metrics(train_metrics_list=self.training_params.train_metrics_list)
-        self._set_valid_metrics(valid_metrics_list=self.training_params.valid_metrics_list)
-        self.loss_logging_items_names = self.training_params.loss_logging_items_names
-
-        self.results_titles = ["Train_" + t for t in
-                               self.loss_logging_items_names + get_metrics_titles(self.train_metrics)] + \
-                              ["Valid_" + t for t in
-                               self.loss_logging_items_names + get_metrics_titles(self.valid_metrics)]
-
-        # Store the metric to follow (loss\accuracy) and initialize as the worst value
-        self.metric_to_watch = self.training_params.metric_to_watch
-        self.greater_metric_to_watch_is_better = self.training_params.greater_metric_to_watch_is_better
-        self.metric_idx_in_results_tuple = (
-            self.loss_logging_items_names + get_metrics_titles(self.valid_metrics)).index(self.metric_to_watch)
-
-        # Instantiate the values to monitor (loss/metric)
-        for loss in self.loss_logging_items_names:
-            self.train_monitored_values[loss] = MonitoredValue(name=loss, greater_is_better=False)
-            self.valid_monitored_values[loss] = MonitoredValue(name=loss, greater_is_better=False)
-        self.valid_monitored_values[self.metric_to_watch] = MonitoredValue(name=self.metric_to_watch,
-                                                                           greater_is_better=True)
-
-        # Allowing loading instantiated loss or string
-        if isinstance(self.training_params.loss, str):
-            criterion_cls = LOSSES[self.training_params.loss]
-            self.criterion = criterion_cls(**self.training_params.criterion_params)
-
-        elif isinstance(self.training_params.loss, Mapping):
-            self.criterion = LossesFactory().get(self.training_params.loss)
-
-        elif isinstance(self.training_params.loss, nn.Module):
-            self.criterion = self.training_params.loss
-
-        self.criterion.to(self.device)
-
-        self.max_epochs = self.training_params.max_epochs
-
-        self.ema = self.training_params.ema
-
-        self.precise_bn = self.training_params.precise_bn
-        self.precise_bn_batch_size = self.training_params.precise_bn_batch_size
-
-        self.batch_accumulate = self.training_params.batch_accumulate
-        num_batches = len(self.train_loader)
-
-        if self.ema:
-            ema_params = self.training_params.ema_params
-            logger.info(f'Using EMA with params {ema_params}')
-            self.ema_model = self._instantiate_ema_model(**ema_params)
-            self.ema_model.updates = self.start_epoch * num_batches // self.batch_accumulate
-            if self.load_checkpoint:
-                if 'ema_net' in self.checkpoint.keys():
-                    self.ema_model.ema.load_state_dict(self.checkpoint['ema_net'])
-                else:
-                    self.ema = False
-                    logger.warning(
-                        "[Warning] Checkpoint does not include EMA weights, continuing training without EMA.")
-
-        self.run_validation_freq = self.training_params.run_validation_freq
-        validation_results_tuple = (0, 0)
-        inf_time = 0
-        timer = core_utils.Timer(self.device)
-
-        # IF THE LR MODE IS NOT DEFAULT TAKE IT FROM THE TRAINING PARAMS
-        self.lr_mode = self.training_params.lr_mode
-        load_opt_params = self.training_params.load_opt_params
-
-        self.phase_callbacks = self.training_params.phase_callbacks or []
-        self.phase_callbacks = ListFactory(CallbacksFactory()).get(self.phase_callbacks)
-
-        if self.lr_mode is not None:
-            sg_lr_callback_cls = LR_SCHEDULERS_CLS_DICT[self.lr_mode]
-            self.phase_callbacks.append(sg_lr_callback_cls(train_loader_len=len(self.train_loader),
-                                                           net=self.net,
-                                                           training_params=self.training_params,
-                                                           update_param_groups=self.update_param_groups,
-                                                           **self.training_params.to_dict()))
-        if self.training_params.lr_warmup_epochs > 0:
-            warmup_mode = self.training_params.warmup_mode
-            if isinstance(warmup_mode, str):
-                warmup_callback_cls = LR_WARMUP_CLS_DICT[warmup_mode]
-            elif isinstance(warmup_mode, type) and issubclass(warmup_mode, LRCallbackBase):
-                warmup_callback_cls = warmup_mode
-            else:
-                raise RuntimeError('warmup_mode has to be either a name of a mode (str) or a subclass of PhaseCallback')
-            self.phase_callbacks.append(warmup_callback_cls(train_loader_len=len(self.train_loader),
-                                                            net=self.net,
-                                                            training_params=self.training_params,
-                                                            update_param_groups=self.update_param_groups,
-                                                            **self.training_params.to_dict()))
-
-        self._add_metrics_update_callback(Phase.TRAIN_BATCH_END)
-        self._add_metrics_update_callback(Phase.VALIDATION_BATCH_END)
-
-        # ADD CALLBACK FOR QAT
-        self.enable_qat = core_utils.get_param(self.training_params, "enable_qat", False)
-        if self.enable_qat:
-            self.qat_params = core_utils.get_param(self.training_params, "qat_params")
-            if self.qat_params is None:
-                raise ValueError("Must pass QAT params when enable_qat=True")
-            self.phase_callbacks.append(QATCallback(**self.qat_params))
-
-        self.phase_callback_handler = CallbackHandler(callbacks=self.phase_callbacks)
-
-        if not self.ddp_silent_mode:
-            self._initialize_sg_logger_objects()
-
-            if self.training_params.dataset_statistics:
-                dataset_statistics_logger = DatasetStatisticsTensorboardLogger(self.sg_logger)
-                dataset_statistics_logger.analyze(self.train_loader, all_classes=self.classes,
-                                                  title="Train-set", anchors=self.net.module.arch_params.anchors)
-                dataset_statistics_logger.analyze(self.valid_loader, all_classes=self.classes,
-                                                  title="val-set")
-            # AVERAGE BEST 10 MODELS PARAMS
-            if self.training_params.average_best_models:
-                self.model_weight_averaging = ModelWeightAveraging(self.checkpoints_dir_path,
-                                                                   greater_is_better=self.greater_metric_to_watch_is_better,
-                                                                   source_ckpt_folder_name=self.source_ckpt_folder_name,
-                                                                   metric_to_watch=self.metric_to_watch,
-                                                                   metric_idx=self.metric_idx_in_results_tuple,
-                                                                   load_checkpoint=self.load_checkpoint,
-                                                                   model_checkpoints_location=self.model_checkpoints_location)
-        if self.training_params.save_full_train_log and not self.ddp_silent_mode:
-            logger = get_logger(__name__,
-                                training_log_path=self.sg_logger.log_file_path.replace('.txt', 'full_train_log.log'))
-            sg_model_utils.log_uncaught_exceptions(logger)
-
-        if not self.load_checkpoint or self.load_weights_only:
-            # WHEN STARTING TRAINING FROM SCRATCH, DO NOT LOAD OPTIMIZER PARAMS (EVEN IF LOADING BACKBONE)
-            self.start_epoch = 0
-            self._reset_best_metric()
-            load_opt_params = False
-
-        if isinstance(self.training_params.optimizer, str) or \
-                (inspect.isclass(self.training_params.optimizer) and issubclass(self.training_params.optimizer, torch.optim.Optimizer)):
-            self.optimizer = build_optimizer(net=self.net, lr=self.training_params.initial_lr,
-                                             training_params=self.training_params)
-        elif isinstance(self.training_params.optimizer, torch.optim.Optimizer):
-            self.optimizer = self.training_params.optimizer
-        else:
-            raise UnsupportedOptimizerFormat()
-
-        # VERIFY GRADIENT CLIPPING VALUE
-        if self.training_params.clip_grad_norm is not None and self.training_params.clip_grad_norm <= 0:
-            raise TypeError('Params', 'Invalid clip_grad_norm')
-
-        if self.load_checkpoint and load_opt_params:
-            self.optimizer.load_state_dict(self.checkpoint['optimizer_state_dict'])
-
-        self.pre_prediction_callback = CallbacksFactory().get(self.training_params.pre_prediction_callback)
-
-        self._initialize_mixed_precision(self.training_params.mixed_precision)
-
-        self._infinite_train_loader = (hasattr(self.train_loader, "sampler") and isinstance(self.train_loader.sampler, InfiniteSampler)) or\
-                                      (hasattr(self.train_loader, "batch_sampler") and isinstance(self.train_loader.batch_sampler.sampler, InfiniteSampler))
-
-        self.ckpt_best_name = self.training_params.ckpt_best_name
-
-        context = PhaseContext(optimizer=self.optimizer,
-                               net=self.net,
-                               experiment_name=self.experiment_name,
-                               ckpt_dir=self.checkpoints_dir_path,
-                               criterion=self.criterion,
-                               lr_warmup_epochs=self.training_params.lr_warmup_epochs,
-                               sg_logger=self.sg_logger,
-                               train_loader=self.train_loader,
-                               valid_loader=self.valid_loader,
-                               training_params=self.training_params,
-                               ddp_silent_mode=self.ddp_silent_mode,
-                               checkpoint_params=self.checkpoint_params,
-                               architecture=self.architecture,
-                               arch_params=self.arch_params,
-                               metric_idx_in_results_tuple=self.metric_idx_in_results_tuple,
-                               metric_to_watch=self.metric_to_watch,
-                               device=self.device,
-                               context_methods=self._get_context_methods(Phase.PRE_TRAINING)
-                               )
-
-        self.phase_callback_handler(Phase.PRE_TRAINING, context)
-
-        try:
-            # HEADERS OF THE TRAINING PROGRESS
-            if not silent_mode:
-                logger.info(
-                    f'Started training for {self.max_epochs - self.start_epoch} epochs ({self.start_epoch}/'f'{self.max_epochs - 1})\n')
-            for epoch in range(self.start_epoch, self.max_epochs):
-                if context.stop_training:
-                    logger.info("Request to stop training has been received, stopping training")
-                    break
-
-                # Phase.TRAIN_EPOCH_START
-                # RUN PHASE CALLBACKS
-                context.update_context(epoch=epoch)
-                self.phase_callback_handler(Phase.TRAIN_EPOCH_START, context)
-
-                # IN DDP- SET_EPOCH WILL CAUSE EVERY PROCESS TO BE EXPOSED TO THE ENTIRE DATASET BY SHUFFLING WITH A
-                # DIFFERENT SEED EACH EPOCH START
-                if self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL and hasattr(self.train_loader, "sampler")\
-                        and hasattr(self.train_loader.sampler, "set_epoch"):
-                    self.train_loader.sampler.set_epoch(epoch)
-
-                train_metrics_tuple = self._train_epoch(epoch=epoch, silent_mode=silent_mode)
-
-                # Phase.TRAIN_EPOCH_END
Discard
1
2
3
4
5
  1. # PACKAGE IMPORTS FOR EXTERNAL USAGE
  2. from super_gradients.training.sg_trainer.sg_trainer import Trainer, MultiGPUMode, StrictLoad
  3. __all__ = ['Trainer', 'MultiGPUMode', 'StrictLoad']
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
  1. import inspect
  2. import os
  3. import sys
  4. from copy import deepcopy
  5. from typing import Union, Tuple, Mapping, List, Any
  6. import hydra
  7. import numpy as np
  8. import pkg_resources
  9. import torch
  10. from omegaconf import DictConfig
  11. from torch import nn
  12. from torch.utils.data import DataLoader, DistributedSampler
  13. from torch.cuda.amp import GradScaler, autocast
  14. from torchmetrics import MetricCollection
  15. from tqdm import tqdm
  16. from piptools.scripts.sync import _get_installed_distributions
  17. from super_gradients.common.factories.callbacks_factory import CallbacksFactory
  18. from super_gradients.common.data_types.enum import MultiGPUMode, StrictLoad, EvaluationType
  19. from super_gradients.training.models.all_architectures import ARCHITECTURES
  20. from super_gradients.common.decorators.factory_decorator import resolve_param
  21. from super_gradients.common.environment import env_helpers
  22. from super_gradients.common.abstractions.abstract_logger import get_logger
  23. from super_gradients.common.factories.datasets_factory import DatasetsFactory
  24. from super_gradients.common.factories.list_factory import ListFactory
  25. from super_gradients.common.factories.losses_factory import LossesFactory
  26. from super_gradients.common.factories.metrics_factory import MetricsFactory
  27. from super_gradients.common.sg_loggers import SG_LOGGERS
  28. from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
  29. from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
  30. from super_gradients.training import utils as core_utils
  31. from super_gradients.training.models import SgModule
  32. from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
  33. from super_gradients.training.utils import sg_trainer_utils
  34. from super_gradients.training.utils.quantization_utils import QATCallback
  35. from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, parse_args
  36. from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat, \
  37. IllegalDataloaderInitialization
  38. from super_gradients.training.datasets import DatasetInterface
  39. from super_gradients.training.losses import LOSSES
  40. from super_gradients.training.metrics.metric_utils import get_metrics_titles, get_metrics_results_tuple, \
  41. get_logging_values, \
  42. get_metrics_dict, get_train_loop_description_dict
  43. from super_gradients.training.params import TrainingParams
  44. from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
  45. from super_gradients.training.utils.distributed_training_utils import MultiGPUModeAutocastWrapper, \
  46. reduce_results_tuple_for_ddp, compute_precise_bn_stats
  47. from super_gradients.training.utils.ema import ModelEMA
  48. from super_gradients.training.utils.optimizer_utils import build_optimizer
  49. from super_gradients.training.utils.weight_averaging_utils import ModelWeightAveraging
  50. from super_gradients.training.metrics import Accuracy, Top5
  51. from super_gradients.training.utils import random_seed
  52. from super_gradients.training.utils.checkpoint_utils import get_ckpt_local_path, read_ckpt_state_dict, \
  53. load_checkpoint_to_model, load_pretrained_weights
  54. from super_gradients.training.datasets.datasets_utils import DatasetStatisticsTensorboardLogger
  55. from super_gradients.training.utils.callbacks import CallbackHandler, Phase, LR_SCHEDULERS_CLS_DICT, PhaseContext, \
  56. MetricsUpdateCallback, LR_WARMUP_CLS_DICT, ContextSgMethods, LRCallbackBase
  57. from super_gradients.common.environment import environment_config
  58. from super_gradients.training.utils import HpmStruct
  59. from super_gradients.training.datasets.samplers.infinite_sampler import InfiniteSampler
  60. logger = get_logger(__name__)
  61. class Trainer:
  62. """
  63. SuperGradient Model - Base Class for Sg Models
  64. Methods
  65. -------
  66. train(max_epochs : int, initial_epoch : int, save_model : bool)
  67. the main function used for the training, h.p. updating, logging etc.
  68. predict(idx : int)
  69. returns the predictions and label of the current inputs
  70. test(epoch : int, idx : int, save : bool):
  71. returns the test loss, accuracy and runtime
  72. """
  73. def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
  74. model_checkpoints_location: str = 'local',
  75. overwrite_local_checkpoint: bool = True, ckpt_name: str = 'ckpt_latest.pth',
  76. post_prediction_callback: DetectionPostPredictionCallback = None, ckpt_root_dir: str = None,
  77. train_loader: DataLoader = None, valid_loader: DataLoader = None, test_loader: DataLoader = None,
  78. classes: List[Any] = None):
  79. """
  80. :param experiment_name: Used for logging and loading purposes
  81. :param device: If equal to 'cpu' runs on the CPU otherwise on GPU
  82. :param multi_gpu: If True, runs on all available devices
  83. :param model_checkpoints_location: If set to 's3' saves the Checkpoints in AWS S3
  84. otherwise saves the Checkpoints Locally
  85. :param overwrite_local_checkpoint: If set to False keeps the current local checkpoint when importing
  86. checkpoint from cloud service, otherwise overwrites the local checkpoints file
  87. :param ckpt_name: The Checkpoint to Load
  88. :param ckpt_root_dir: Local root directory path where all experiment logging directories will
  89. reside. When none is give, it is assumed that
  90. pkg_resources.resource_filename('checkpoints', "") exists and will be used.
  91. :param train_loader: Training set Dataloader instead of using DatasetInterface, must pass "valid_loader"
  92. and "classes" along with it
  93. :param valid_loader: Validation set Dataloader
  94. :param test_loader: Test set Dataloader
  95. :param classes: List of class labels
  96. """
  97. # SET THE EMPTY PROPERTIES
  98. self.net, self.architecture, self.arch_params, self.dataset_interface = None, None, None, None
  99. self.device, self.multi_gpu = None, None
  100. self.ema = None
  101. self.ema_model = None
  102. self.sg_logger = None
  103. self.update_param_groups = None
  104. self.post_prediction_callback = None
  105. self.criterion = None
  106. self.training_params = None
  107. self.scaler = None
  108. self.phase_callbacks = None
  109. self.checkpoint_params = None
  110. self.pre_prediction_callback = None
  111. # SET THE DEFAULT PROPERTIES
  112. self.half_precision = False
  113. self.load_checkpoint = False
  114. self.load_backbone = False
  115. self.load_weights_only = False
  116. self.ddp_silent_mode = False
  117. self.source_ckpt_folder_name = None
  118. self.model_weight_averaging = None
  119. self.average_model_checkpoint_filename = 'average_model.pth'
  120. self.start_epoch = 0
  121. self.best_metric = np.inf
  122. self.external_checkpoint_path = None
  123. self.strict_load = StrictLoad.ON
  124. self.load_ema_as_net = False
  125. self.ckpt_best_name = 'ckpt_best.pth'
  126. self.enable_qat = False
  127. self.qat_params = {}
  128. self._infinite_train_loader = False
  129. # DETERMINE THE LOCATION OF THE LOSS AND ACCURACY IN THE RESULTS TUPLE OUTPUTED BY THE TEST
  130. self.loss_idx_in_results_tuple, self.acc_idx_in_results_tuple = None, None
  131. # METRICS
  132. self.loss_logging_items_names = None
  133. self.train_metrics = None
  134. self.valid_metrics = None
  135. self.greater_metric_to_watch_is_better = None
  136. # SETTING THE PROPERTIES FROM THE CONSTRUCTOR
  137. self.experiment_name = experiment_name
  138. self.ckpt_name = ckpt_name
  139. self.overwrite_local_checkpoint = overwrite_local_checkpoint
  140. self.model_checkpoints_location = model_checkpoints_location
  141. self._set_dataset_properties(classes, test_loader, train_loader, valid_loader)
  142. # CREATING THE LOGGING DIR BASED ON THE INPUT PARAMS TO PREVENT OVERWRITE OF LOCAL VERSION
  143. if ckpt_root_dir:
  144. self.checkpoints_dir_path = os.path.join(ckpt_root_dir, self.experiment_name)
  145. elif pkg_resources.resource_exists("checkpoints", ""):
  146. self.checkpoints_dir_path = pkg_resources.resource_filename('checkpoints', self.experiment_name)
  147. else:
  148. raise ValueError("Illegal checkpoints directory: pass ckpt_root_dir that exists, or add 'checkpoints' to"
  149. "resources.")
  150. # INITIALIZE THE DEVICE FOR THE MODEL
  151. self._initialize_device(requested_device=device, requested_multi_gpu=multi_gpu)
  152. self.post_prediction_callback = post_prediction_callback
  153. # SET THE DEFAULTS
  154. # TODO: SET DEFAULT TRAINING PARAMS FOR EACH TASK
  155. default_results_titles = ['Train Loss', 'Train Acc', 'Train Top5', 'Valid Loss', 'Valid Acc', 'Valid Top5']
  156. self.results_titles = default_results_titles
  157. self.loss_idx_in_results_tuple, self.acc_idx_in_results_tuple = 0, 1
  158. default_train_metrics, default_valid_metrics = MetricCollection([Accuracy(), Top5()]), MetricCollection(
  159. [Accuracy(), Top5()])
  160. default_loss_logging_items_names = ["Loss"]
  161. self.train_metrics, self.valid_metrics = default_train_metrics, default_valid_metrics
  162. self.loss_logging_items_names = default_loss_logging_items_names
  163. self.train_monitored_values = {}
  164. self.valid_monitored_values = {}
  165. @classmethod
  166. def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None:
  167. """
  168. Trains according to cfg recipe configuration.
  169. @param cfg: The parsed DictConfig from yaml recipe files or a dictionary
  170. @return: output of trainer.train(...) (i.e results tuple)
  171. """
  172. # INSTANTIATE ALL OBJECTS IN CFG
  173. cfg = hydra.utils.instantiate(cfg)
  174. kwargs = parse_args(cfg, cls.__init__)
  175. trainer = Trainer(**kwargs)
  176. # CONNECT THE DATASET INTERFACE WITH DECI MODEL
  177. trainer.connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
  178. # BUILD NETWORK
  179. trainer.build_model(cfg.architecture, arch_params=cfg.arch_params, checkpoint_params=cfg.checkpoint_params)
  180. # TRAIN
  181. trainer.train(training_params=cfg.training_hyperparams)
  182. def _set_dataset_properties(self, classes, test_loader, train_loader, valid_loader):
  183. if any([train_loader, valid_loader, classes]) and not all([train_loader, valid_loader, classes]):
  184. raise IllegalDataloaderInitialization()
  185. dataset_params = {"batch_size": train_loader.batch_size if train_loader else None,
  186. "val_batch_size": valid_loader.batch_size if valid_loader else None,
  187. "test_batch_size": test_loader.batch_size if test_loader else None,
  188. "dataset_dir": None,
  189. "s3_link": None}
  190. if train_loader and self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
  191. if not all([isinstance(train_loader.sampler, DistributedSampler),
  192. isinstance(valid_loader.sampler, DistributedSampler),
  193. test_loader is None or isinstance(test_loader.sampler, DistributedSampler)]):
  194. logger.warning("DDP training was selected but the dataloader samplers are not of type DistributedSamplers")
  195. self.dataset_params, self.train_loader, self.valid_loader, self.test_loader, self.classes = \
  196. HpmStruct(**dataset_params), train_loader, valid_loader, test_loader, classes
  197. @resolve_param('dataset_interface', DatasetsFactory())
  198. def connect_dataset_interface(self, dataset_interface: DatasetInterface, data_loader_num_workers: int = 8):
  199. """
  200. :param dataset_interface: DatasetInterface object
  201. :param data_loader_num_workers: The number of threads to initialize the Data Loaders with
  202. The dataset to be connected
  203. """
  204. if self.train_loader:
  205. logger.warning("Overriding the dataloaders that Trainer was initialized with")
  206. self.dataset_interface = dataset_interface
  207. self.train_loader, self.valid_loader, self.test_loader, self.classes = \
  208. self.dataset_interface.get_data_loaders(batch_size_factor=self.num_devices,
  209. num_workers=data_loader_num_workers,
  210. distributed_sampler=self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL)
  211. self.dataset_params = self.dataset_interface.get_dataset_params()
  212. # FIXME - we need to resolve flake8's 'function is too complex' for this function
  213. def build_model(self, # noqa: C901 - too complex
  214. architecture: Union[str, nn.Module],
  215. arch_params={}, checkpoint_params={}, *args, **kwargs):
  216. """
  217. :param architecture: Defines the network's architecture from models/ALL_ARCHITECTURES
  218. :param arch_params: Architecture H.P. e.g.: block, num_blocks, num_classes, etc.
  219. :param checkpoint_params: Dictionary like object with the following key:values:
  220. load_checkpoint: Load a pre-trained checkpoint
  221. strict_load: See StrictLoad class documentation for details.
  222. source_ckpt_folder_name: folder name to load the checkpoint from (self.experiment_name if none is given)
  223. load_weights_only: loads only the weight from the checkpoint and zeroize the training params
  224. load_backbone: loads the provided checkpoint to self.net.backbone instead of self.net
  225. external_checkpoint_path: The path to the external checkpoint to be loaded. Can be absolute or relative
  226. (ie: path/to/checkpoint.pth). If provided, will automatically attempt to
  227. load the checkpoint even if the load_checkpoint flag is not provided.
  228. """
  229. if 'num_classes' not in arch_params.keys():
  230. if self.classes is None and self.dataset_interface is None:
  231. raise Exception('Error', 'Number of classes not defined in arch params and dataset is not defined')
  232. else:
  233. arch_params['num_classes'] = len(self.classes)
  234. self.arch_params = core_utils.HpmStruct(**arch_params)
  235. self.checkpoint_params = core_utils.HpmStruct(**checkpoint_params)
  236. self.net = self._instantiate_net(architecture, self.arch_params, checkpoint_params, *args, **kwargs)
  237. # SAVE THE ARCHITECTURE FOR NEURAL ARCHITECTURE SEARCH
  238. self.architecture = architecture
  239. self._net_to_device()
  240. # SET THE FLAG FOR DIFFERENT PARAMETER GROUP OPTIMIZER UPDATE
  241. self.update_param_groups = hasattr(self.net.module, 'update_param_groups')
  242. self._load_checkpoint_to_model()
  243. def _set_ckpt_loading_attributes(self):
  244. """
  245. Sets checkpoint loading related attributes according to self.checkpoint_params
  246. """
  247. self.checkpoint = {}
  248. self.strict_load = core_utils.get_param(self.checkpoint_params, 'strict_load', default_val=StrictLoad.ON)
  249. self.load_ema_as_net = core_utils.get_param(self.checkpoint_params, 'load_ema_as_net', default_val=False)
  250. self.source_ckpt_folder_name = core_utils.get_param(self.checkpoint_params, 'source_ckpt_folder_name')
  251. self.load_checkpoint = core_utils.get_param(self.checkpoint_params, 'load_checkpoint', default_val=False)
  252. self.load_backbone = core_utils.get_param(self.checkpoint_params, 'load_backbone', default_val=False)
  253. self.external_checkpoint_path = core_utils.get_param(self.checkpoint_params, 'external_checkpoint_path')
  254. if self.load_checkpoint or self.external_checkpoint_path:
  255. self.load_weights_only = core_utils.get_param(self.checkpoint_params, 'load_weights_only',
  256. default_val=False)
  257. self.ckpt_name = core_utils.get_param(self.checkpoint_params, 'ckpt_name', default_val=self.ckpt_name)
  258. def _net_to_device(self):
  259. """
  260. Manipulates self.net according to self.multi_gpu
  261. """
  262. self.net.to(self.device)
  263. # FOR MULTI-GPU TRAINING (not distributed)
  264. self.arch_params.sync_bn = core_utils.get_param(self.arch_params, 'sync_bn', default_val=False)
  265. if self.multi_gpu == MultiGPUMode.DATA_PARALLEL:
  266. self.net = torch.nn.DataParallel(self.net, device_ids=self.device_ids)
  267. elif self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
  268. if self.arch_params.sync_bn:
  269. if not self.ddp_silent_mode:
  270. logger.info('DDP - Using Sync Batch Norm... Training time will be affected accordingly')
  271. self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net).to(self.device)
  272. local_rank = int(self.device.split(':')[1])
  273. self.net = torch.nn.parallel.DistributedDataParallel(self.net,
  274. device_ids=[local_rank],
  275. output_device=local_rank,
  276. find_unused_parameters=True)
  277. else:
  278. self.net = core_utils.WrappedModel(self.net)
  279. def _train_epoch(self, epoch: int, silent_mode: bool = False) -> tuple:
  280. """
  281. train_epoch - A single epoch training procedure
  282. :param optimizer: The optimizer for the network
  283. :param epoch: The current epoch
  284. :param silent_mode: No verbosity
  285. """
  286. # SET THE MODEL IN training STATE
  287. self.net.train()
  288. # THE DISABLE FLAG CONTROLS WHETHER THE PROGRESS BAR IS SILENT OR PRINTS THE LOGS
  289. progress_bar_train_loader = tqdm(self.train_loader, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True,
  290. disable=silent_mode)
  291. progress_bar_train_loader.set_description(f"Train epoch {epoch}")
  292. # RESET/INIT THE METRIC LOGGERS
  293. self._reset_metrics()
  294. self.train_metrics.to(self.device)
  295. loss_avg_meter = core_utils.utils.AverageMeter()
  296. context = PhaseContext(epoch=epoch,
  297. optimizer=self.optimizer,
  298. metrics_compute_fn=self.train_metrics,
  299. loss_avg_meter=loss_avg_meter,
  300. criterion=self.criterion,
  301. device=self.device,
  302. lr_warmup_epochs=self.training_params.lr_warmup_epochs,
  303. sg_logger=self.sg_logger,
  304. train_loader=self.train_loader,
  305. context_methods=self._get_context_methods(Phase.TRAIN_BATCH_END),
  306. ddp_silent_mode=self.ddp_silent_mode)
  307. for batch_idx, batch_items in enumerate(progress_bar_train_loader):
  308. batch_items = core_utils.tensor_container_to_device(batch_items, self.device, non_blocking=True)
  309. inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items)
  310. if self.pre_prediction_callback is not None:
  311. inputs, targets = self.pre_prediction_callback(inputs, targets, batch_idx)
  312. # AUTOCAST IS ENABLED ONLY IF self.training_params.mixed_precision - IF enabled=False AUTOCAST HAS NO EFFECT
  313. with autocast(enabled=self.training_params.mixed_precision):
  314. # FORWARD PASS TO GET NETWORK'S PREDICTIONS
  315. outputs = self.net(inputs)
  316. # COMPUTE THE LOSS FOR BACK PROP + EXTRA METRICS COMPUTED DURING THE LOSS FORWARD PASS
  317. loss, loss_log_items = self._get_losses(outputs, targets)
  318. context.update_context(batch_idx=batch_idx,
  319. inputs=inputs,
  320. preds=outputs,
  321. target=targets,
  322. loss_log_items=loss_log_items,
  323. **additional_batch_items)
  324. self.phase_callback_handler(Phase.TRAIN_BATCH_END, context)
  325. # LOG LR THAT WILL BE USED IN CURRENT EPOCH AND AFTER FIRST WARMUP/LR_SCHEDULER UPDATE BEFORE WEIGHT UPDATE
  326. if not self.ddp_silent_mode and batch_idx == 0:
  327. self._write_lrs(epoch)
  328. self._backward_step(loss, epoch, batch_idx, context)
  329. # COMPUTE THE RUNNING USER METRICS AND LOSS RUNNING ITEMS. RESULT TUPLE IS THEIR CONCATENATION.
  330. logging_values = loss_avg_meter.average + get_metrics_results_tuple(self.train_metrics)
  331. gpu_memory_utilization = torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0
  332. # RENDER METRICS PROGRESS
  333. pbar_message_dict = get_train_loop_description_dict(logging_values,
  334. self.train_metrics,
  335. self.loss_logging_items_names,
  336. gpu_mem=gpu_memory_utilization)
  337. progress_bar_train_loader.set_postfix(**pbar_message_dict)
  338. # TODO: ITERATE BY MAX ITERS
  339. # FOR INFINITE SAMPLERS WE MUST BREAK WHEN REACHING LEN ITERATIONS.
  340. if self._infinite_train_loader and batch_idx == len(self.train_loader) - 1:
  341. break
  342. if not self.ddp_silent_mode:
  343. self.sg_logger.upload()
  344. self.train_monitored_values = sg_trainer_utils.update_monitored_values_dict(
  345. monitored_values_dict=self.train_monitored_values, new_values_dict=pbar_message_dict)
  346. return logging_values
  347. def _get_losses(self, outputs: torch.Tensor, targets: torch.Tensor) -> Tuple[torch.Tensor, tuple]:
  348. # GET THE OUTPUT OF THE LOSS FUNCTION
  349. loss = self.criterion(outputs, targets)
  350. if isinstance(loss, tuple):
  351. loss, loss_logging_items = loss
  352. # IF ITS NOT A TUPLE THE LOGGING ITEMS CONTAIN ONLY THE LOSS FOR BACKPROP (USER DEFINED LOSS RETURNS SCALAR)
  353. else:
  354. loss_logging_items = loss.unsqueeze(0).detach()
  355. if len(loss_logging_items) != len(self.loss_logging_items_names):
  356. raise ValueError("Loss output length must match loss_logging_items_names. Got " + str(
  357. len(loss_logging_items)) + ', and ' + str(len(self.loss_logging_items_names)))
  358. # RETURN AND THE LOSS LOGGING ITEMS COMPUTED DURING LOSS FORWARD PASS
  359. return loss, loss_logging_items
  360. def _backward_step(self, loss: torch.Tensor, epoch: int, batch_idx: int, context: PhaseContext, *args, **kwargs):
  361. """
  362. Run backprop on the loss and perform a step
  363. :param loss: The value computed by the loss function
  364. :param optimizer: An object that can perform a gradient step and zeroize model gradient
  365. :param epoch: number of epoch the training is on
  366. :param batch_idx: number of iteration inside the current epoch
  367. :param context: current phase context
  368. :return:
  369. """
  370. # SCALER IS ENABLED ONLY IF self.training_params.mixed_precision=True
  371. self.scaler.scale(loss).backward()
  372. # APPLY GRADIENT CLIPPING IF REQUIRED
  373. if self.training_params.clip_grad_norm:
  374. torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.training_params.clip_grad_norm)
  375. # ACCUMULATE GRADIENT FOR X BATCHES BEFORE OPTIMIZING
  376. integrated_batches_num = batch_idx + len(self.train_loader) * epoch + 1
  377. if integrated_batches_num % self.batch_accumulate == 0:
  378. # SCALER IS ENABLED ONLY IF self.training_params.mixed_precision=True
  379. self.scaler.step(self.optimizer)
  380. self.scaler.update()
  381. self.optimizer.zero_grad()
  382. if self.ema:
  383. self.ema_model.update(self.net, integrated_batches_num / (len(self.train_loader) * self.max_epochs))
  384. # RUN PHASE CALLBACKS
  385. self.phase_callback_handler(Phase.TRAIN_BATCH_STEP, context)
  386. def _save_checkpoint(self, optimizer=None, epoch: int = None, validation_results_tuple: tuple = None,
  387. context: PhaseContext = None):
  388. """
  389. Save the current state dict as latest (always), best (if metric was improved), epoch# (if determined in training
  390. params)
  391. """
  392. # WHEN THE validation_results_tuple IS NONE WE SIMPLY SAVE THE state_dict AS LATEST AND Return
  393. if validation_results_tuple is None:
  394. self.sg_logger.add_checkpoint(tag='ckpt_latest_weights_only.pth', state_dict={'net': self.net.state_dict()},
  395. global_step=epoch)
  396. return
  397. # COMPUTE THE CURRENT metric
  398. # IF idx IS A LIST - SUM ALL THE VALUES STORED IN THE LIST'S INDICES
  399. metric = validation_results_tuple[self.metric_idx_in_results_tuple] if isinstance(
  400. self.metric_idx_in_results_tuple, int) else \
  401. sum([validation_results_tuple[idx] for idx in self.metric_idx_in_results_tuple])
  402. # BUILD THE state_dict
  403. state = {'net': self.net.state_dict(), 'acc': metric, 'epoch': epoch}
  404. if optimizer is not None:
  405. state['optimizer_state_dict'] = optimizer.state_dict()
  406. if self.scaler is not None:
  407. state['scaler_state_dict'] = self.scaler.state_dict()
  408. if self.ema:
  409. state['ema_net'] = self.ema_model.ema.state_dict()
  410. # SAVES CURRENT MODEL AS ckpt_latest
  411. self.sg_logger.add_checkpoint(tag='ckpt_latest.pth', state_dict=state, global_step=epoch)
  412. # SAVE MODEL AT SPECIFIC EPOCHS DETERMINED BY save_ckpt_epoch_list
  413. if epoch in self.training_params.save_ckpt_epoch_list:
  414. self.sg_logger.add_checkpoint(tag=f'ckpt_epoch_{epoch}.pth', state_dict=state, global_step=epoch)
  415. # OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST
  416. if (metric > self.best_metric and self.greater_metric_to_watch_is_better) or (
  417. metric < self.best_metric and not self.greater_metric_to_watch_is_better):
  418. # STORE THE CURRENT metric AS BEST
  419. self.best_metric = metric
  420. self._save_best_checkpoint(epoch, state)
  421. # RUN PHASE CALLBACKS
  422. self.phase_callback_handler(Phase.VALIDATION_END_BEST_EPOCH, context)
  423. if isinstance(metric, torch.Tensor):
  424. metric = metric.item()
  425. logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(metric))
  426. if self.training_params.average_best_models:
  427. net_for_averaging = self.ema_model.ema if self.ema else self.net
  428. averaged_model_sd = self.model_weight_averaging.get_average_model(net_for_averaging,
  429. validation_results_tuple=validation_results_tuple)
  430. self.sg_logger.add_checkpoint(tag=self.average_model_checkpoint_filename,
  431. state_dict={'net': averaged_model_sd}, global_step=epoch)
  432. def _save_best_checkpoint(self, epoch, state):
  433. self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
  434. # FIXME - we need to resolve flake8's 'function is too complex' for this function
  435. def train(self, training_params: dict = dict()): # noqa: C901
  436. """
  437. train - Trains the Model
  438. IMPORTANT NOTE: Additional batch parameters can be added as a third item (optional) if a tuple is returned by
  439. the data loaders, as dictionary. The phase context will hold the additional items, under an attribute with
  440. the same name as the key in this dictionary. Then such items can be accessed through phase callbacks.
  441. :param training_params:
  442. - `max_epochs` : int
  443. Number of epochs to run training.
  444. - `lr_updates` : list(int)
  445. List of fixed epoch numbers to perform learning rate updates when `lr_mode='step'`.
  446. - `lr_decay_factor` : float
  447. Decay factor to apply to the learning rate at each update when `lr_mode='step'`.
  448. - `lr_mode` : str
  449. Learning rate scheduling policy, one of ['step','poly','cosine','function']. 'step' refers to
  450. constant updates at epoch numbers passed through `lr_updates`. 'cosine' refers to Cosine Anealing
  451. policy as mentioned in https://arxiv.org/abs/1608.03983. 'poly' refers to polynomial decrease i.e
  452. in each epoch iteration `self.lr = self.initial_lr * pow((1.0 - (current_iter / max_iter)),
  453. 0.9)` 'function' refers to user defined learning rate scheduling function, that is passed through
  454. `lr_schedule_function`.
  455. - `lr_schedule_function` : Union[callable,None]
  456. Learning rate scheduling function to be used when `lr_mode` is 'function'.
  457. - `lr_warmup_epochs` : int (default=0)
  458. Number of epochs for learning rate warm up - see https://arxiv.org/pdf/1706.02677.pdf (Section 2.2).
  459. - `cosine_final_lr_ratio` : float (default=0.01)
  460. Final learning rate ratio (only relevant when `lr_mode`='cosine'). The cosine starts from initial_lr and reaches
  461. initial_lr * cosine_final_lr_ratio in last epoch
  462. - `inital_lr` : float
  463. Initial learning rate.
  464. - `loss` : Union[nn.module, str]
  465. Loss function for training.
  466. One of SuperGradient's built in options:
  467. "cross_entropy": LabelSmoothingCrossEntropyLoss,
  468. "mse": MSELoss,
  469. "r_squared_loss": RSquaredLoss,
  470. "detection_loss": YoLoV3DetectionLoss,
  471. "shelfnet_ohem_loss": ShelfNetOHEMLoss,
  472. "shelfnet_se_loss": ShelfNetSemanticEncodingLoss,
  473. "ssd_loss": SSDLoss,
  474. or user defined nn.module loss function.
  475. IMPORTANT: forward(...) should return a (loss, loss_items) tuple where loss is the tensor used
  476. for backprop (i.e what your original loss function returns), and loss_items should be a tensor of
  477. shape (n_items), of values computed during the forward pass which we desire to log over the
  478. entire epoch. For example- the loss itself should always be logged. Another example is a scenario
  479. where the computed loss is the sum of a few components we would like to log- these entries in
  480. loss_items).
  481. When training, set the loss_logging_items_names parameter in train_params to be a list of
  482. strings, of length n_items who's ith element is the name of the ith entry in loss_items. Then
  483. each item will be logged, rendered on tensorboard and "watched" (i.e saving model checkpoints
  484. according to it).
  485. Since running logs will save the loss_items in some internal state, it is recommended that
  486. loss_items are detached from their computational graph for memory efficiency.
  487. - `optimizer` : Union[str, torch.optim.Optimizer]
  488. Optimization algorithm. One of ['Adam','SGD','RMSProp'] corresponding to the torch.optim
  489. optimzers implementations, or any object that implements torch.optim.Optimizer.
  490. - `criterion_params` : dict
  491. Loss function parameters.
  492. - `optimizer_params` : dict
  493. When `optimizer` is one of ['Adam','SGD','RMSProp'], it will be initialized with optimizer_params.
  494. (see https://pytorch.org/docs/stable/optim.html for the full list of
  495. parameters for each optimizer).
  496. - `train_metrics_list` : list(torchmetrics.Metric)
  497. Metrics to log during training. For more information on torchmetrics see
  498. https://torchmetrics.rtfd.io/en/latest/.
  499. - `valid_metrics_list` : list(torchmetrics.Metric)
  500. Metrics to log during validation/testing. For more information on torchmetrics see
  501. https://torchmetrics.rtfd.io/en/latest/.
  502. - `loss_logging_items_names` : list(str)
  503. The list of names/titles for the outputs returned from the loss functions forward pass (reminder-
  504. the loss function should return the tuple (loss, loss_items)). These names will be used for
  505. logging their values.
  506. - `metric_to_watch` : str (default="Accuracy")
  507. will be the metric which the model checkpoint will be saved according to, and can be set to any
  508. of the following:
  509. a metric name (str) of one of the metric objects from the valid_metrics_list
  510. a "metric_name" if some metric in valid_metrics_list has an attribute component_names which
  511. is a list referring to the names of each entry in the output metric (torch tensor of size n)
  512. one of "loss_logging_items_names" i.e which will correspond to an item returned during the
  513. loss function's forward pass.
  514. At the end of each epoch, if a new best metric_to_watch value is achieved, the models checkpoint
  515. is saved in YOUR_PYTHON_PATH/checkpoints/ckpt_best.pth
  516. - `greater_metric_to_watch_is_better` : bool
  517. When choosing a model's checkpoint to be saved, the best achieved model is the one that maximizes the
  518. metric_to_watch when this parameter is set to True, and a one that minimizes it otherwise.
  519. - `ema` : bool (default=False)
  520. Whether to use Model Exponential Moving Average (see
  521. https://github.com/rwightman/pytorch-image-models ema implementation)
  522. - `batch_accumulate` : int (default=1)
  523. Number of batches to accumulate before every backward pass.
  524. - `ema_params` : dict
  525. Parameters for the ema model.
  526. - `zero_weight_decay_on_bias_and_bn` : bool (default=False)
  527. Whether to apply weight decay on batch normalization parameters or not (ignored when the passed
  528. optimizer has already been initialized).
  529. - `load_opt_params` : bool (default=True)
  530. Whether to load the optimizers parameters as well when loading a model's checkpoint.
  531. - `run_validation_freq` : int (default=1)
  532. The frequency in which validation is performed during training (i.e the validation is ran every
  533. `run_validation_freq` epochs.
  534. - `save_model` : bool (default=True)
  535. Whether to save the model checkpoints.
  536. - `silent_mode` : bool
  537. Silents the print outs.
  538. - `mixed_precision` : bool
  539. Whether to use mixed precision or not.
  540. - `save_ckpt_epoch_list` : list(int) (default=[])
  541. List of fixed epoch indices the user wishes to save checkpoints in.
  542. - `average_best_models` : bool (default=False)
  543. If set, a snapshot dictionary file and the average model will be saved / updated at every epoch
  544. and evaluated only when training is completed. The snapshot file will only be deleted upon
  545. completing the training. The snapshot dict will be managed on cpu.
  546. - `precise_bn` : bool (default=False)
  547. Whether to use precise_bn calculation during the training.
  548. - `precise_bn_batch_size` : int (default=None)
  549. The effective batch size we want to calculate the batchnorm on. For example, if we are training a model
  550. on 8 gpus, with a batch of 128 on each gpu, a good rule of thumb would be to give it 8192
  551. (ie: effective_batch_size * num_gpus = batch_per_gpu * num_gpus * num_gpus).
  552. If precise_bn_batch_size is not provided in the training_params, the latter heuristic will be taken.
  553. - `seed` : int (default=42)
  554. Random seed to be set for torch, numpy, and random. When using DDP each process will have it's seed
  555. set to seed + rank.
  556. - `log_installed_packages` : bool (default=False)
  557. When set, the list of all installed packages (and their versions) will be written to the tensorboard
  558. and logfile (useful when trying to reproduce results).
  559. - `dataset_statistics` : bool (default=False)
  560. Enable a statistic analysis of the dataset. If set to True the dataset will be analyzed and a report
  561. will be added to the tensorboard along with some sample images from the dataset. Currently only
  562. detection datasets are supported for analysis.
  563. - `save_full_train_log` : bool (default=False)
  564. When set, a full log (of all super_gradients modules, including uncaught exceptions from any other
  565. module) of the training will be saved in the checkpoint directory under full_train_log.log
  566. - `sg_logger` : Union[AbstractSGLogger, str] (defauls=base_sg_logger)
  567. Define the SGLogger object for this training process. The SGLogger handles all disk writes, logs, TensorBoard, remote logging
  568. and remote storage. By overriding the default base_sg_logger, you can change the storage location, support external monitoring and logging
  569. or support remote storage.
  570. - `sg_logger_params` : dict
  571. SGLogger parameters
  572. - `clip_grad_norm` : float
  573. Defines a maximal L2 norm of the gradients. Values which exceed the given value will be clipped
  574. - `lr_cooldown_epochs` : int (default=0)
  575. Number of epochs to cooldown LR (i.e the last epoch from scheduling view point=max_epochs-cooldown).
  576. - `pre_prediction_callback` : Callable (default=None)
  577. When not None, this callback will be applied to images and targets, and returning them to be used
  578. for the forward pass, and further computations. Args for this callable should be in the order
  579. (inputs, targets, batch_idx) returning modified_inputs, modified_targets
  580. - `ckpt_best_name` : str (default='ckpt_best.pth')
  581. The best checkpoint (according to metric_to_watch) will be saved under this filename in the checkpoints directory.
  582. - `enable_qat`: bool (default=False)
  583. Adds a QATCallback to the phase callbacks, that triggers quantization aware training starting from
  584. qat_params["start_epoch"]
  585. - `qat_params`: dict-like object with the following key/values:
  586. start_epoch: int, first epoch to start QAT.
  587. quant_modules_calib_method: str, One of [percentile, mse, entropy, max]. Statistics method for amax
  588. computation of the quantized modules (default=percentile).
  589. per_channel_quant_modules: bool, whether quant modules should be per channel (default=False).
  590. calibrate: bool, whether to perfrom calibration (default=False).
  591. calibrated_model_path: str, path to a calibrated checkpoint (default=None).
  592. calib_data_loader: torch.utils.data.DataLoader, data loader of the calibration dataset. When None,
  593. context.train_loader will be used (default=None).
  594. num_calib_batches: int, number of batches to collect the statistics from.
  595. percentile: float, percentile value to use when Trainer,quant_modules_calib_method='percentile'.
  596. Discarded when other methods are used (Default=99.99).
  597. :return:
  598. """
  599. global logger
  600. if self.net is None:
  601. raise Exception('Model', 'No model found')
  602. if self.dataset_interface is None and self.train_loader is None:
  603. raise Exception('Data', 'No dataset found')
  604. self.training_params = TrainingParams()
  605. self.training_params.override(**training_params)
  606. # SET RANDOM SEED
  607. random_seed(is_ddp=self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL,
  608. device=self.device, seed=self.training_params.seed)
  609. silent_mode = self.training_params.silent_mode or self.ddp_silent_mode
  610. # METRICS
  611. self._set_train_metrics(train_metrics_list=self.training_params.train_metrics_list)
  612. self._set_valid_metrics(valid_metrics_list=self.training_params.valid_metrics_list)
  613. self.loss_logging_items_names = self.training_params.loss_logging_items_names
  614. self.results_titles = ["Train_" + t for t in
  615. self.loss_logging_items_names + get_metrics_titles(self.train_metrics)] + \
  616. ["Valid_" + t for t in
  617. self.loss_logging_items_names + get_metrics_titles(self.valid_metrics)]
  618. # Store the metric to follow (loss\accuracy) and initialize as the worst value
  619. self.metric_to_watch = self.training_params.metric_to_watch
  620. self.greater_metric_to_watch_is_better = self.training_params.greater_metric_to_watch_is_better
  621. self.metric_idx_in_results_tuple = (self.loss_logging_items_names + get_metrics_titles(self.valid_metrics)).index(self.metric_to_watch)
  622. # Instantiate the values to monitor (loss/metric)
  623. for loss in self.loss_logging_items_names:
  624. self.train_monitored_values[loss] = MonitoredValue(name=loss, greater_is_better=False)
  625. self.valid_monitored_values[loss] = MonitoredValue(name=loss, greater_is_better=False)
  626. self.valid_monitored_values[self.metric_to_watch] = MonitoredValue(name=self.metric_to_watch,
  627. greater_is_better=True)
  628. # Allowing loading instantiated loss or string
  629. if isinstance(self.training_params.loss, str):
  630. criterion_cls = LOSSES[self.training_params.loss]
  631. self.criterion = criterion_cls(**self.training_params.criterion_params)
  632. elif isinstance(self.training_params.loss, Mapping):
  633. self.criterion = LossesFactory().get(self.training_params.loss)
  634. elif isinstance(self.training_params.loss, nn.Module):
  635. self.criterion = self.training_params.loss
  636. self.criterion.to(self.device)
  637. self.max_epochs = self.training_params.max_epochs
  638. self.ema = self.training_params.ema
  639. self.precise_bn = self.training_params.precise_bn
  640. self.precise_bn_batch_size = self.training_params.precise_bn_batch_size
  641. self.batch_accumulate = self.training_params.batch_accumulate
  642. num_batches = len(self.train_loader)
  643. if self.ema:
  644. ema_params = self.training_params.ema_params
  645. logger.info(f'Using EMA with params {ema_params}')
  646. self.ema_model = self._instantiate_ema_model(**ema_params)
  647. self.ema_model.updates = self.start_epoch * num_batches // self.batch_accumulate
  648. if self.load_checkpoint:
  649. if 'ema_net' in self.checkpoint.keys():
  650. self.ema_model.ema.load_state_dict(self.checkpoint['ema_net'])
  651. else:
  652. self.ema = False
  653. logger.warning(
  654. "[Warning] Checkpoint does not include EMA weights, continuing training without EMA.")
  655. self.run_validation_freq = self.training_params.run_validation_freq
  656. validation_results_tuple = (0, 0)
  657. inf_time = 0
  658. timer = core_utils.Timer(self.device)
  659. # IF THE LR MODE IS NOT DEFAULT TAKE IT FROM THE TRAINING PARAMS
  660. self.lr_mode = self.training_params.lr_mode
  661. load_opt_params = self.training_params.load_opt_params
  662. self.phase_callbacks = self.training_params.phase_callbacks or []
  663. self.phase_callbacks = ListFactory(CallbacksFactory()).get(self.phase_callbacks)
  664. if self.lr_mode is not None:
  665. sg_lr_callback_cls = LR_SCHEDULERS_CLS_DICT[self.lr_mode]
  666. self.phase_callbacks.append(sg_lr_callback_cls(train_loader_len=len(self.train_loader),
  667. net=self.net,
  668. training_params=self.training_params,
  669. update_param_groups=self.update_param_groups,
  670. **self.training_params.to_dict()))
  671. if self.training_params.lr_warmup_epochs > 0:
  672. warmup_mode = self.training_params.warmup_mode
  673. if isinstance(warmup_mode, str):
  674. warmup_callback_cls = LR_WARMUP_CLS_DICT[warmup_mode]
  675. elif isinstance(warmup_mode, type) and issubclass(warmup_mode, LRCallbackBase):
  676. warmup_callback_cls = warmup_mode
  677. else:
  678. raise RuntimeError('warmup_mode has to be either a name of a mode (str) or a subclass of PhaseCallback')
  679. self.phase_callbacks.append(warmup_callback_cls(train_loader_len=len(self.train_loader),
  680. net=self.net,
  681. training_params=self.training_params,
  682. update_param_groups=self.update_param_groups,
  683. **self.training_params.to_dict()))
  684. self._add_metrics_update_callback(Phase.TRAIN_BATCH_END)
  685. self._add_metrics_update_callback(Phase.VALIDATION_BATCH_END)
  686. # ADD CALLBACK FOR QAT
  687. self.enable_qat = core_utils.get_param(self.training_params, "enable_qat", False)
  688. if self.enable_qat:
  689. self.qat_params = core_utils.get_param(self.training_params, "qat_params")
  690. if self.qat_params is None:
  691. raise ValueError("Must pass QAT params when enable_qat=True")
  692. self.phase_callbacks.append(QATCallback(**self.qat_params))
  693. self.phase_callback_handler = CallbackHandler(callbacks=self.phase_callbacks)
  694. if not self.ddp_silent_mode:
  695. self._initialize_sg_logger_objects()
  696. if self.training_params.dataset_statistics:
  697. dataset_statistics_logger = DatasetStatisticsTensorboardLogger(self.sg_logger)
  698. dataset_statistics_logger.analyze(self.train_loader, all_classes=self.classes,
  699. title="Train-set", anchors=self.net.module.arch_params.anchors)
  700. dataset_statistics_logger.analyze(self.valid_loader, all_classes=self.classes,
  701. title="val-set")
  702. # AVERAGE BEST 10 MODELS PARAMS
  703. if self.training_params.average_best_models:
  704. self.model_weight_averaging = ModelWeightAveraging(self.checkpoints_dir_path,
  705. greater_is_better=self.greater_metric_to_watch_is_better,
  706. source_ckpt_folder_name=self.source_ckpt_folder_name,
  707. metric_to_watch=self.metric_to_watch,
  708. metric_idx=self.metric_idx_in_results_tuple,
  709. load_checkpoint=self.load_checkpoint,
  710. model_checkpoints_location=self.model_checkpoints_location)
  711. if self.training_params.save_full_train_log and not self.ddp_silent_mode:
  712. logger = get_logger(__name__,
  713. training_log_path=self.sg_logger.log_file_path.replace('.txt', 'full_train_log.log'))
  714. sg_trainer_utils.log_uncaught_exceptions(logger)
  715. if not self.load_checkpoint or self.load_weights_only:
  716. # WHEN STARTING TRAINING FROM SCRATCH, DO NOT LOAD OPTIMIZER PARAMS (EVEN IF LOADING BACKBONE)
  717. self.start_epoch = 0
  718. self._reset_best_metric()
  719. load_opt_params = False
  720. if isinstance(self.training_params.optimizer, str) or \
  721. (inspect.isclass(self.training_params.optimizer) and issubclass(self.training_params.optimizer, torch.optim.Optimizer)):
  722. self.optimizer = build_optimizer(net=self.net, lr=self.training_params.initial_lr,
  723. training_params=self.training_params)
  724. elif isinstance(self.training_params.optimizer, torch.optim.Optimizer):
  725. self.optimizer = self.training_params.optimizer
  726. else:
  727. raise UnsupportedOptimizerFormat()
  728. # VERIFY GRADIENT CLIPPING VALUE
  729. if self.training_params.clip_grad_norm is not None and self.training_params.clip_grad_norm <= 0:
  730. raise TypeError('Params', 'Invalid clip_grad_norm')
  731. if self.load_checkpoint and load_opt_params:
  732. self.optimizer.load_state_dict(self.checkpoint['optimizer_state_dict'])
  733. self.pre_prediction_callback = CallbacksFactory().get(self.training_params.pre_prediction_callback)
  734. self._initialize_mixed_precision(self.training_params.mixed_precision)
  735. self._infinite_train_loader = (hasattr(self.train_loader, "sampler") and isinstance(self.train_loader.sampler, InfiniteSampler)) or \
  736. (hasattr(self.train_loader, "batch_sampler") and isinstance(self.train_loader.batch_sampler.sampler, InfiniteSampler))
  737. self.ckpt_best_name = self.training_params.ckpt_best_name
  738. context = PhaseContext(optimizer=self.optimizer,
  739. net=self.net,
  740. experiment_name=self.experiment_name,
  741. ckpt_dir=self.checkpoints_dir_path,
  742. criterion=self.criterion,
  743. lr_warmup_epochs=self.training_params.lr_warmup_epochs,
  744. sg_logger=self.sg_logger,
  745. train_loader=self.train_loader,
  746. valid_loader=self.valid_loader,
  747. training_params=self.training_params,
  748. ddp_silent_mode=self.ddp_silent_mode,
  749. checkpoint_params=self.checkpoint_params,
  750. architecture=self.architecture,
  751. arch_params=self.arch_params,
  752. metric_idx_in_results_tuple=self.metric_idx_in_results_tuple,
  753. metric_to_watch=self.metric_to_watch,
  754. device=self.device,
  755. context_methods=self._get_context_methods(Phase.PRE_TRAINING)
  756. )
  757. self.phase_callback_handler(Phase.PRE_TRAINING, context)
  758. try:
  759. # HEADERS OF THE TRAINING PROGRESS
  760. if not silent_mode:
  761. logger.info(
  762. f'Started training for {self.max_epochs - self.start_epoch} epochs ({self.start_epoch}/'f'{self.max_epochs - 1})\n')
  763. for epoch in range(self.start_epoch, self.max_epochs):
  764. if context.stop_training:
  765. logger.info("Request to stop training has been received, stopping training")
  766. break
  767. # Phase.TRAIN_EPOCH_START
  768. # RUN PHASE CALLBACKS
  769. context.update_context(epoch=epoch)
  770. self.phase_callback_handler(Phase.TRAIN_EPOCH_START, context)
  771. # IN DDP- SET_EPOCH WILL CAUSE EVERY PROCESS TO BE EXPOSED TO THE ENTIRE DATASET BY SHUFFLING WITH A
  772. # DIFFERENT SEED EACH EPOCH START
  773. if self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL and hasattr(self.train_loader, "sampler") \
  774. and hasattr(self.train_loader.sampler, "set_epoch"):
  775. self.train_loader.sampler.set_epoch(epoch)
  776. train_metrics_tuple = self._train_epoch(epoch=epoch, silent_mode=silent_mode)
  777. # Phase.TRAIN_EPOCH_END
  778. # RUN PHASE CALLBACKS
  779. train_metrics_dict = get_metrics_dict(train_metrics_tuple, self.train_metrics,
  780. self.loss_logging_items_names)
  781. context.update_context(metrics_dict=train_metrics_dict)
  782. self.phase_callback_handler(Phase.TRAIN_EPOCH_END, context)
  783. # CALCULATE PRECISE BATCHNORM STATS
  784. if self.precise_bn:
  785. compute_precise_bn_stats(model=self.net, loader=self.train_loader,
  786. precise_bn_batch_size=self.precise_bn_batch_size,
  787. num_gpus=self.num_devices)
  788. if self.ema:
  789. compute_precise_bn_stats(model=self.ema_model.ema, loader=self.train_loader,
  790. precise_bn_batch_size=self.precise_bn_batch_size,
  791. num_gpus=self.num_devices)
  792. # model switch - we replace self.net.module with the ema model for the testing and saving part
  793. # and then switch it back before the next training epoch
  794. if self.ema:
  795. self.ema_model.update_attr(self.net)
  796. keep_model = self.net
  797. self.net = self.ema_model.ema
  798. # RUN TEST ON VALIDATION SET EVERY self.run_validation_freq EPOCHS
  799. if (epoch + 1) % self.run_validation_freq == 0:
  800. timer.start()
  801. validation_results_tuple = self._validate_epoch(epoch=epoch, silent_mode=silent_mode)
  802. inf_time = timer.stop()
  803. # Phase.VALIDATION_EPOCH_END
  804. # RUN PHASE CALLBACKS
  805. valid_metrics_dict = get_metrics_dict(validation_results_tuple, self.valid_metrics,
  806. self.loss_logging_items_names)
  807. context.update_context(metrics_dict=valid_metrics_dict)
  808. self.phase_callback_handler(Phase.VALIDATION_EPOCH_END, context)
  809. if self.ema:
  810. self.net = keep_model
  811. if not self.ddp_silent_mode:
  812. # SAVING AND LOGGING OCCURS ONLY IN THE MAIN PROCESS (IN CASES THERE ARE SEVERAL PROCESSES - DDP)
  813. self._write_to_disk_operations(train_metrics_tuple, validation_results_tuple, inf_time, epoch,
  814. context)
  815. # Evaluating the average model and removing snapshot averaging file if training is completed
  816. if self.training_params.average_best_models:
  817. self._validate_final_average_model(cleanup_snapshots_pkl_file=True)
  818. except KeyboardInterrupt:
  819. logger.info(
  820. '\n[MODEL TRAINING EXECUTION HAS BEEN INTERRUPTED]... Please wait until SOFT-TERMINATION process '
  821. 'finishes and saves all of the Model Checkpoints and log files before terminating...')
  822. logger.info('For HARD Termination - Stop the process again')
  823. finally:
  824. if self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
  825. # CLEAN UP THE MULTI-GPU PROCESS GROUP WHEN DONE
  826. if torch.distributed.is_initialized():
  827. torch.distributed.destroy_process_group()
  828. # PHASE.TRAIN_END
  829. self.phase_callback_handler(Phase.POST_TRAINING, context)
  830. if not self.ddp_silent_mode:
  831. if self.model_checkpoints_location != 'local':
  832. logger.info('[CLEANUP] - Saving Checkpoint files')
  833. self.sg_logger.upload()
  834. self.sg_logger.close()
  835. def _reset_best_metric(self):
  836. self.best_metric = -1 * np.inf if self.greater_metric_to_watch_is_better else np.inf
  837. def _reset_metrics(self):
  838. for metric in ("train_metrics", "valid_metrics", "test_metrics"):
  839. if hasattr(self, metric) and getattr(self, metric) is not None:
  840. getattr(self, metric).reset()
  841. @resolve_param('train_metrics_list', ListFactory(MetricsFactory()))
  842. def _set_train_metrics(self, train_metrics_list):
  843. self.train_metrics = MetricCollection(train_metrics_list)
  844. @resolve_param('valid_metrics_list', ListFactory(MetricsFactory()))
  845. def _set_valid_metrics(self, valid_metrics_list):
  846. self.valid_metrics = MetricCollection(valid_metrics_list)
  847. def _initialize_mixed_precision(self, mixed_precision_enabled: bool):
  848. # SCALER IS ALWAYS INITIALIZED BUT IS DISABLED IF MIXED PRECISION WAS NOT SET
  849. self.scaler = GradScaler(enabled=mixed_precision_enabled)
  850. if mixed_precision_enabled:
  851. assert self.device.startswith('cuda'), "mixed precision is not available for CPU"
  852. if self.multi_gpu == MultiGPUMode.DATA_PARALLEL:
  853. # IN DATAPARALLEL MODE WE NEED TO WRAP THE FORWARD FUNCTION OF OUR MODEL SO IT WILL RUN WITH AUTOCAST.
  854. # BUT SINCE THE MODULE IS CLONED TO THE DEVICES ON EACH FORWARD CALL OF A DATAPARALLEL MODEL,
  855. # WE HAVE TO REGISTER THE WRAPPER BEFORE EVERY FORWARD CALL
  856. def hook(module, _):
  857. module.forward = MultiGPUModeAutocastWrapper(module.forward)
  858. self.net.module.register_forward_pre_hook(hook=hook)
  859. if self.load_checkpoint:
  860. scaler_state_dict = core_utils.get_param(self.checkpoint, 'scaler_state_dict')
  861. if scaler_state_dict is None:
  862. logger.warning(
  863. 'Mixed Precision - scaler state_dict not found in loaded model. This may case issues '
  864. 'with loss scaling')
  865. else:
  866. self.scaler.load_state_dict(scaler_state_dict)
  867. def _validate_final_average_model(self, cleanup_snapshots_pkl_file=False):
  868. """
  869. Testing the averaged model by loading the last saved average checkpoint and running test.
  870. Will be loaded to each of DDP processes
  871. :param cleanup_pkl_file: a flag for deleting the 10 best snapshots dictionary
  872. """
  873. logger.info('RUNNING ADDITIONAL TEST ON THE AVERAGED MODEL...')
  874. keep_state_dict = deepcopy(self.net.state_dict())
  875. # SETTING STATE DICT TO THE AVERAGE MODEL FOR EVALUATION
  876. average_model_ckpt_path = os.path.join(self.checkpoints_dir_path, self.average_model_checkpoint_filename)
  877. average_model_sd = read_ckpt_state_dict(average_model_ckpt_path)['net']
  878. self.net.load_state_dict(average_model_sd)
  879. # testing the averaged model and save instead of best model if needed
  880. averaged_model_results_tuple = self._validate_epoch(epoch=self.max_epochs)
  881. # Reverting the current model
  882. self.net.load_state_dict(keep_state_dict)
  883. if not self.ddp_silent_mode:
  884. # Adding values to sg_logger
  885. # looping over last titles which corresponds to validation (and average model) metrics.
  886. all_titles = self.results_titles[-1 * len(averaged_model_results_tuple):]
  887. result_dict = {all_titles[i]: averaged_model_results_tuple[i] for i in
  888. range(len(averaged_model_results_tuple))}
  889. self.sg_logger.add_scalars(tag_scalar_dict=result_dict, global_step=self.max_epochs)
  890. average_model_tb_titles = ['Averaged Model ' + x for x in
  891. self.results_titles[-1 * len(averaged_model_results_tuple):]]
  892. write_struct = ''
  893. for ind, title in enumerate(average_model_tb_titles):
  894. write_struct += '%s: %.3f \n ' % (title, averaged_model_results_tuple[ind])
  895. self.sg_logger.add_scalar(title, averaged_model_results_tuple[ind], global_step=self.max_epochs)
  896. self.sg_logger.add_text("Averaged_Model_Performance", write_struct, self.max_epochs)
  897. if cleanup_snapshots_pkl_file:
  898. self.model_weight_averaging.cleanup()
  899. @property
  900. def get_arch_params(self):
  901. return self.arch_params.to_dict()
  902. @property
  903. def get_structure(self):
  904. return self.net.module.structure
  905. @property
  906. def get_architecture(self):
  907. return self.architecture
  908. def set_experiment_name(self, experiment_name):
  909. self.experiment_name = experiment_name
  910. def _re_build_model(self, arch_params={}):
  911. """
  912. arch_params : dict
  913. Architecture H.P. e.g.: block, num_blocks, num_classes, etc.
  914. :return:
  915. """
  916. if 'num_classes' not in arch_params.keys():
  917. if self.dataset_interface is None:
  918. raise Exception('Error', 'Number of classes not defined in arch params and dataset is not defined')
  919. else:
  920. arch_params['num_classes'] = len(self.classes)
  921. self.arch_params = core_utils.HpmStruct(**arch_params)
  922. self.classes = self.arch_params.num_classes
  923. self.net = self._instantiate_net(self.architecture, self.arch_params, self.checkpoint_params)
  924. # save the architecture for neural architecture search
  925. if hasattr(self.net, 'structure'):
  926. self.architecture = self.net.structure
  927. self.net.to(self.device)
  928. if self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
  929. logger.warning("Warning: distributed training is not supported in re_build_model()")
  930. self.net = torch.nn.DataParallel(self.net,
  931. device_ids=self.device_ids) if self.multi_gpu else core_utils.WrappedModel(
  932. self.net)
  933. @property
  934. def get_module(self):
  935. return self.net
  936. def set_module(self, module):
  937. self.net = module
  938. def _initialize_device(self, requested_device: str, requested_multi_gpu: Union[MultiGPUMode, str]):
  939. """
  940. _initialize_device - Initializes the device for the model - Default is CUDA
  941. :param requested_device: Device to initialize ('cuda' / 'cpu')
  942. :param requested_multi_gpu: Get Multiple GPU
  943. """
  944. if isinstance(requested_multi_gpu, str):
  945. requested_multi_gpu = MultiGPUMode(requested_multi_gpu)
  946. # SELECT CUDA DEVICE
  947. if requested_device == 'cuda':
  948. if torch.cuda.is_available():
  949. self.device = 'cuda' # TODO - we may want to set the device number as well i.e. 'cuda:1'
  950. else:
  951. raise RuntimeError('CUDA DEVICE NOT FOUND... EXITING')
  952. # SELECT CPU DEVICE
  953. elif requested_device == 'cpu':
  954. self.device = 'cpu'
  955. self.multi_gpu = False
  956. else:
  957. # SELECT CUDA DEVICE BY DEFAULT IF AVAILABLE
  958. self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
  959. # DEFUALT IS SET TO 1 - IT IS CHANGED IF MULTI-GPU IS USED
  960. self.num_devices = 1
  961. # IN CASE OF MULTIPLE GPUS UPDATE THE LEARNING AND DATA PARAMETERS
  962. # FIXME - CREATE A DISCUSSION ON THESE PARAMETERS - WE MIGHT WANT TO CHANGE THE WAY WE USE THE LR AND
  963. if requested_multi_gpu != MultiGPUMode.OFF:
  964. if 'cuda' in self.device:
  965. # COLLECT THE AVAILABLE GPU AND COUNT THE AVAILABLE GPUS AMOUNT
  966. self.device_ids = list(range(torch.cuda.device_count()))
  967. self.num_devices = len(self.device_ids)
  968. if self.num_devices == 1:
  969. self.multi_gpu = MultiGPUMode.OFF
  970. if requested_multi_gpu != MultiGPUMode.AUTO:
  971. # if AUTO mode was set - do not log a warning
  972. logger.warning(
  973. '\n[WARNING] - Tried running on multiple GPU but only a single GPU is available\n')
  974. else:
  975. if requested_multi_gpu == MultiGPUMode.AUTO:
  976. if env_helpers.is_distributed():
  977. requested_multi_gpu = MultiGPUMode.DISTRIBUTED_DATA_PARALLEL
  978. else:
  979. requested_multi_gpu = MultiGPUMode.DATA_PARALLEL
  980. self.multi_gpu = requested_multi_gpu
  981. if self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
  982. self._initialize_ddp()
  983. else:
  984. # MULTIPLE GPUS CAN BE ACTIVE ONLY IF A GPU IS AVAILABLE
  985. self.multi_gpu = MultiGPUMode.OFF
  986. logger.warning('\n[WARNING] - Tried running on multiple GPU but none are available => running on CPU\n')
  987. def _initialize_ddp(self):
  988. """
  989. Initializes Distributed Data Parallel
  990. Usage:
  991. python -m torch.distributed.launch --nproc_per_node=n YOUR_TRAINING_SCRIPT.py
  992. where n is the number of GPUs required, e.g., n=8
  993. Important note: (1) in distributed training it is customary to specify learning rates and batch sizes per GPU.
  994. Whatever learning rate and schedule you specify will be applied to the each GPU individually.
  995. Since gradients are passed and summed (reduced) from all to all GPUs, the effective batch size is the
  996. batch you specify times the number of GPUs. In the literature there are several "best practices" to set
  997. learning rates and schedules for large batch sizes.
  998. """
  999. logger.info("Distributed training starting...")
  1000. local_rank = environment_config.DDP_LOCAL_RANK
  1001. if not torch.distributed.is_initialized():
  1002. torch.distributed.init_process_group(backend='nccl', init_method='env://')
  1003. if local_rank > 0:
  1004. f = open(os.devnull, 'w')
  1005. sys.stdout = f # silent all printing for non master process
  1006. torch.cuda.set_device(local_rank)
  1007. self.device = 'cuda:%d' % local_rank
  1008. # MAKE ALL HIGHER-RANK GPUS SILENT (DISTRIBUTED MODE)
  1009. self.ddp_silent_mode = local_rank > 0
  1010. if torch.distributed.get_rank() == 0:
  1011. logger.info(f"Training in distributed mode... with {str(torch.distributed.get_world_size())} GPUs")
  1012. def _switch_device(self, new_device):
  1013. self.device = new_device
  1014. self.net.to(self.device)
  1015. # FIXME - we need to resolve flake8's 'function is too complex' for this function
  1016. def _load_checkpoint_to_model(self): # noqa: C901 - too complex
  1017. """
  1018. Copies the source checkpoint to a local folder and loads the checkpoint's data to the model using the
  1019. attributes:
  1020. strict: See StrictLoad class documentation for details.
  1021. load_backbone: loads the provided checkpoint to self.net.backbone instead of self.net
  1022. source_ckpt_folder_name: The folder where the checkpoint is saved. By default uses the self.experiment_name
  1023. NOTE: 'acc', 'epoch', 'optimizer_state_dict' and the logs are NOT loaded if self.zeroize_prev_train_params
  1024. is True
  1025. """
  1026. self._set_ckpt_loading_attributes()
  1027. if self.load_checkpoint or self.external_checkpoint_path:
  1028. # GET LOCAL PATH TO THE CHECKPOINT FILE FIRST
  1029. ckpt_local_path = get_ckpt_local_path(source_ckpt_folder_name=self.source_ckpt_folder_name,
  1030. experiment_name=self.experiment_name,
  1031. ckpt_name=self.ckpt_name,
  1032. model_checkpoints_location=self.model_checkpoints_location,
  1033. external_checkpoint_path=self.external_checkpoint_path,
  1034. overwrite_local_checkpoint=self.overwrite_local_checkpoint,
  1035. load_weights_only=self.load_weights_only)
  1036. # LOAD CHECKPOINT TO MODEL
  1037. self.checkpoint = load_checkpoint_to_model(ckpt_local_path=ckpt_local_path,
  1038. load_backbone=self.load_backbone,
  1039. net=self.net,
  1040. strict=self.strict_load.value if isinstance(self.strict_load,
  1041. StrictLoad) else self.strict_load,
  1042. load_weights_only=self.load_weights_only,
  1043. load_ema_as_net=self.load_ema_as_net)
  1044. if 'ema_net' in self.checkpoint.keys():
  1045. logger.warning("[WARNING] Main network has been loaded from checkpoint but EMA network exists as "
  1046. "well. It "
  1047. " will only be loaded during validation when training with ema=True. ")
  1048. # UPDATE TRAINING PARAMS IF THEY EXIST & WE ARE NOT LOADING AN EXTERNAL MODEL's WEIGHTS
  1049. self.best_metric = self.checkpoint['acc'] if 'acc' in self.checkpoint.keys() else -1
  1050. self.start_epoch = self.checkpoint['epoch'] if 'epoch' in self.checkpoint.keys() else 0
  1051. def _prep_for_test(self, test_loader: torch.utils.data.DataLoader = None, loss=None, post_prediction_callback=None,
  1052. test_metrics_list=None,
  1053. loss_logging_items_names=None, test_phase_callbacks=None):
  1054. """Run commands that are common to all SgModels"""
  1055. # SET THE MODEL IN evaluation STATE
  1056. self.net.eval()
  1057. # IF SPECIFIED IN THE FUNCTION CALL - OVERRIDE THE self ARGUMENTS
  1058. self.test_loader = test_loader or self.test_loader
  1059. self.criterion = loss or self.criterion
  1060. self.post_prediction_callback = post_prediction_callback or self.post_prediction_callback
  1061. self.loss_logging_items_names = loss_logging_items_names or self.loss_logging_items_names
  1062. self.phase_callbacks = test_phase_callbacks or self.phase_callbacks
  1063. if self.phase_callbacks is None:
  1064. self.phase_callbacks = []
  1065. if test_metrics_list:
  1066. self.test_metrics = MetricCollection(test_metrics_list)
  1067. self._add_metrics_update_callback(Phase.TEST_BATCH_END)
  1068. self.phase_callback_handler = CallbackHandler(self.phase_callbacks)
  1069. # WHEN TESTING WITHOUT A LOSS FUNCTION- CREATE EPOCH HEADERS FOR PRINTS
  1070. if self.criterion is None:
  1071. self.loss_logging_items_names = []
  1072. if self.test_metrics is None:
  1073. raise ValueError("Metrics are required to perform test. Pass them through test_metrics_list arg when "
  1074. "calling test or through training_params when calling train(...)")
  1075. if self.test_loader is None:
  1076. raise ValueError("Test dataloader is required to perform test. Make sure to either pass it through "
  1077. "test_loader arg or calling connect_dataset_interface upon a DatasetInterface instance "
  1078. "with a non empty testset attribute.")
  1079. # RESET METRIC RUNNERS
  1080. self._reset_metrics()
  1081. self.test_metrics.to(self.device)
  1082. def _add_metrics_update_callback(self, phase: Phase):
  1083. """
  1084. Adds MetricsUpdateCallback to be fired at phase
  1085. :param phase: Phase for the metrics callback to be fired at
  1086. """
  1087. self.phase_callbacks.append(MetricsUpdateCallback(phase))
  1088. def _initialize_sg_logger_objects(self):
  1089. """Initialize object that collect, write to disk, monitor and store remotely all training outputs"""
  1090. sg_logger = core_utils.get_param(self.training_params, 'sg_logger')
  1091. # OVERRIDE SOME PARAMETERS TO MAKE SURE THEY MATCH THE TRAINING PARAMETERS
  1092. general_sg_logger_params = {'experiment_name': self.experiment_name,
  1093. 'storage_location': self.model_checkpoints_location,
  1094. 'resumed': self.load_checkpoint,
  1095. 'training_params': self.training_params,
  1096. 'checkpoints_dir_path': self.checkpoints_dir_path}
  1097. if sg_logger is None:
  1098. raise RuntimeError('sg_logger must be defined in training params (see default_training_params)')
  1099. if isinstance(sg_logger, AbstractSGLogger):
  1100. self.sg_logger = sg_logger
  1101. elif isinstance(sg_logger, str):
  1102. sg_logger_params = core_utils.get_param(self.training_params, 'sg_logger_params', {})
  1103. if issubclass(SG_LOGGERS[sg_logger], BaseSGLogger):
  1104. sg_logger_params = {**sg_logger_params, **general_sg_logger_params}
  1105. if sg_logger not in SG_LOGGERS:
  1106. raise RuntimeError('sg_logger not defined in SG_LOGGERS')
  1107. self.sg_logger = SG_LOGGERS[sg_logger](**sg_logger_params)
  1108. else:
  1109. raise RuntimeError('sg_logger can be either an sg_logger name (str) or an instance of AbstractSGLogger')
  1110. if not isinstance(self.sg_logger, BaseSGLogger):
  1111. logger.warning("WARNING! Using a user-defined sg_logger: files will not be automatically written to disk!\n"
  1112. "Please make sure the provided sg_logger writes to disk or compose your sg_logger to BaseSGLogger")
  1113. # IN CASE SG_LOGGER UPDATED THE DIR PATH
  1114. self.checkpoints_dir_path = self.sg_logger.local_dir()
  1115. hyper_param_config = self._get_hyper_param_config()
  1116. self.sg_logger.add_config("hyper_params", hyper_param_config)
  1117. self.sg_logger.flush()
  1118. def _get_hyper_param_config(self):
  1119. """
  1120. Creates a training hyper param config for logging.
  1121. """
  1122. additional_log_items = {'initial_LR': self.training_params.initial_lr,
  1123. 'num_devices': self.num_devices,
  1124. 'multi_gpu': str(self.multi_gpu),
  1125. 'device_type': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'cpu'}
  1126. # ADD INSTALLED PACKAGE LIST + THEIR VERSIONS
  1127. if self.training_params.log_installed_packages:
  1128. pkg_list = list(map(lambda pkg: str(pkg), _get_installed_distributions()))
  1129. additional_log_items['installed_packages'] = pkg_list
  1130. hyper_param_config = {"arch_params": self.arch_params.__dict__,
  1131. "checkpoint_params": self.checkpoint_params.__dict__,
  1132. "training_hyperparams": self.training_params.__dict__,
  1133. "dataset_params": self.dataset_params.__dict__,
  1134. "additional_log_items": additional_log_items}
  1135. return hyper_param_config
  1136. def _write_to_disk_operations(self, train_metrics: tuple, validation_results: tuple, inf_time: float, epoch: int,
  1137. context: PhaseContext):
  1138. """Run the various logging operations, e.g.: log file, Tensorboard, save checkpoint etc."""
  1139. # STORE VALUES IN A TENSORBOARD FILE
  1140. train_results = list(train_metrics) + list(validation_results) + [inf_time]
  1141. all_titles = self.results_titles + ['Inference Time']
  1142. result_dict = {all_titles[i]: train_results[i] for i in range(len(train_results))}
  1143. self.sg_logger.add_scalars(tag_scalar_dict=result_dict, global_step=epoch)
  1144. # SAVE THE CHECKPOINT
  1145. if self.training_params.save_model:
  1146. self._save_checkpoint(self.optimizer, epoch + 1, validation_results, context)
  1147. def _write_lrs(self, epoch):
  1148. lrs = [self.optimizer.param_groups[i]['lr'] for i in range(len(self.optimizer.param_groups))]
  1149. lr_titles = ['LR/Param_group_' + str(i) for i in range(len(self.optimizer.param_groups))] if len(
  1150. self.optimizer.param_groups) > 1 else ['LR']
  1151. lr_dict = {lr_titles[i]: lrs[i] for i in range(len(lrs))}
  1152. self.sg_logger.add_scalars(tag_scalar_dict=lr_dict, global_step=epoch)
  1153. def test(self, # noqa: C901
  1154. test_loader: torch.utils.data.DataLoader = None,
  1155. loss: torch.nn.modules.loss._Loss = None,
  1156. silent_mode: bool = False,
  1157. test_metrics_list=None,
  1158. loss_logging_items_names=None, metrics_progress_verbose=False, test_phase_callbacks=None,
  1159. use_ema_net=True) -> tuple:
  1160. """
  1161. Evaluates the model on given dataloader and metrics.
  1162. :param test_loader: dataloader to perform test on.
  1163. :param test_metrics_list: (list(torchmetrics.Metric)) metrics list for evaluation.
  1164. :param silent_mode: (bool) controls verbosity
  1165. :param metrics_progress_verbose: (bool) controls the verbosity of metrics progress (default=False). Slows down the program.
  1166. :param use_ema_net (bool) whether to perform test on self.ema_model.ema (when self.ema_model.ema exists,
  1167. otherwise self.net will be tested) (default=True)
  1168. :return: results tuple (tuple) containing the loss items and metric values.
  1169. All of the above args will override Trainer's corresponding attribute when not equal to None. Then evaluation
  1170. is ran on self.test_loader with self.test_metrics.
  1171. """
  1172. # IN CASE TRAINING WAS PERFROMED BEFORE TEST- MAKE SURE TO TEST THE EMA MODEL (UNLESS SPECIFIED OTHERWISE BY
  1173. # use_ema_net)
  1174. if use_ema_net and self.ema_model is not None:
  1175. keep_model = self.net
  1176. self.net = self.ema_model.ema
  1177. self._prep_for_test(test_loader=test_loader,
  1178. loss=loss,
  1179. test_metrics_list=test_metrics_list,
  1180. loss_logging_items_names=loss_logging_items_names,
  1181. test_phase_callbacks=test_phase_callbacks,
  1182. )
  1183. test_results = self.evaluate(data_loader=self.test_loader,
  1184. metrics=self.test_metrics,
  1185. evaluation_type=EvaluationType.TEST,
  1186. silent_mode=silent_mode,
  1187. metrics_progress_verbose=metrics_progress_verbose)
  1188. # SWITCH BACK BETWEEN NETS SO AN ADDITIONAL TRAINING CAN BE DONE AFTER TEST
  1189. if use_ema_net and self.ema_model is not None:
  1190. self.net = keep_model
  1191. return test_results
  1192. def _validate_epoch(self, epoch: int, silent_mode: bool = False) -> tuple:
  1193. """
  1194. Runs evaluation on self.valid_loader, with self.valid_metrics.
  1195. :param epoch: (int) epoch idx
  1196. :param silent_mode: (bool) controls verbosity
  1197. :return: results tuple (tuple) containing the loss items and metric values.
  1198. """
  1199. self.net.eval()
  1200. self._reset_metrics()
  1201. self.valid_metrics.to(self.device)
  1202. return self.evaluate(data_loader=self.valid_loader, metrics=self.valid_metrics,
  1203. evaluation_type=EvaluationType.VALIDATION, epoch=epoch, silent_mode=silent_mode)
  1204. def evaluate(self, data_loader: torch.utils.data.DataLoader, metrics: MetricCollection,
  1205. evaluation_type: EvaluationType, epoch: int = None, silent_mode: bool = False,
  1206. metrics_progress_verbose: bool = False):
  1207. """
  1208. Evaluates the model on given dataloader and metrics.
  1209. :param data_loader: dataloader to perform evaluataion on
  1210. :param metrics: (MetricCollection) metrics for evaluation
  1211. :param evaluation_type: (EvaluationType) controls which phase callbacks will be used (for example, on batch end,
  1212. when evaluation_type=EvaluationType.VALIDATION the Phase.VALIDATION_BATCH_END callbacks will be triggered)
  1213. :param epoch: (int) epoch idx
  1214. :param silent_mode: (bool) controls verbosity
  1215. :param metrics_progress_verbose: (bool) controls the verbosity of metrics progress (default=False).
  1216. Slows down the program significantly.
  1217. :return: results tuple (tuple) containing the loss items and metric values.
  1218. """
  1219. # THE DISABLE FLAG CONTROLS WHETHER THE PROGRESS BAR IS SILENT OR PRINTS THE LOGS
  1220. progress_bar_data_loader = tqdm(data_loader, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True,
  1221. disable=silent_mode)
  1222. loss_avg_meter = core_utils.utils.AverageMeter()
  1223. logging_values = None
  1224. loss_tuple = None
  1225. lr_warmup_epochs = self.training_params.lr_warmup_epochs if self.training_params else None
  1226. context = PhaseContext(epoch=epoch,
  1227. metrics_compute_fn=metrics,
  1228. loss_avg_meter=loss_avg_meter,
  1229. criterion=self.criterion,
  1230. device=self.device,
  1231. lr_warmup_epochs=lr_warmup_epochs,
  1232. sg_logger=self.sg_logger,
  1233. context_methods=self._get_context_methods(Phase.VALIDATION_BATCH_END))
  1234. if not silent_mode:
  1235. # PRINT TITLES
  1236. pbar_start_msg = f"Validation epoch {epoch}" if evaluation_type == EvaluationType.VALIDATION else "Test"
  1237. progress_bar_data_loader.set_description(pbar_start_msg)
  1238. with torch.no_grad():
  1239. for batch_idx, batch_items in enumerate(progress_bar_data_loader):
  1240. batch_items = core_utils.tensor_container_to_device(batch_items, self.device, non_blocking=True)
  1241. inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items)
  1242. output = self.net(inputs)
  1243. if self.criterion is not None:
  1244. # STORE THE loss_items ONLY, THE 1ST RETURNED VALUE IS THE loss FOR BACKPROP DURING TRAINING
  1245. loss_tuple = self._get_losses(output, targets)[1].cpu()
  1246. context.update_context(batch_idx=batch_idx,
  1247. inputs=inputs,
  1248. preds=output,
  1249. target=targets,
  1250. loss_log_items=loss_tuple,
  1251. **additional_batch_items)
  1252. # TRIGGER PHASE CALLBACKS CORRESPONDING TO THE EVALUATION TYPE
  1253. if evaluation_type == EvaluationType.VALIDATION:
  1254. self.phase_callback_handler(Phase.VALIDATION_BATCH_END, context)
  1255. else:
  1256. self.phase_callback_handler(Phase.TEST_BATCH_END, context)
  1257. # COMPUTE METRICS IF PROGRESS VERBOSITY IS SET
  1258. if metrics_progress_verbose and not silent_mode:
  1259. # COMPUTE THE RUNNING USER METRICS AND LOSS RUNNING ITEMS. RESULT TUPLE IS THEIR CONCATENATION.
  1260. logging_values = get_logging_values(loss_avg_meter, metrics, self.criterion)
  1261. pbar_message_dict = get_train_loop_description_dict(logging_values,
  1262. metrics,
  1263. self.loss_logging_items_names)
  1264. progress_bar_data_loader.set_postfix(**pbar_message_dict)
  1265. # NEED TO COMPUTE METRICS FOR THE FIRST TIME IF PROGRESS VERBOSITY IS NOT SET
  1266. if not metrics_progress_verbose:
  1267. # COMPUTE THE RUNNING USER METRICS AND LOSS RUNNING ITEMS. RESULT TUPLE IS THEIR CONCATENATION.
  1268. logging_values = get_logging_values(loss_avg_meter, metrics, self.criterion)
  1269. pbar_message_dict = get_train_loop_description_dict(logging_values,
  1270. metrics,
  1271. self.loss_logging_items_names)
  1272. progress_bar_data_loader.set_postfix(**pbar_message_dict)
  1273. # TODO: SUPPORT PRINTING AP PER CLASS- SINCE THE METRICS ARE NOT HARD CODED ANYMORE (as done in
  1274. # calc_batch_prediction_accuracy_per_class in metric_utils.py), THIS IS ONLY RELEVANT WHEN CHOOSING
  1275. # DETECTIONMETRICS, WHICH ALREADY RETURN THE METRICS VALUEST HEMSELVES AND NOT THE ITEMS REQUIRED FOR SUCH
  1276. # COMPUTATION. ALSO REMOVE THE BELOW LINES BY IMPLEMENTING CRITERION AS A TORCHMETRIC.
  1277. if self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
  1278. logging_values = reduce_results_tuple_for_ddp(logging_values, next(self.net.parameters()).device)
  1279. pbar_message_dict = get_train_loop_description_dict(logging_values,
  1280. metrics,
  1281. self.loss_logging_items_names)
  1282. self.valid_monitored_values = sg_trainer_utils.update_monitored_values_dict(
  1283. monitored_values_dict=self.valid_monitored_values, new_values_dict=pbar_message_dict)
  1284. if not silent_mode and evaluation_type == EvaluationType.VALIDATION:
  1285. progress_bar_data_loader.write("===========================================================")
  1286. sg_trainer_utils.display_epoch_summary(epoch=context.epoch, n_digits=4,
  1287. train_monitored_values=self.train_monitored_values,
  1288. valid_monitored_values=self.valid_monitored_values)
  1289. progress_bar_data_loader.write("===========================================================")
  1290. return logging_values
  1291. def _instantiate_net(self, architecture: Union[torch.nn.Module, SgModule.__class__, str], arch_params: dict,
  1292. checkpoint_params: dict, *args, **kwargs) -> tuple:
  1293. """
  1294. Instantiates nn.Module according to architecture and arch_params, and handles pretrained weights and the required
  1295. module manipulation (i.e head replacement).
  1296. :param architecture: String, torch.nn.Module or uninstantiated SgModule class describing the netowrks architecture.
  1297. :param arch_params: Architecture's parameters passed to networks c'tor.
  1298. :param checkpoint_params: checkpoint loading related parameters dictionary with 'pretrained_weights' key,
  1299. s.t it's value is a string describing the dataset of the pretrained weights (for example "imagenent").
  1300. :return: instantiated netowrk i.e torch.nn.Module, architecture_class (will be none when architecture is not str)
  1301. """
  1302. pretrained_weights = core_utils.get_param(checkpoint_params, 'pretrained_weights', default_val=None)
  1303. if pretrained_weights is not None:
  1304. num_classes_new_head = arch_params.num_classes
  1305. arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
  1306. if isinstance(architecture, str):
  1307. architecture_cls = ARCHITECTURES[architecture]
  1308. net = architecture_cls(arch_params=arch_params)
  1309. elif isinstance(architecture, SgModule.__class__):
  1310. net = architecture(arch_params)
  1311. else:
  1312. net = architecture
  1313. if pretrained_weights:
  1314. load_pretrained_weights(net, architecture, pretrained_weights)
  1315. if num_classes_new_head != arch_params.num_classes:
  1316. net.replace_head(new_num_classes=num_classes_new_head)
  1317. arch_params.num_classes = num_classes_new_head
  1318. return net
  1319. def _instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> ModelEMA:
  1320. """Instantiate ema model for standard SgModule.
  1321. :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
  1322. until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
  1323. :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
  1324. its final value. beta=15 is ~40% of the training process.
  1325. """
  1326. return ModelEMA(self.net, decay, beta, exp_activation)
  1327. @property
  1328. def get_net(self):
  1329. """
  1330. Getter for network.
  1331. :return: torch.nn.Module, self.net
  1332. """
  1333. return self.net
  1334. def set_net(self, net: torch.nn.Module):
  1335. """
  1336. Setter for network.
  1337. :param net: torch.nn.Module, value to set net
  1338. :return:
  1339. """
  1340. self.net = net
  1341. def set_ckpt_best_name(self, ckpt_best_name):
  1342. """
  1343. Setter for best checkpoint filename.
  1344. :param ckpt_best_name: str, value to set ckpt_best_name
  1345. """
  1346. self.ckpt_best_name = ckpt_best_name
  1347. def set_ema(self, val: bool):
  1348. """
  1349. Setter for self.ema
  1350. :param val: bool, value to set ema
  1351. """
  1352. self.ema = val
  1353. def _get_context_methods(self, phase: Phase) -> ContextSgMethods:
  1354. """
  1355. Returns ContextSgMethods holding the methods that should be accessible through phase callbacks to the user at
  1356. the specific phase
  1357. :param phase: Phase, controls what methods should be returned.
  1358. :return: ContextSgMethods holding methods from self.
  1359. """
  1360. if phase in [Phase.PRE_TRAINING, Phase.TRAIN_EPOCH_START, Phase.TRAIN_EPOCH_END, Phase.VALIDATION_EPOCH_END,
  1361. Phase.VALIDATION_EPOCH_END, Phase.POST_TRAINING, Phase.VALIDATION_END_BEST_EPOCH]:
  1362. context_methods = ContextSgMethods(get_net=self.get_net,
  1363. set_net=self.set_net,
  1364. set_ckpt_best_name=self.set_ckpt_best_name,
  1365. reset_best_metric=self._reset_best_metric,
  1366. build_model=self.build_model,
  1367. validate_epoch=self._validate_epoch,
  1368. set_ema=self.set_ema)
  1369. else:
  1370. context_methods = ContextSgMethods()
  1371. return context_methods
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
  1. from omegaconf import DictConfig
  2. import hydra
  3. class Trainer:
  4. """
  5. Class for running SuperGradient's recipes.
  6. See train_from_recipe example in the examples directory to demonstrate it's usage.
  7. """
  8. @classmethod
  9. def train(cls, cfg: DictConfig) -> None:
  10. """
  11. Trains according to cfg recipe configuration.
  12. @param cfg: The parsed DictConfig from yaml recipe files
  13. @return: output of sg_model.train(...) (i.e results tuple)
  14. """
  15. # INSTANTIATE ALL OBJECTS IN CFG
  16. cfg = hydra.utils.instantiate(cfg)
  17. # CONNECT THE DATASET INTERFACE WITH DECI MODEL
  18. cfg.sg_model.connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
  19. # BUILD NETWORK
  20. cls.build_model(cfg)
  21. # TRAIN
  22. cfg.sg_model.train(training_params=cfg.training_hyperparams)
  23. @classmethod
  24. def build_model(cls, cfg):
  25. cfg.sg_model.build_model(cfg.architecture, arch_params=cfg.arch_params, checkpoint_params=cfg.checkpoint_params)
Discard
@@ -44,7 +44,7 @@ class Phase(Enum):
 
 
 class ContextSgMethods:
 class ContextSgMethods:
     """
     """
-    Class for delegating SgModel's methods, so that only the relevant ones are ("phase wise") are accessible.
+    Class for delegating Trainer's methods, so that only the relevant ones are ("phase wise") are accessible.
     """
     """
     def __init__(self, **methods):
     def __init__(self, **methods):
         for attr, attr_val in methods.items():
         for attr, attr_val in methods.items():
Discard
@@ -22,7 +22,7 @@ def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt
         - external_checkpoint_path when external_checkpoint_path != None
         - external_checkpoint_path when external_checkpoint_path != None
 
 
     @param source_ckpt_folder_name: The folder where the checkpoint is saved. When set to None- uses the experiment_name.
     @param source_ckpt_folder_name: The folder where the checkpoint is saved. When set to None- uses the experiment_name.
-    @param experiment_name: experiment name attr in sg_model
+    @param experiment_name: experiment name attr in trainer
     @param ckpt_name: checkpoint filename
     @param ckpt_name: checkpoint filename
     @param model_checkpoints_location: S3, local ot URL
     @param model_checkpoints_location: S3, local ot URL
     @param external_checkpoint_path: full path to checkpoint file (that might be located outside of super_gradients/checkpoints directory)
     @param external_checkpoint_path: full path to checkpoint file (that might be located outside of super_gradients/checkpoints directory)
Discard
@@ -72,8 +72,8 @@ def scaled_all_reduce(tensors: torch.Tensor, num_gpus: int):
 @torch.no_grad()
 @torch.no_grad()
 def compute_precise_bn_stats(model: nn.Module, loader: torch.utils.data.DataLoader, precise_bn_batch_size: int, num_gpus: int):
 def compute_precise_bn_stats(model: nn.Module, loader: torch.utils.data.DataLoader, precise_bn_batch_size: int, num_gpus: int):
     '''
     '''
-    :param model:                   The model being trained (ie: SgModel.net)
-    :param loader:                  Training dataloader (ie: SgModel.train_loader)
+    :param model:                   The model being trained (ie: Trainer.net)
+    :param loader:                  Training dataloader (ie: Trainer.train_loader)
     :param precise_bn_batch_size:   The effective batch size we want to calculate the batchnorm on. For example, if we are training a model
     :param precise_bn_batch_size:   The effective batch size we want to calculate the batchnorm on. For example, if we are training a model
                                     on 8 gpus, with a batch of 128 on each gpu, a good rule of thumb would be to give it 8192
                                     on 8 gpus, with a batch of 128 on each gpu, a good rule of thumb would be to give it 8192
                                     (ie: effective_batch_size * num_gpus = batch_per_gpu * num_gpus * num_gpus).
                                     (ie: effective_batch_size * num_gpus = batch_per_gpu * num_gpus * num_gpus).
Discard
Discard
@@ -88,7 +88,7 @@ def calibrate_model(model: torch.nn.Module, calib_data_loader: torch.utils.data.
     :param method:              str, One of [percentile, mse, entropy, max]. Statistics method for amax computation of the quantized modules
     :param method:              str, One of [percentile, mse, entropy, max]. Statistics method for amax computation of the quantized modules
                                 (Default=percentile).
                                 (Default=percentile).
     :param num_calib_batches:   int, number of batches to collect the statistics from.
     :param num_calib_batches:   int, number of batches to collect the statistics from.
-    :param percentile:          float, percentile value to use when SgModel,quant_modules_calib_method='percentile'. Discarded when other methods are used
+    :param percentile:          float, percentile value to use when Trainer,quant_modules_calib_method='percentile'. Discarded when other methods are used
                                 (Default=99.99).
                                 (Default=99.99).
 
 
     """
     """
@@ -201,7 +201,7 @@ class QATCallback(PhaseCallback):
         1. loads the best checkpoint then performs calibration.
         1. loads the best checkpoint then performs calibration.
         2. loads an external calibrated model (makes sense when start_epoch=0).
         2. loads an external calibrated model (makes sense when start_epoch=0).
 
 
-    Additionally, resets SgModel's best_metric and sets ckpt_best_name to 'qat_ckpt_best.pth' so best QAT checkpoints
+    Additionally, resets Trainer's best_metric and sets ckpt_best_name to 'qat_ckpt_best.pth' so best QAT checkpoints
      will be saved separately.
      will be saved separately.
 
 
     If performing calibration- the calibrated model is evaluated, and the metric_to_watch is logged under
     If performing calibration- the calibrated model is evaluated, and the metric_to_watch is logged under
@@ -225,7 +225,7 @@ class QATCallback(PhaseCallback):
 
 
         num_calib_batches: int, number of batches to collect the statistics from.
         num_calib_batches: int, number of batches to collect the statistics from.
 
 
-        percentile: float, percentile value to use when SgModel,quant_modules_calib_method='percentile'.
+        percentile: float, percentile value to use when Trainer,quant_modules_calib_method='percentile'.
          Discarded when other methods are used (Default=99.99).
          Discarded when other methods are used (Default=99.99).
 
 
 
 
Discard
@@ -5,9 +5,12 @@ import time
 from dataclasses import dataclass
 from dataclasses import dataclass
 from multiprocessing import Process
 from multiprocessing import Process
 from pathlib import Path
 from pathlib import Path
-from typing import Tuple, Union, Dict
+from typing import Tuple, Union, Dict, List, Sequence
 import random
 import random
 
 
+import inspect
+
+from super_gradients.common.abstractions.abstract_logger import get_logger
 from treelib import Tree
 from treelib import Tree
 from termcolor import colored
 from termcolor import colored
 import torch
 import torch
@@ -16,11 +19,12 @@ from torch.utils.tensorboard import SummaryWriter
 from super_gradients.training.exceptions.dataset_exceptions import UnsupportedBatchItemsFormat
 from super_gradients.training.exceptions.dataset_exceptions import UnsupportedBatchItemsFormat
 
 
 
 
-# TODO: These utils should move to sg_model package as internal (private) helper functions
+# TODO: These utils should move to sg_trainer package as internal (private) helper functions
 
 
 IS_BETTER_COLOR = {True: "green", False: "red"}
 IS_BETTER_COLOR = {True: "green", False: "red"}
 IS_GREATER_SYMBOLS = {True: "↗", False: "↘"}
 IS_GREATER_SYMBOLS = {True: "↗", False: "↘"}
 
 
+logger = get_logger(__name__)
 
 
 @dataclass
 @dataclass
 class MonitoredValue:
 class MonitoredValue:
@@ -329,3 +333,18 @@ def log_uncaught_exceptions(logger):
         logger.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))
         logger.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))
 
 
     sys.excepthook = handle_exception
     sys.excepthook = handle_exception
+
+
+def parse_args(cfg, arg_names: Union[List[str], callable]) -> dict:
+    """
+    parse args from a config.
+    unlike get_param(), in this case only parameters that appear in the config will override default params from the function's signature
+    """
+    if not isinstance(arg_names, Sequence):
+        arg_names = list(inspect.signature(arg_names).parameters.keys())
+
+    kwargs_dict = {}
+    for arg_name in arg_names:
+        if hasattr(cfg, arg_name) and getattr(cfg, arg_name) is not None:
+            kwargs_dict[arg_name] = getattr(cfg, arg_name)
+    return kwargs_dict
Discard
@@ -14,7 +14,7 @@ from tests.unit_tests.train_with_intialized_param_args_test import TrainWithInit
 from tests.unit_tests.pretrained_models_unit_test import PretrainedModelsUnitTest
 from tests.unit_tests.pretrained_models_unit_test import PretrainedModelsUnitTest
 from tests.unit_tests.lr_warmup_test import LRWarmupTest
 from tests.unit_tests.lr_warmup_test import LRWarmupTest
 from tests.unit_tests.kd_ema_test import KDEMATest
 from tests.unit_tests.kd_ema_test import KDEMATest
-from tests.unit_tests.kd_model_test import KDModelTest
+from tests.unit_tests.kd_trainer_test import KDTrainerTest
 from tests.unit_tests.dice_loss_test import DiceLossTest
 from tests.unit_tests.dice_loss_test import DiceLossTest
 from tests.unit_tests.iou_loss_test import IoULossTest
 from tests.unit_tests.iou_loss_test import IoULossTest
 from tests.unit_tests.update_param_groups_unit_test import UpdateParamGroupsTest
 from tests.unit_tests.update_param_groups_unit_test import UpdateParamGroupsTest
@@ -64,7 +64,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DiceLossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DiceLossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestViT))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestViT))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(KDEMATest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(KDEMATest))
-        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(KDModelTest))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(KDTrainerTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestYOLOX))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestYOLOX))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(InitializeWithDataloadersTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(InitializeWithDataloadersTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LRCooldownTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LRCooldownTest))
Discard
@@ -2,18 +2,18 @@ import unittest
 
 
 import super_gradients
 import super_gradients
 
 
-from super_gradients import SgModel
+from super_gradients import Trainer
 from super_gradients.training.datasets.dataset_interfaces import LibraryDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces import LibraryDatasetInterface
 
 
 
 
 class TestCifar10Trainer(unittest.TestCase):
 class TestCifar10Trainer(unittest.TestCase):
     def test_train_cifar10(self):
     def test_train_cifar10(self):
         super_gradients.init_trainer()
         super_gradients.init_trainer()
-        model = SgModel("test", model_checkpoints_location='local')
+        trainer = Trainer("test", model_checkpoints_location='local')
         cifar_10_dataset_interface = LibraryDatasetInterface(name="cifar10")
         cifar_10_dataset_interface = LibraryDatasetInterface(name="cifar10")
-        model.connect_dataset_interface(cifar_10_dataset_interface)
-        model.build_model("resnet18_cifar", arch_params={'num_classes': 10})
-        model.train(training_params={"max_epochs": 1})
+        trainer.connect_dataset_interface(cifar_10_dataset_interface)
+        trainer.build_model("resnet18_cifar", arch_params={'num_classes': 10})
+        trainer.train(training_params={"max_epochs": 1})
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -5,7 +5,7 @@ import numpy as np
 from PIL import Image
 from PIL import Image
 import tensorflow.keras as keras
 import tensorflow.keras as keras
 from super_gradients.training import MultiGPUMode
 from super_gradients.training import MultiGPUMode
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ExternalDatasetInterface, \
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ExternalDatasetInterface, \
     ImageNetDatasetInterface
     ImageNetDatasetInterface
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
@@ -123,12 +123,12 @@ class TestExternalDatasetInterface(unittest.TestCase):
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
 
 
         arch_params = {'num_classes': 1000}
         arch_params = {'num_classes': 1000}
-        model = SgModel("test", model_checkpoints_location='local',
-                        multi_gpu=MultiGPUMode.OFF)
-        model.connect_dataset_interface(dataset_interface=self.test_external_dataset_interface,
-                                        data_loader_num_workers=8)
-        model.build_model("resnet50", arch_params)
-        model.train(training_params=train_params)
+        trainer = Trainer("test", model_checkpoints_location='local',
+                          multi_gpu=MultiGPUMode.OFF)
+        trainer.connect_dataset_interface(dataset_interface=self.test_external_dataset_interface,
+                                          data_loader_num_workers=8)
+        trainer.build_model("resnet50", arch_params)
+        trainer.train(training_params=train_params)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -4,7 +4,7 @@ import unittest
 import super_gradients
 import super_gradients
 import torch
 import torch
 import os
 import os
-from super_gradients import SgModel, ClassificationTestDatasetInterface
+from super_gradients import Trainer, ClassificationTestDatasetInterface
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 
 
 
 
@@ -35,73 +35,65 @@ class TestTrainer(unittest.TestCase):
 
 
     @staticmethod
     @staticmethod
     def get_classification_trainer(name=''):
     def get_classification_trainer(name=''):
-        model = SgModel(name, model_checkpoints_location='local')
+        trainer = Trainer(name, model_checkpoints_location='local')
         dataset_params = {"batch_size": 4}
         dataset_params = {"batch_size": 4}
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
-        model.connect_dataset_interface(dataset)
-        model.build_model("resnet18_cifar")
-        return model
+        trainer.connect_dataset_interface(dataset)
+        trainer.build_model("resnet18_cifar")
+        return trainer
 
 
     def test_train(self):
     def test_train(self):
-        model = self.get_classification_trainer(self.folder_names[0])
-        model.train(training_params=self.training_params)
+        trainer = self.get_classification_trainer(self.folder_names[0])
+        trainer.train(training_params=self.training_params)
 
 
     def test_save_load(self):
     def test_save_load(self):
-        model = self.get_classification_trainer(self.folder_names[1])
-        model.train(training_params=self.training_params)
-        model.build_model("resnet18_cifar", checkpoint_params={'load_checkpoint': True})
+        trainer = self.get_classification_trainer(self.folder_names[1])
+        trainer.train(training_params=self.training_params)
+        trainer.build_model("resnet18_cifar", checkpoint_params={'load_checkpoint': True})
 
 
     def test_load_only_weights_from_ckpt(self):
     def test_load_only_weights_from_ckpt(self):
         # Create a checkpoint with 100% accuracy
         # Create a checkpoint with 100% accuracy
-        model = self.get_classification_trainer(self.folder_names[2])
+        trainer = self.get_classification_trainer(self.folder_names[2])
         params = self.training_params.copy()
         params = self.training_params.copy()
 
 
         params['max_epochs'] = 3
         params['max_epochs'] = 3
-        model.train(training_params=params)
+        trainer.train(training_params=params)
         # Build a model that continues the training
         # Build a model that continues the training
-        model = self.get_classification_trainer(self.folder_names[3])
-        model.build_model('resnet18_cifar', checkpoint_params={"load_checkpoint": True, "load_weights_only": False,
-                                                               "source_ckpt_folder_name": self.folder_names[2]}
-                          )
-        self.assertTrue(model.best_metric > -1)
-        self.assertTrue(model.start_epoch != 0)
+        trainer = self.get_classification_trainer(self.folder_names[3])
+        trainer.build_model('resnet18_cifar', checkpoint_params={"load_checkpoint": True, "load_weights_only": False,
+                                                                 "source_ckpt_folder_name": self.folder_names[2]}
+                            )
+        self.assertTrue(trainer.best_metric > -1)
+        self.assertTrue(trainer.start_epoch != 0)
         # start_epoch is not initialized, adding to max_epochs
         # start_epoch is not initialized, adding to max_epochs
         self.training_params['max_epochs'] += 3
         self.training_params['max_epochs'] += 3
-        model.train(training_params=self.training_params)
+        trainer.train(training_params=self.training_params)
         # Build a model that loads the weights and starts from scratch
         # Build a model that loads the weights and starts from scratch
-        model = self.get_classification_trainer(self.folder_names[4])
-        model.build_model('resnet18_cifar', checkpoint_params={"load_checkpoint": True, "load_weights_only": True,
-                                                               "source_ckpt_folder_name": self.folder_names[2]}
-                          )
-        self.assertTrue(model.best_metric == -1)
-        self.assertTrue(model.start_epoch == 0)
+        trainer = self.get_classification_trainer(self.folder_names[4])
+        trainer.build_model('resnet18_cifar', checkpoint_params={"load_checkpoint": True, "load_weights_only": True,
+                                                                 "source_ckpt_folder_name": self.folder_names[2]}
+                            )
+        self.assertTrue(trainer.best_metric == -1)
+        self.assertTrue(trainer.start_epoch == 0)
         self.training_params['max_epochs'] += 3
         self.training_params['max_epochs'] += 3
-        model.train(training_params=self.training_params)
+        trainer.train(training_params=self.training_params)
 
 
     def test_checkpoint_content(self):
     def test_checkpoint_content(self):
         """VERIFY THAT ALL CHECKPOINTS ARE SAVED AND CONTAIN ALL THE EXPECTED KEYS"""
         """VERIFY THAT ALL CHECKPOINTS ARE SAVED AND CONTAIN ALL THE EXPECTED KEYS"""
-        model = self.get_classification_trainer(self.folder_names[5])
+        trainer = self.get_classification_trainer(self.folder_names[5])
         params = self.training_params.copy()
         params = self.training_params.copy()
         params["save_ckpt_epoch_list"] = [1]
         params["save_ckpt_epoch_list"] = [1]
-        model.train(training_params=params)
+        trainer.train(training_params=params)
         ckpt_filename = ['ckpt_best.pth', 'ckpt_latest.pth', 'ckpt_epoch_1.pth']
         ckpt_filename = ['ckpt_best.pth', 'ckpt_latest.pth', 'ckpt_epoch_1.pth']
-        ckpt_paths = [os.path.join(model.checkpoints_dir_path, suf) for suf in ckpt_filename]
+        ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
         for ckpt_path in ckpt_paths:
         for ckpt_path in ckpt_paths:
             ckpt = torch.load(ckpt_path)
             ckpt = torch.load(ckpt_path)
             self.assertListEqual(['net', 'acc', 'epoch', 'optimizer_state_dict', 'scaler_state_dict'],
             self.assertListEqual(['net', 'acc', 'epoch', 'optimizer_state_dict', 'scaler_state_dict'],
                                  list(ckpt.keys()))
                                  list(ckpt.keys()))
-        model._save_checkpoint()
-        weights_only = torch.load(os.path.join(model.checkpoints_dir_path, 'ckpt_latest_weights_only.pth'))
+        trainer._save_checkpoint()
+        weights_only = torch.load(os.path.join(trainer.checkpoints_dir_path, 'ckpt_latest_weights_only.pth'))
         self.assertListEqual(['net'], list(weights_only.keys()))
         self.assertListEqual(['net'], list(weights_only.keys()))
 
 
-    def test_predict(self):
-        model = self.get_classification_trainer(self.folder_names[6])
-        inputs = torch.randn((5, 3, 32, 32))
-        targets = torch.randint(0, 5, (5, 1))
-        model.predict(inputs=inputs, targets=targets)
-        model.predict(inputs=inputs, targets=targets, half=True)
-        model.predict(inputs=inputs, targets=targets, half=False, verbose=True)
-
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
     unittest.main()
     unittest.main()
Discard
@@ -3,7 +3,7 @@ from enum import Enum
 import re
 import re
 
 
 from super_gradients import (
 from super_gradients import (
-    SgModel,
+    Trainer,
     ClassificationTestDatasetInterface,
     ClassificationTestDatasetInterface,
     SegmentationTestDatasetInterface,
     SegmentationTestDatasetInterface,
 )
 )
@@ -68,13 +68,13 @@ class ConversionCallbackTest(unittest.TestCase):
                 "phase_callbacks": phase_callbacks,
                 "phase_callbacks": phase_callbacks,
             }
             }
 
 
-            model = SgModel(f"{architecture}_example", model_checkpoints_location="local", ckpt_root_dir=checkpoint_dir)
+            trainer = Trainer(f"{architecture}_example", model_checkpoints_location="local", ckpt_root_dir=checkpoint_dir)
             dataset = ClassificationTestDatasetInterface(dataset_params={"batch_size": 10})
             dataset = ClassificationTestDatasetInterface(dataset_params={"batch_size": 10})
 
 
-            model.connect_dataset_interface(dataset, data_loader_num_workers=0)
-            model.build_model(architecture=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
+            trainer.connect_dataset_interface(dataset, data_loader_num_workers=0)
+            trainer.build_model(architecture=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
             try:
             try:
-                model.train(train_params)
+                trainer.train(train_params)
             except Exception as e:
             except Exception as e:
                 self.fail(f"Model training didn't succeed due to {e}")
                 self.fail(f"Model training didn't succeed due to {e}")
             else:
             else:
@@ -103,9 +103,9 @@ class ConversionCallbackTest(unittest.TestCase):
         for architecture in SEMANTIC_SEGMENTATION:
         for architecture in SEMANTIC_SEGMENTATION:
             model_meta_data = generate_model_metadata(architecture=architecture, task=Task.SEMANTIC_SEGMENTATION)
             model_meta_data = generate_model_metadata(architecture=architecture, task=Task.SEMANTIC_SEGMENTATION)
             dataset = SegmentationTestDatasetInterface(dataset_params={"batch_size": 10})
             dataset = SegmentationTestDatasetInterface(dataset_params={"batch_size": 10})
-            model = SgModel(f"{architecture}_example", model_checkpoints_location="local", ckpt_root_dir=checkpoint_dir)
-            model.connect_dataset_interface(dataset, data_loader_num_workers=0)
-            model.build_model(architecture=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
+            trainer = Trainer(f"{architecture}_example", model_checkpoints_location="local", ckpt_root_dir=checkpoint_dir)
+            trainer.connect_dataset_interface(dataset, data_loader_num_workers=0)
+            trainer.build_model(architecture=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
 
 
             phase_callbacks = [
             phase_callbacks = [
                 ModelConversionCheckCallback(model_meta_data=model_meta_data, opset_version=11, rtol=1, atol=1),
                 ModelConversionCheckCallback(model_meta_data=model_meta_data, opset_version=11, rtol=1, atol=1),
@@ -129,7 +129,7 @@ class ConversionCallbackTest(unittest.TestCase):
             train_params.update(custom_config)
             train_params.update(custom_config)
 
 
             try:
             try:
-                model.train(train_params)
+                trainer.train(train_params)
             except Exception as e:
             except Exception as e:
                 self.fail(f"Model training didn't succeed for {architecture} due to {e}")
                 self.fail(f"Model training didn't succeed for {architecture} due to {e}")
             else:
             else:
Discard
@@ -1,5 +1,5 @@
 import unittest
 import unittest
-from super_gradients import SgModel, \
+from super_gradients import Trainer, \
     ClassificationTestDatasetInterface
     ClassificationTestDatasetInterface
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.models import ResNet18
 from super_gradients.training.models import ResNet18
@@ -10,12 +10,12 @@ from deci_lab_client.models import Metric, QuantizationLevel, ModelMetadata, Opt
 
 
 class DeciLabUploadTest(unittest.TestCase):
 class DeciLabUploadTest(unittest.TestCase):
     def setUp(self) -> None:
     def setUp(self) -> None:
-        self.model = SgModel("deci_lab_export_test_model", model_checkpoints_location='local')
+        self.trainer = Trainer("deci_lab_export_test_model", model_checkpoints_location='local')
         dataset = ClassificationTestDatasetInterface(dataset_params={"batch_size": 10})
         dataset = ClassificationTestDatasetInterface(dataset_params={"batch_size": 10})
-        self.model.connect_dataset_interface(dataset)
+        self.trainer.connect_dataset_interface(dataset)
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
         self.optimizer = SGD(params=net.parameters(), lr=0.1)
         self.optimizer = SGD(params=net.parameters(), lr=0.1)
-        self.model.build_model(net)
+        self.trainer.build_model(net)
 
 
     def test_train_with_deci_lab_integration(self):
     def test_train_with_deci_lab_integration(self):
         model_meta_data = ModelMetadata(name='model_for_deci_lab_upload_test',
         model_meta_data = ModelMetadata(name='model_for_deci_lab_upload_test',
@@ -50,7 +50,7 @@ class DeciLabUploadTest(unittest.TestCase):
                         "greater_metric_to_watch_is_better": True,
                         "greater_metric_to_watch_is_better": True,
                         "phase_callbacks": [model_conversion_callback, deci_lab_callback]}
                         "phase_callbacks": [model_conversion_callback, deci_lab_callback]}
 
 
-        self.model.train(train_params)
+        self.trainer.train(train_params)
 
 
         # CLEANUP
         # CLEANUP
 
 
Discard
@@ -1,6 +1,6 @@
 from super_gradients import ClassificationTestDatasetInterface
 from super_gradients import ClassificationTestDatasetInterface
 from super_gradients.training import MultiGPUMode
 from super_gradients.training import MultiGPUMode
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 import unittest
 import unittest
 
 
@@ -23,11 +23,11 @@ class CallWrapper:
 class EMAIntegrationTest(unittest.TestCase):
 class EMAIntegrationTest(unittest.TestCase):
 
 
     def _init_model(self) -> None:
     def _init_model(self) -> None:
-        self.model = SgModel("resnet18_cifar_ema_test", model_checkpoints_location='local',
-                             device='cpu', multi_gpu=MultiGPUMode.OFF)
+        self.trainer = Trainer("resnet18_cifar_ema_test", model_checkpoints_location='local',
+                               device='cpu', multi_gpu=MultiGPUMode.OFF)
         dataset_interface = ClassificationTestDatasetInterface({"batch_size": 32})
         dataset_interface = ClassificationTestDatasetInterface({"batch_size": 32})
-        self.model.connect_dataset_interface(dataset_interface, 8)
-        self.model.build_model("resnet18_cifar")
+        self.trainer.connect_dataset_interface(dataset_interface, 8)
+        self.trainer.build_model("resnet18_cifar")
 
 
     @classmethod
     @classmethod
     def tearDownClass(cls) -> None:
     def tearDownClass(cls) -> None:
@@ -57,17 +57,17 @@ class EMAIntegrationTest(unittest.TestCase):
                            "greater_metric_to_watch_is_better": True}
                            "greater_metric_to_watch_is_better": True}
 
 
         def before_test():
         def before_test():
-            self.assertEqual(self.model.net, self.model.ema_model.ema)
+            self.assertEqual(self.trainer.net, self.trainer.ema_model.ema)
 
 
         def before_train_epoch():
         def before_train_epoch():
-            self.assertNotEqual(self.model.net, self.model.ema_model.ema)
+            self.assertNotEqual(self.trainer.net, self.trainer.ema_model.ema)
 
 
-        self.model.test = CallWrapper(self.model.test, check_before=before_test)
-        self.model._train_epoch = CallWrapper(self.model._train_epoch, check_before=before_train_epoch)
+        self.trainer.test = CallWrapper(self.trainer.test, check_before=before_test)
+        self.trainer._train_epoch = CallWrapper(self.trainer._train_epoch, check_before=before_train_epoch)
 
 
-        self.model.train(training_params=training_params)
+        self.trainer.train(training_params=training_params)
 
 
-        self.assertIsNotNone(self.model.ema_model)
+        self.assertIsNotNone(self.trainer.ema_model)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -1,7 +1,7 @@
 import shutil
 import shutil
 import unittest
 import unittest
 import os
 import os
-from super_gradients import SgModel, ClassificationTestDatasetInterface
+from super_gradients import Trainer, ClassificationTestDatasetInterface
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 
 
 
 
@@ -26,37 +26,37 @@ class LRTest(unittest.TestCase):
 
 
     @staticmethod
     @staticmethod
     def get_trainer(name=''):
     def get_trainer(name=''):
-        model = SgModel(name, model_checkpoints_location='local')
+        trainer = Trainer(name, model_checkpoints_location='local')
         dataset_params = {"batch_size": 4}
         dataset_params = {"batch_size": 4}
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
-        model.connect_dataset_interface(dataset)
-        model.build_model("resnet18_cifar")
-        return model
+        trainer.connect_dataset_interface(dataset)
+        trainer.build_model("resnet18_cifar")
+        return trainer
 
 
     def test_function_lr(self):
     def test_function_lr(self):
-        model = self.get_trainer(self.folder_name)
+        trainer = self.get_trainer(self.folder_name)
 
 
         def test_lr_function(initial_lr, epoch, iter, max_epoch, iters_per_epoch, **kwargs):
         def test_lr_function(initial_lr, epoch, iter, max_epoch, iters_per_epoch, **kwargs):
             return initial_lr * (1 - ((epoch * iters_per_epoch + iter) / (max_epoch * iters_per_epoch)))
             return initial_lr * (1 - ((epoch * iters_per_epoch + iter) / (max_epoch * iters_per_epoch)))
 
 
         # test if we are able that lr_function supports functions with this structure
         # test if we are able that lr_function supports functions with this structure
         training_params = {**self.training_params, "lr_mode": "function", "lr_schedule_function": test_lr_function}
         training_params = {**self.training_params, "lr_mode": "function", "lr_schedule_function": test_lr_function}
-        model.train(training_params=training_params)
+        trainer.train(training_params=training_params)
 
 
         # test that we assert lr_function is callable
         # test that we assert lr_function is callable
         training_params = {**self.training_params, "lr_mode": "function"}
         training_params = {**self.training_params, "lr_mode": "function"}
         with self.assertRaises(AssertionError):
         with self.assertRaises(AssertionError):
-            model.train(training_params=training_params)
+            trainer.train(training_params=training_params)
 
 
     def test_cosine_lr(self):
     def test_cosine_lr(self):
-        model = self.get_trainer(self.folder_name)
+        trainer = self.get_trainer(self.folder_name)
         training_params = {**self.training_params, "lr_mode": "cosine", "cosine_final_lr_ratio": 0.01}
         training_params = {**self.training_params, "lr_mode": "cosine", "cosine_final_lr_ratio": 0.01}
-        model.train(training_params=training_params)
+        trainer.train(training_params=training_params)
 
 
     def test_step_lr(self):
     def test_step_lr(self):
-        model = self.get_trainer(self.folder_name)
+        trainer = self.get_trainer(self.folder_name)
         training_params = {**self.training_params, "lr_mode": "step", "lr_decay_factor": 0.1, "lr_updates": [4]}
         training_params = {**self.training_params, "lr_mode": "step", "lr_decay_factor": 0.1, "lr_updates": [4]}
-        model.train(training_params=training_params)
+        trainer.train(training_params=training_params)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -1,7 +1,7 @@
 import unittest
 import unittest
 import super_gradients
 import super_gradients
 from super_gradients.training import MultiGPUMode
 from super_gradients.training import MultiGPUMode
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ImageNetDatasetInterface, \
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ImageNetDatasetInterface, \
     ClassificationTestDatasetInterface, CityscapesDatasetInterface, SegmentationTestDatasetInterface, \
     ClassificationTestDatasetInterface, CityscapesDatasetInterface, SegmentationTestDatasetInterface, \
     CoCoSegmentationDatasetInterface, DetectionTestDatasetInterface
     CoCoSegmentationDatasetInterface, DetectionTestDatasetInterface
@@ -324,7 +324,7 @@ class PretrainedModelsTest(unittest.TestCase):
         }
         }
 
 
     def test_pretrained_resnet50_imagenet(self):
     def test_pretrained_resnet50_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_resnet50', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet50', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.build_model("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"],
         trainer.build_model("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"],
@@ -334,7 +334,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet50"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet50"], delta=0.001)
 
 
     def test_transfer_learning_resnet50_imagenet(self):
     def test_transfer_learning_resnet50_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_resnet50_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet50_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.build_model("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"],
         trainer.build_model("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"],
@@ -342,7 +342,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.transfer_classification_train_params)
         trainer.train(training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_resnet34_imagenet(self):
     def test_pretrained_resnet34_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_resnet34', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet34', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.build_model("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"],
         trainer.build_model("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"],
@@ -352,7 +352,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet34"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet34"], delta=0.001)
 
 
     def test_transfer_learning_resnet34_imagenet(self):
     def test_transfer_learning_resnet34_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_resnet34_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet34_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.build_model("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"],
         trainer.build_model("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"],
@@ -360,7 +360,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.transfer_classification_train_params)
         trainer.train(training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_resnet18_imagenet(self):
     def test_pretrained_resnet18_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_resnet18', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet18', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.build_model("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"],
         trainer.build_model("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"],
@@ -370,7 +370,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet18"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet18"], delta=0.001)
 
 
     def test_transfer_learning_resnet18_imagenet(self):
     def test_transfer_learning_resnet18_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_resnet18_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet18_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.build_model("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"],
         trainer.build_model("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"],
@@ -378,7 +378,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.transfer_classification_train_params)
         trainer.train(training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_regnetY800_imagenet(self):
     def test_pretrained_regnetY800_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_regnetY800', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY800', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.build_model("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"],
         trainer.build_model("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"],
@@ -388,7 +388,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY800"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY800"], delta=0.001)
 
 
     def test_transfer_learning_regnetY800_imagenet(self):
     def test_transfer_learning_regnetY800_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_regnetY800_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY800_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.build_model("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"],
         trainer.build_model("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"],
@@ -396,7 +396,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.transfer_classification_train_params)
         trainer.train(training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_regnetY600_imagenet(self):
     def test_pretrained_regnetY600_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_regnetY600', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY600', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.build_model("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"],
         trainer.build_model("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"],
@@ -406,7 +406,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY600"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY600"], delta=0.001)
 
 
     def test_transfer_learning_regnetY600_imagenet(self):
     def test_transfer_learning_regnetY600_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_regnetY600_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY600_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.build_model("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"],
         trainer.build_model("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"],
@@ -414,7 +414,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.transfer_classification_train_params)
         trainer.train(training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_regnetY400_imagenet(self):
     def test_pretrained_regnetY400_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_regnetY400', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY400', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.build_model("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"],
         trainer.build_model("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"],
@@ -424,7 +424,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY400"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY400"], delta=0.001)
 
 
     def test_transfer_learning_regnetY400_imagenet(self):
     def test_transfer_learning_regnetY400_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_regnetY400_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY400_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.build_model("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"],
         trainer.build_model("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"],
@@ -432,7 +432,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.transfer_classification_train_params)
         trainer.train(training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_regnetY200_imagenet(self):
     def test_pretrained_regnetY200_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_regnetY200', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY200', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.build_model("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"],
         trainer.build_model("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"],
@@ -442,7 +442,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY200"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY200"], delta=0.001)
 
 
     def test_transfer_learning_regnetY200_imagenet(self):
     def test_transfer_learning_regnetY200_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_regnetY200_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY200_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.build_model("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"],
         trainer.build_model("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"],
@@ -450,7 +450,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.transfer_classification_train_params)
         trainer.train(training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_repvgg_a0_imagenet(self):
     def test_pretrained_repvgg_a0_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_repvgg_a0', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_repvgg_a0', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.build_model("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"],
         trainer.build_model("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"],
@@ -460,7 +460,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["repvgg_a0"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["repvgg_a0"], delta=0.001)
 
 
     def test_transfer_learning_repvgg_a0_imagenet(self):
     def test_transfer_learning_repvgg_a0_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_repvgg_a0_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_repvgg_a0_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.build_model("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"],
         trainer.build_model("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"],
@@ -468,7 +468,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.transfer_classification_train_params)
         trainer.train(training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_regseg48_cityscapes(self):
     def test_pretrained_regseg48_cityscapes(self):
-        trainer = SgModel('cityscapes_pretrained_regseg48', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_regseg48', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.cityscapes_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.cityscapes_dataset, data_loader_num_workers=8)
         trainer.build_model("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"],
         trainer.build_model("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"],
@@ -479,7 +479,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["regseg48"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["regseg48"], delta=0.001)
 
 
     def test_transfer_learning_regseg48_cityscapes(self):
     def test_transfer_learning_regseg48_cityscapes(self):
-        trainer = SgModel('regseg48_cityscapes_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('regseg48_cityscapes_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.build_model("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"],
         trainer.build_model("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"],
@@ -487,7 +487,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.regseg_transfer_segmentation_train_params)
         trainer.train(training_params=self.regseg_transfer_segmentation_train_params)
 
 
     def test_pretrained_ddrnet23_cityscapes(self):
     def test_pretrained_ddrnet23_cityscapes(self):
-        trainer = SgModel('cityscapes_pretrained_ddrnet23', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_ddrnet23', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.cityscapes_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.cityscapes_dataset, data_loader_num_workers=8)
         trainer.build_model("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
         trainer.build_model("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
@@ -498,7 +498,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23"], delta=0.001)
 
 
     def test_pretrained_ddrnet23_slim_cityscapes(self):
     def test_pretrained_ddrnet23_slim_cityscapes(self):
-        trainer = SgModel('cityscapes_pretrained_ddrnet23_slim', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_ddrnet23_slim', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.cityscapes_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.cityscapes_dataset, data_loader_num_workers=8)
         trainer.build_model("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
         trainer.build_model("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
@@ -509,7 +509,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23_slim"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23_slim"], delta=0.001)
 
 
     def test_transfer_learning_ddrnet23_cityscapes(self):
     def test_transfer_learning_ddrnet23_cityscapes(self):
-        trainer = SgModel('cityscapes_pretrained_ddrnet23_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_ddrnet23_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.build_model("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
         trainer.build_model("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
@@ -517,7 +517,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.ddrnet_transfer_segmentation_train_params)
         trainer.train(training_params=self.ddrnet_transfer_segmentation_train_params)
 
 
     def test_transfer_learning_ddrnet23_slim_cityscapes(self):
     def test_transfer_learning_ddrnet23_slim_cityscapes(self):
-        trainer = SgModel('cityscapes_pretrained_ddrnet23_slim_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_ddrnet23_slim_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.build_model("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
         trainer.build_model("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
@@ -525,7 +525,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.ddrnet_transfer_segmentation_train_params)
         trainer.train(training_params=self.ddrnet_transfer_segmentation_train_params)
 
 
     def test_pretrained_coco_segmentation_subclass_pretrained_shelfnet34_lw(self):
     def test_pretrained_coco_segmentation_subclass_pretrained_shelfnet34_lw(self):
-        trainer = SgModel('coco_segmentation_subclass_pretrained_shelfnet34_lw', model_checkpoints_location='local',
+        trainer = Trainer('coco_segmentation_subclass_pretrained_shelfnet34_lw', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.coco_segmentation_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.coco_segmentation_dataset, data_loader_num_workers=8)
         trainer.build_model("shelfnet34_lw",
         trainer.build_model("shelfnet34_lw",
@@ -536,7 +536,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_segmentation_subclass_pretrained_mious["shelfnet34_lw"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_segmentation_subclass_pretrained_mious["shelfnet34_lw"], delta=0.001)
 
 
     def test_pretrained_efficientnet_b0_imagenet(self):
     def test_pretrained_efficientnet_b0_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_efficientnet_b0', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_efficientnet_b0', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.build_model("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"],
         trainer.build_model("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"],
@@ -546,7 +546,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["efficientnet_b0"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["efficientnet_b0"], delta=0.001)
 
 
     def test_transfer_learning_efficientnet_b0_imagenet(self):
     def test_transfer_learning_efficientnet_b0_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_efficientnet_b0_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_efficientnet_b0_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.build_model("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"],
         trainer.build_model("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"],
@@ -554,7 +554,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.transfer_classification_train_params)
         trainer.train(training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_ssd_lite_mobilenet_v2_coco(self):
     def test_pretrained_ssd_lite_mobilenet_v2_coco(self):
-        trainer = SgModel('coco_ssd_lite_mobilenet_v2', model_checkpoints_location='local',
+        trainer = Trainer('coco_ssd_lite_mobilenet_v2', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.coco_dataset['ssd_mobilenet'], data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.coco_dataset['ssd_mobilenet'], data_loader_num_workers=8)
         trainer.build_model("ssd_lite_mobilenet_v2",
         trainer.build_model("ssd_lite_mobilenet_v2",
@@ -568,7 +568,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["ssd_lite_mobilenet_v2"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["ssd_lite_mobilenet_v2"], delta=0.001)
 
 
     def test_transfer_learning_ssd_lite_mobilenet_v2_coco(self):
     def test_transfer_learning_ssd_lite_mobilenet_v2_coco(self):
-        trainer = SgModel('coco_ssd_lite_mobilenet_v2_transfer_learning',
+        trainer = Trainer('coco_ssd_lite_mobilenet_v2_transfer_learning',
                           model_checkpoints_location='local', multi_gpu=MultiGPUMode.OFF)
                           model_checkpoints_location='local', multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_detection_dataset,
         trainer.connect_dataset_interface(self.transfer_detection_dataset,
                                           data_loader_num_workers=8)
                                           data_loader_num_workers=8)
@@ -580,7 +580,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.transfer_detection_train_params['ssd_lite_mobilenet_v2'])
         trainer.train(training_params=self.transfer_detection_train_params['ssd_lite_mobilenet_v2'])
 
 
     def test_pretrained_ssd_mobilenet_v1_coco(self):
     def test_pretrained_ssd_mobilenet_v1_coco(self):
-        trainer = SgModel('coco_ssd_mobilenet_v1', model_checkpoints_location='local',
+        trainer = Trainer('coco_ssd_mobilenet_v1', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.coco_dataset['ssd_mobilenet'], data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.coco_dataset['ssd_mobilenet'], data_loader_num_workers=8)
         trainer.build_model("ssd_mobilenet_v1",
         trainer.build_model("ssd_mobilenet_v1",
@@ -595,7 +595,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["coco_ssd_mobilenet_v1"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["coco_ssd_mobilenet_v1"], delta=0.001)
 
 
     def test_pretrained_yolox_s_coco(self):
     def test_pretrained_yolox_s_coco(self):
-        trainer = SgModel('yolox_s', model_checkpoints_location='local',
+        trainer = Trainer('yolox_s', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
         trainer.build_model("yolox_s",
         trainer.build_model("yolox_s",
@@ -607,7 +607,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_s"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_s"], delta=0.001)
 
 
     def test_pretrained_yolox_m_coco(self):
     def test_pretrained_yolox_m_coco(self):
-        trainer = SgModel('yolox_m', model_checkpoints_location='local',
+        trainer = Trainer('yolox_m', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
         trainer.build_model("yolox_m",
         trainer.build_model("yolox_m",
@@ -619,7 +619,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_m"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_m"], delta=0.001)
 
 
     def test_pretrained_yolox_l_coco(self):
     def test_pretrained_yolox_l_coco(self):
-        trainer = SgModel('yolox_l', model_checkpoints_location='local',
+        trainer = Trainer('yolox_l', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
         trainer.build_model("yolox_l",
         trainer.build_model("yolox_l",
@@ -631,7 +631,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_l"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_l"], delta=0.001)
 
 
     def test_pretrained_yolox_n_coco(self):
     def test_pretrained_yolox_n_coco(self):
-        trainer = SgModel('yolox_n', model_checkpoints_location='local',
+        trainer = Trainer('yolox_n', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
         trainer.build_model("yolox_n",
         trainer.build_model("yolox_n",
@@ -643,7 +643,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_n"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_n"], delta=0.001)
 
 
     def test_pretrained_yolox_t_coco(self):
     def test_pretrained_yolox_t_coco(self):
-        trainer = SgModel('yolox_t', model_checkpoints_location='local',
+        trainer = Trainer('yolox_t', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
         trainer.build_model("yolox_t",
         trainer.build_model("yolox_t",
@@ -655,7 +655,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_t"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_t"], delta=0.001)
 
 
     def test_transfer_learning_yolox_n_coco(self):
     def test_transfer_learning_yolox_n_coco(self):
-        trainer = SgModel('test_transfer_learning_yolox_n_coco',
+        trainer = Trainer('test_transfer_learning_yolox_n_coco',
                           model_checkpoints_location='local',
                           model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_detection_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_detection_dataset, data_loader_num_workers=8)
@@ -663,7 +663,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.transfer_detection_train_params["yolox"])
         trainer.train(training_params=self.transfer_detection_train_params["yolox"])
 
 
     def test_transfer_learning_mobilenet_v3_large_imagenet(self):
     def test_transfer_learning_mobilenet_v3_large_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_mobilenet_v3_large_transfer_learning',
+        trainer = Trainer('imagenet_pretrained_mobilenet_v3_large_transfer_learning',
                           model_checkpoints_location='local',
                           model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
@@ -672,7 +672,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.transfer_classification_train_params)
         trainer.train(training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_mobilenet_v3_large_imagenet(self):
     def test_pretrained_mobilenet_v3_large_imagenet(self):
-        trainer = SgModel('imagenet_mobilenet_v3_large', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_mobilenet_v3_large', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.build_model("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
         trainer.build_model("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
@@ -682,7 +682,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_large"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_large"], delta=0.001)
 
 
     def test_transfer_learning_mobilenet_v3_small_imagenet(self):
     def test_transfer_learning_mobilenet_v3_small_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_mobilenet_v3_small_transfer_learning',
+        trainer = Trainer('imagenet_pretrained_mobilenet_v3_small_transfer_learning',
                           model_checkpoints_location='local',
                           model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
@@ -691,7 +691,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.transfer_classification_train_params)
         trainer.train(training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_mobilenet_v3_small_imagenet(self):
     def test_pretrained_mobilenet_v3_small_imagenet(self):
-        trainer = SgModel('imagenet_mobilenet_v3_small', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_mobilenet_v3_small', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.build_model("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
         trainer.build_model("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
@@ -701,7 +701,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_small"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_small"], delta=0.001)
 
 
     def test_transfer_learning_mobilenet_v2_imagenet(self):
     def test_transfer_learning_mobilenet_v2_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_mobilenet_v2_transfer_learning',
+        trainer = Trainer('imagenet_pretrained_mobilenet_v2_transfer_learning',
                           model_checkpoints_location='local',
                           model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
@@ -710,7 +710,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.transfer_classification_train_params)
         trainer.train(training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_mobilenet_v2_imagenet(self):
     def test_pretrained_mobilenet_v2_imagenet(self):
-        trainer = SgModel('imagenet_mobilenet_v2', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_mobilenet_v2', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
         trainer.build_model("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
         trainer.build_model("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
@@ -720,7 +720,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v2"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v2"], delta=0.001)
 
 
     def test_pretrained_stdc1_seg50_cityscapes(self):
     def test_pretrained_stdc1_seg50_cityscapes(self):
-        trainer = SgModel('cityscapes_pretrained_stdc1_seg50', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc1_seg50', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled50, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled50, data_loader_num_workers=8)
         trainer.build_model("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
         trainer.build_model("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
@@ -731,7 +731,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg50"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg50"], delta=0.001)
 
 
     def test_transfer_learning_stdc1_seg50_cityscapes(self):
     def test_transfer_learning_stdc1_seg50_cityscapes(self):
-        trainer = SgModel('cityscapes_pretrained_stdc1_seg50_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc1_seg50_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.build_model("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
         trainer.build_model("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
@@ -739,7 +739,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.stdc_transfer_segmentation_train_params)
         trainer.train(training_params=self.stdc_transfer_segmentation_train_params)
 
 
     def test_pretrained_stdc1_seg75_cityscapes(self):
     def test_pretrained_stdc1_seg75_cityscapes(self):
-        trainer = SgModel('cityscapes_pretrained_stdc1_seg75', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc1_seg75', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled75, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled75, data_loader_num_workers=8)
         trainer.build_model("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
         trainer.build_model("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
@@ -750,7 +750,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg75"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg75"], delta=0.001)
 
 
     def test_transfer_learning_stdc1_seg75_cityscapes(self):
     def test_transfer_learning_stdc1_seg75_cityscapes(self):
-        trainer = SgModel('cityscapes_pretrained_stdc1_seg75_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc1_seg75_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.build_model("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
         trainer.build_model("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
@@ -758,7 +758,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.stdc_transfer_segmentation_train_params)
         trainer.train(training_params=self.stdc_transfer_segmentation_train_params)
 
 
     def test_pretrained_stdc2_seg50_cityscapes(self):
     def test_pretrained_stdc2_seg50_cityscapes(self):
-        trainer = SgModel('cityscapes_pretrained_stdc2_seg50', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc2_seg50', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled50, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled50, data_loader_num_workers=8)
         trainer.build_model("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
         trainer.build_model("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
@@ -769,7 +769,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg50"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg50"], delta=0.001)
 
 
     def test_transfer_learning_stdc2_seg50_cityscapes(self):
     def test_transfer_learning_stdc2_seg50_cityscapes(self):
-        trainer = SgModel('cityscapes_pretrained_stdc2_seg50_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc2_seg50_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.build_model("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
         trainer.build_model("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
@@ -777,7 +777,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.stdc_transfer_segmentation_train_params)
         trainer.train(training_params=self.stdc_transfer_segmentation_train_params)
 
 
     def test_pretrained_stdc2_seg75_cityscapes(self):
     def test_pretrained_stdc2_seg75_cityscapes(self):
-        trainer = SgModel('cityscapes_pretrained_stdc2_seg75', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc2_seg75', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled75, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled75, data_loader_num_workers=8)
         trainer.build_model("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
         trainer.build_model("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
@@ -788,7 +788,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg75"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg75"], delta=0.001)
 
 
     def test_transfer_learning_stdc2_seg75_cityscapes(self):
     def test_transfer_learning_stdc2_seg75_cityscapes(self):
-        trainer = SgModel('cityscapes_pretrained_stdc2_seg75_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc2_seg75_transfer_learning', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
         trainer.build_model("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
         trainer.build_model("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
@@ -796,7 +796,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.stdc_transfer_segmentation_train_params)
         trainer.train(training_params=self.stdc_transfer_segmentation_train_params)
 
 
     def test_transfer_learning_vit_base_imagenet21k(self):
     def test_transfer_learning_vit_base_imagenet21k(self):
-        trainer = SgModel('imagenet21k_pretrained_vit_base',
+        trainer = Trainer('imagenet21k_pretrained_vit_base',
                           model_checkpoints_location='local',
                           model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
@@ -805,7 +805,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.transfer_classification_train_params)
         trainer.train(training_params=self.transfer_classification_train_params)
 
 
     def test_transfer_learning_vit_large_imagenet21k(self):
     def test_transfer_learning_vit_large_imagenet21k(self):
-        trainer = SgModel('imagenet21k_pretrained_vit_large',
+        trainer = Trainer('imagenet21k_pretrained_vit_large',
                           model_checkpoints_location='local',
                           model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
@@ -814,7 +814,7 @@ class PretrainedModelsTest(unittest.TestCase):
         trainer.train(training_params=self.transfer_classification_train_params)
         trainer.train(training_params=self.transfer_classification_train_params)
 
 
     def test_pretrained_vit_base_imagenet(self):
     def test_pretrained_vit_base_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_vit_base', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_vit_base', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset_05_mean_std, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset_05_mean_std, data_loader_num_workers=8)
         trainer.build_model("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
         trainer.build_model("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
@@ -824,7 +824,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_base"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_base"], delta=0.001)
 
 
     def test_pretrained_vit_large_imagenet(self):
     def test_pretrained_vit_large_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_vit_large', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_vit_large', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset_05_mean_std, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset_05_mean_std, data_loader_num_workers=8)
         trainer.build_model("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
         trainer.build_model("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
@@ -834,7 +834,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_large"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_large"], delta=0.001)
 
 
     def test_pretrained_beit_base_imagenet(self):
     def test_pretrained_beit_base_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_beit_base', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_beit_base', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.imagenet_dataset_05_mean_std, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.imagenet_dataset_05_mean_std, data_loader_num_workers=8)
         trainer.build_model("beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
         trainer.build_model("beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
@@ -844,7 +844,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["beit_base_patch16_224"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["beit_base_patch16_224"], delta=0.001)
 
 
     def test_transfer_learning_beit_base_imagenet(self):
     def test_transfer_learning_beit_base_imagenet(self):
-        trainer = SgModel('test_transfer_learning_beit_base_imagenet',
+        trainer = Trainer('test_transfer_learning_beit_base_imagenet',
                           model_checkpoints_location='local',
                           model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
Discard
@@ -1,7 +1,7 @@
 import unittest
 import unittest
 
 
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
-from super_gradients.training import SgModel, MultiGPUMode
+from super_gradients.training import Trainer, MultiGPUMode
 from super_gradients.training.metrics.classification_metrics import Accuracy
 from super_gradients.training.metrics.classification_metrics import Accuracy
 import os
 import os
 from super_gradients.training.utils.quantization_utils import PostQATConversionCallback
 from super_gradients.training.utils.quantization_utils import PostQATConversionCallback
@@ -11,12 +11,12 @@ class QATIntegrationTest(unittest.TestCase):
     def _get_trainer(self, experiment_name):
     def _get_trainer(self, experiment_name):
         dataset_params = {"batch_size": 10}
         dataset_params = {"batch_size": 10}
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
-        model = SgModel(experiment_name,
-                        model_checkpoints_location='local',
-                        multi_gpu=MultiGPUMode.OFF)
-        model.connect_dataset_interface(dataset)
-        model.build_model("resnet18", checkpoint_params={"pretrained_weights": "imagenet"})
-        return model
+        trainer = Trainer(experiment_name,
+                          model_checkpoints_location='local',
+                          multi_gpu=MultiGPUMode.OFF)
+        trainer.connect_dataset_interface(dataset)
+        trainer.build_model("resnet18", checkpoint_params={"pretrained_weights": "imagenet"})
+        return trainer
 
 
     def _get_train_params(self, qat_params):
     def _get_train_params(self, qat_params):
         train_params = {"max_epochs": 2,
         train_params = {"max_epochs": 2,
@@ -38,7 +38,7 @@ class QATIntegrationTest(unittest.TestCase):
         return train_params
         return train_params
 
 
     def test_qat_from_start(self):
     def test_qat_from_start(self):
-        model = self._get_trainer("test_qat_from_start")
+        trainer = self._get_trainer("test_qat_from_start")
         train_params = self._get_train_params(qat_params={
         train_params = self._get_train_params(qat_params={
             "start_epoch": 0,
             "start_epoch": 0,
             "quant_modules_calib_method": "percentile",
             "quant_modules_calib_method": "percentile",
@@ -47,10 +47,10 @@ class QATIntegrationTest(unittest.TestCase):
             "percentile": 99.99
             "percentile": 99.99
         })
         })
 
 
-        model.train(training_params=train_params)
+        trainer.train(training_params=train_params)
 
 
     def test_qat_transition(self):
     def test_qat_transition(self):
-        model = self._get_trainer("test_qat_transition")
+        trainer = self._get_trainer("test_qat_transition")
         train_params = self._get_train_params(qat_params={
         train_params = self._get_train_params(qat_params={
             "start_epoch": 1,
             "start_epoch": 1,
             "quant_modules_calib_method": "percentile",
             "quant_modules_calib_method": "percentile",
@@ -59,10 +59,10 @@ class QATIntegrationTest(unittest.TestCase):
             "percentile": 99.99
             "percentile": 99.99
         })
         })
 
 
-        model.train(training_params=train_params)
+        trainer.train(training_params=train_params)
 
 
     def test_qat_from_calibrated_ckpt(self):
     def test_qat_from_calibrated_ckpt(self):
-        model = self._get_trainer("generate_calibrated_model")
+        trainer = self._get_trainer("generate_calibrated_model")
         train_params = self._get_train_params(qat_params={
         train_params = self._get_train_params(qat_params={
             "start_epoch": 0,
             "start_epoch": 0,
             "quant_modules_calib_method": "percentile",
             "quant_modules_calib_method": "percentile",
@@ -71,11 +71,11 @@ class QATIntegrationTest(unittest.TestCase):
             "percentile": 99.99
             "percentile": 99.99
         })
         })
 
 
-        model.train(training_params=train_params)
+        trainer.train(training_params=train_params)
 
 
-        calibrated_model_path = os.path.join(model.checkpoints_dir_path, "ckpt_calibrated_percentile_99.99.pth")
+        calibrated_model_path = os.path.join(trainer.checkpoints_dir_path, "ckpt_calibrated_percentile_99.99.pth")
 
 
-        model = self._get_trainer("test_qat_from_calibrated_ckpt")
+        trainer = self._get_trainer("test_qat_from_calibrated_ckpt")
         train_params = self._get_train_params(qat_params={
         train_params = self._get_train_params(qat_params={
             "start_epoch": 0,
             "start_epoch": 0,
             "quant_modules_calib_method": "percentile",
             "quant_modules_calib_method": "percentile",
@@ -85,7 +85,7 @@ class QATIntegrationTest(unittest.TestCase):
             "percentile": 99.99
             "percentile": 99.99
         })
         })
 
 
-        model.train(training_params=train_params)
+        trainer.train(training_params=train_params)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -1,6 +1,6 @@
 import torch
 import torch
 from super_gradients.training.datasets.dataset_interfaces import DatasetInterface
 from super_gradients.training.datasets.dataset_interfaces import DatasetInterface
-from super_gradients.training.sg_model import SgModel
+from super_gradients.training.sg_trainer import Trainer
 from torchvision.models import resnet18
 from torchvision.models import resnet18
 import numpy as np
 import numpy as np
 
 
@@ -25,13 +25,13 @@ class TestDatasetInterface(DatasetInterface):
 # ------------------ Loading The Model From Model.py----------------
 # ------------------ Loading The Model From Model.py----------------
 arch_params = {'num_classes': 1000}
 arch_params = {'num_classes': 1000}
 model = resnet18()
 model = resnet18()
-sg_classification_model = SgModel('Client_model_training',
-                                  model_checkpoints_location='local', device='cpu')
+trainer = Trainer('Client_model_training',
+                  model_checkpoints_location='local', device='cpu')
 # if a torch.nn.Module is provided when building the model, the model will be integrated into deci model class
 # if a torch.nn.Module is provided when building the model, the model will be integrated into deci model class
-sg_classification_model.build_model(model, arch_params=arch_params)
+trainer.build_model(model, arch_params=arch_params)
 # ------------------ Loading The Dataset From Dataset.py----------------
 # ------------------ Loading The Dataset From Dataset.py----------------
 dataset = TestDatasetInterface()
 dataset = TestDatasetInterface()
-sg_classification_model.connect_dataset_interface(dataset)
+trainer.connect_dataset_interface(dataset)
 # ------------------ Loading The Loss From Loss.py -----------------
 # ------------------ Loading The Loss From Loss.py -----------------
 loss = 'cross_entropy'
 loss = 'cross_entropy'
 # ------------------ Training -----------------
 # ------------------ Training -----------------
@@ -40,4 +40,4 @@ train_params = {"max_epochs": 100,
                 "lr_updates": [30, 60, 90, 100],
                 "lr_updates": [30, 60, 90, 100],
                 "lr_decay_factor": 0.1,
                 "lr_decay_factor": 0.1,
                 "initial_lr": 0.025, "loss": loss}
                 "initial_lr": 0.025, "loss": loss}
-sg_classification_model.train(train_params)
+trainer.train(train_params)
Discard
@@ -6,7 +6,7 @@ from super_gradients.training.transforms.transforms import DetectionPaddedRescal
     DetectionHSV
     DetectionHSV
 from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
 from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
 from super_gradients.training.utils.detection_utils import DetectionCollateFN
 from super_gradients.training.utils.detection_utils import DetectionCollateFN
-from super_gradients.training.utils import sg_model_utils
+from super_gradients.training.utils import sg_trainer_utils
 from super_gradients.training import utils as core_utils
 from super_gradients.training import utils as core_utils
 
 
 
 
@@ -116,7 +116,7 @@ class TestDatasetInterface(unittest.TestCase):
             batch_items = next(iter(loader))
             batch_items = next(iter(loader))
             batch_items = core_utils.tensor_container_to_device(batch_items, 'cuda', non_blocking=True)
             batch_items = core_utils.tensor_container_to_device(batch_items, 'cuda', non_blocking=True)
 
 
-            inputs, targets, additional_batch_items = sg_model_utils.unpack_batch_items(batch_items)
+            inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items)
             self.assertListEqual([batch_size, 3, image_size, image_size], list(inputs.shape))
             self.assertListEqual([batch_size, 3, image_size, image_size], list(inputs.shape))
 
 
     def test_pascal_voc(self):
     def test_pascal_voc(self):
@@ -131,7 +131,7 @@ class TestDatasetInterface(unittest.TestCase):
             batch_items = next(iter(loader))
             batch_items = next(iter(loader))
             batch_items = core_utils.tensor_container_to_device(batch_items, 'cuda', non_blocking=True)
             batch_items = core_utils.tensor_container_to_device(batch_items, 'cuda', non_blocking=True)
 
 
-            inputs, targets, additional_batch_items = sg_model_utils.unpack_batch_items(batch_items)
+            inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items)
             self.assertListEqual([batch_size, 3, image_size, image_size], list(inputs.shape))
             self.assertListEqual([batch_size, 3, image_size, image_size], list(inputs.shape))
 
 
 
 
Discard
@@ -3,7 +3,7 @@ import unittest
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CoCoDetectionDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CoCoDetectionDatasetInterface
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 
 
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 from super_gradients.training.utils.detection_utils import CrowdDetectionCollateFN, DetectionCollateFN, \
 from super_gradients.training.utils.detection_utils import CrowdDetectionCollateFN, DetectionCollateFN, \
     DetectionTargetsFormat
     DetectionTargetsFormat
@@ -53,11 +53,11 @@ class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
                                                                 "area_thr": 0
                                                                 "area_thr": 0
                                                                 })
                                                                 })
 
 
-        model = SgModel('dataset_statistics_visual_test',
-                        model_checkpoints_location='local',
-                        post_prediction_callback=YoloPostPredictionCallback())
-        model.connect_dataset_interface(dataset, data_loader_num_workers=8)
-        model.build_model("yolox_s")
+        trainer = Trainer('dataset_statistics_visual_test',
+                          model_checkpoints_location='local',
+                          post_prediction_callback=YoloPostPredictionCallback())
+        trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
+        trainer.build_model("yolox_s")
 
 
         training_params = {"max_epochs": 1,  # we dont really need the actual training to run
         training_params = {"max_epochs": 1,  # we dont really need the actual training to run
                            "lr_mode": "cosine",
                            "lr_mode": "cosine",
@@ -74,7 +74,7 @@ class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
                            "loss_logging_items_names": ["iou", "obj", "cls", "l1", "num_fg", "Loss"],
                            "loss_logging_items_names": ["iou", "obj", "cls", "l1", "num_fg", "Loss"],
                            "metric_to_watch": "mAP@0.50:0.95",
                            "metric_to_watch": "mAP@0.50:0.95",
                            }
                            }
-        model.train(training_params=training_params)
+        trainer.train(training_params=training_params)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -1,7 +1,7 @@
 import os
 import os
 import unittest
 import unittest
 
 
-from super_gradients.training import SgModel, utils as core_utils
+from super_gradients.training import Trainer, utils as core_utils
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CoCoDetectionDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CoCoDetectionDatasetInterface
 from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
 from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
@@ -46,27 +46,27 @@ class TestDetectionUtils(unittest.TestCase):
                                                                 })
                                                                 })
 
 
         # Create Yolo model
         # Create Yolo model
-        model = SgModel('visualization_test',
-                        model_checkpoints_location='local',
-                        post_prediction_callback=YoloPostPredictionCallback())
-        model.connect_dataset_interface(dataset, data_loader_num_workers=8)
-        model.build_model("yolox_n", checkpoint_params={"pretrained_weights": "coco"})
+        trainer = Trainer('visualization_test',
+                          model_checkpoints_location='local',
+                          post_prediction_callback=YoloPostPredictionCallback())
+        trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
+        trainer.build_model("yolox_n", checkpoint_params={"pretrained_weights": "coco"})
 
 
         # Simulate one iteration of validation subset
         # Simulate one iteration of validation subset
-        valid_loader = model.valid_loader
+        valid_loader = trainer.valid_loader
         batch_i, (imgs, targets) = 0, next(iter(valid_loader))
         batch_i, (imgs, targets) = 0, next(iter(valid_loader))
-        imgs = core_utils.tensor_container_to_device(imgs, model.device)
-        targets = core_utils.tensor_container_to_device(targets, model.device)
-        output = model.net(imgs)
-        output = model.post_prediction_callback(output)
+        imgs = core_utils.tensor_container_to_device(imgs, trainer.device)
+        targets = core_utils.tensor_container_to_device(targets, trainer.device)
+        output = trainer.net(imgs)
+        output = trainer.post_prediction_callback(output)
         # Visualize the batch
         # Visualize the batch
         DetectionVisualization.visualize_batch(imgs, output, targets, batch_i,
         DetectionVisualization.visualize_batch(imgs, output, targets, batch_i,
-                                               COCO_DETECTION_CLASSES_LIST, model.checkpoints_dir_path)
+                                               COCO_DETECTION_CLASSES_LIST, trainer.checkpoints_dir_path)
 
 
         # Assert images ware created and delete them
         # Assert images ware created and delete them
         img_name = '{}/{}_{}.jpg'
         img_name = '{}/{}_{}.jpg'
         for i in range(4):
         for i in range(4):
-            img_path = img_name.format(model.checkpoints_dir_path, batch_i, i)
+            img_path = img_name.format(trainer.checkpoints_dir_path, batch_i, i)
             self.assertTrue(os.path.exists(img_path))
             self.assertTrue(os.path.exists(img_path))
             os.remove(img_path)
             os.remove(img_path)
 
 
Discard
@@ -4,7 +4,7 @@ import unittest
 
 
 from super_gradients.training.utils.early_stopping import EarlyStop
 from super_gradients.training.utils.early_stopping import EarlyStop
 from super_gradients.training.utils.callbacks import Phase
 from super_gradients.training.utils.callbacks import Phase
-from super_gradients.training.sg_model import SgModel
+from super_gradients.training.sg_trainer import Trainer
 from super_gradients.training.datasets.dataset_interfaces import ClassificationTestDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces import ClassificationTestDatasetInterface
 from super_gradients.training.models.classification_models.resnet import ResNet18
 from super_gradients.training.models.classification_models.resnet import ResNet18
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
@@ -60,9 +60,9 @@ class EarlyStopTest(unittest.TestCase):
         Test for mode=min metric, test that training stops after no improvement in metric value for amount of `patience`
         Test for mode=min metric, test that training stops after no improvement in metric value for amount of `patience`
         epochs.
         epochs.
         """
         """
-        model = SgModel("early_stop_test", model_checkpoints_location='local')
-        model.connect_dataset_interface(self.dataset)
-        model.build_model(self.net)
+        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer.connect_dataset_interface(self.dataset)
+        trainer.build_model(self.net)
 
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", patience=3, verbose=True)
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", patience=3, verbose=True)
         phase_callbacks = [early_stop_loss]
         phase_callbacks = [early_stop_loss]
@@ -72,7 +72,7 @@ class EarlyStopTest(unittest.TestCase):
         train_params = self.train_params.copy()
         train_params = self.train_params.copy()
         train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
         train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
 
 
-        model.train(train_params)
+        trainer.train(train_params)
 
 
         excepted_end_epoch = 5
         excepted_end_epoch = 5
 
 
@@ -84,9 +84,9 @@ class EarlyStopTest(unittest.TestCase):
         Test for mode=max metric, test that training stops after no improvement in metric value for amount of `patience`
         Test for mode=max metric, test that training stops after no improvement in metric value for amount of `patience`
         epochs.
         epochs.
         """
         """
-        model = SgModel("early_stop_test", model_checkpoints_location='local')
-        model.connect_dataset_interface(self.dataset)
-        model.build_model(self.net)
+        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer.connect_dataset_interface(self.dataset)
+        trainer.build_model(self.net)
 
 
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=3,
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=3,
                                    verbose=True)
                                    verbose=True)
@@ -98,7 +98,7 @@ class EarlyStopTest(unittest.TestCase):
         train_params.update(
         train_params.update(
             {"valid_metrics_list": [fake_metric], "metric_to_watch": "MetricTest", "phase_callbacks": phase_callbacks})
             {"valid_metrics_list": [fake_metric], "metric_to_watch": "MetricTest", "phase_callbacks": phase_callbacks})
 
 
-        model.train(train_params)
+        trainer.train(train_params)
 
 
         excepted_end_epoch = 6
         excepted_end_epoch = 6
 
 
@@ -108,9 +108,9 @@ class EarlyStopTest(unittest.TestCase):
         """
         """
         Test for mode=min metric, test that training stops after metric value reaches the `threshold` value.
         Test for mode=min metric, test that training stops after metric value reaches the `threshold` value.
         """
         """
-        model = SgModel("early_stop_test", model_checkpoints_location='local')
-        model.connect_dataset_interface(self.dataset)
-        model.build_model(self.net)
+        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer.connect_dataset_interface(self.dataset)
+        trainer.build_model(self.net)
 
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", threshold=0.1, verbose=True)
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", threshold=0.1, verbose=True)
         phase_callbacks = [early_stop_loss]
         phase_callbacks = [early_stop_loss]
@@ -120,7 +120,7 @@ class EarlyStopTest(unittest.TestCase):
         train_params = self.train_params.copy()
         train_params = self.train_params.copy()
         train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
         train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
 
 
-        model.train(train_params)
+        trainer.train(train_params)
 
 
         excepted_end_epoch = 5
         excepted_end_epoch = 5
         # count divided by 2, because loss counter used for both train and eval.
         # count divided by 2, because loss counter used for both train and eval.
@@ -130,9 +130,9 @@ class EarlyStopTest(unittest.TestCase):
         """
         """
         Test for mode=max metric, test that training stops after metric value reaches the `threshold` value.
         Test for mode=max metric, test that training stops after metric value reaches the `threshold` value.
         """
         """
-        model = SgModel("early_stop_test", model_checkpoints_location='local')
-        model.connect_dataset_interface(self.dataset)
-        model.build_model(self.net)
+        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer.connect_dataset_interface(self.dataset)
+        trainer.build_model(self.net)
 
 
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", threshold=0.94,
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", threshold=0.94,
                                    verbose=True)
                                    verbose=True)
@@ -144,7 +144,7 @@ class EarlyStopTest(unittest.TestCase):
         train_params.update(
         train_params.update(
             {"valid_metrics_list": [fake_metric], "metric_to_watch": "MetricTest", "phase_callbacks": phase_callbacks})
             {"valid_metrics_list": [fake_metric], "metric_to_watch": "MetricTest", "phase_callbacks": phase_callbacks})
 
 
-        model.train(train_params)
+        trainer.train(train_params)
 
 
         excepted_end_epoch = 7
         excepted_end_epoch = 7
 
 
@@ -155,9 +155,9 @@ class EarlyStopTest(unittest.TestCase):
         Test that training stops when monitor value is not a finite number. Test case of NaN and Inf values.
         Test that training stops when monitor value is not a finite number. Test case of NaN and Inf values.
         """
         """
         # test Nan value
         # test Nan value
-        model = SgModel("early_stop_test", model_checkpoints_location='local')
-        model.connect_dataset_interface(self.dataset)
-        model.build_model(self.net)
+        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer.connect_dataset_interface(self.dataset)
+        trainer.build_model(self.net)
 
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", check_finite=True,
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", check_finite=True,
                                     verbose=True)
                                     verbose=True)
@@ -168,16 +168,16 @@ class EarlyStopTest(unittest.TestCase):
         train_params = self.train_params.copy()
         train_params = self.train_params.copy()
         train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
         train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
 
 
-        model.train(train_params)
+        trainer.train(train_params)
 
 
         excepted_end_epoch = 2
         excepted_end_epoch = 2
 
 
         self.assertEqual(excepted_end_epoch, fake_loss.count // 2)
         self.assertEqual(excepted_end_epoch, fake_loss.count // 2)
 
 
         # test Inf value
         # test Inf value
-        model = SgModel("early_stop_test", model_checkpoints_location='local')
-        model.connect_dataset_interface(self.dataset)
-        model.build_model(self.net)
+        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer.connect_dataset_interface(self.dataset)
+        trainer.build_model(self.net)
 
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", patience=3, verbose=True)
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", patience=3, verbose=True)
         phase_callbacks = [early_stop_loss]
         phase_callbacks = [early_stop_loss]
@@ -187,7 +187,7 @@ class EarlyStopTest(unittest.TestCase):
         train_params = self.train_params.copy()
         train_params = self.train_params.copy()
         train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
         train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
 
 
-        model.train(train_params)
+        trainer.train(train_params)
 
 
         excepted_end_epoch = 3
         excepted_end_epoch = 3
         # count divided by 2, because loss counter used for both train and eval.
         # count divided by 2, because loss counter used for both train and eval.
@@ -198,9 +198,9 @@ class EarlyStopTest(unittest.TestCase):
         Test for `min_delta` argument, metric value is considered an improvement only if
         Test for `min_delta` argument, metric value is considered an improvement only if
         current_value - min_delta > best_value
         current_value - min_delta > best_value
         """
         """
-        model = SgModel("early_stop_test", model_checkpoints_location='local')
-        model.connect_dataset_interface(self.dataset)
-        model.build_model(self.net)
+        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer.connect_dataset_interface(self.dataset)
+        trainer.build_model(self.net)
 
 
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=2,
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=2,
                                    min_delta=0.1, verbose=True)
                                    min_delta=0.1, verbose=True)
@@ -212,7 +212,7 @@ class EarlyStopTest(unittest.TestCase):
         train_params.update(
         train_params.update(
             {"valid_metrics_list": [fake_metric], "metric_to_watch": "MetricTest", "phase_callbacks": phase_callbacks})
             {"valid_metrics_list": [fake_metric], "metric_to_watch": "MetricTest", "phase_callbacks": phase_callbacks})
 
 
-        model.train(train_params)
+        trainer.train(train_params)
 
 
         excepted_end_epoch = 5
         excepted_end_epoch = 5
 
 
Discard
@@ -2,7 +2,7 @@ import unittest
 
 
 import torch
 import torch
 
 
-from super_gradients import ClassificationTestDatasetInterface, SgModel
+from super_gradients import ClassificationTestDatasetInterface, Trainer
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.models import ResNet18
 from super_gradients.training.models import ResNet18
 
 
@@ -10,13 +10,13 @@ from super_gradients.training.models import ResNet18
 class FactoriesTest(unittest.TestCase):
 class FactoriesTest(unittest.TestCase):
 
 
     def test_training_with_factories(self):
     def test_training_with_factories(self):
-        model = SgModel("test_train_with_factories", model_checkpoints_location='local')
+        trainer = Trainer("test_train_with_factories", model_checkpoints_location='local')
         dataset_params = {"batch_size": 10}
         dataset_params = {"batch_size": 10}
         dataset = {"classification_test_dataset": {"dataset_params": dataset_params}}
         dataset = {"classification_test_dataset": {"dataset_params": dataset_params}}
-        model.connect_dataset_interface(dataset)
+        trainer.connect_dataset_interface(dataset)
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
-        model.build_model(net)
+        trainer.build_model(net)
         train_params = {"max_epochs": 2,
         train_params = {"max_epochs": 2,
                         "lr_updates": [1],
                         "lr_updates": [1],
                         "lr_decay_factor": 0.1,
                         "lr_decay_factor": 0.1,
@@ -32,12 +32,12 @@ class FactoriesTest(unittest.TestCase):
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
 
 
-        model.train(train_params)
+        trainer.train(train_params)
 
 
-        self.assertIsInstance(model.train_metrics.Accuracy, Accuracy)
-        self.assertIsInstance(model.valid_metrics.Top5, Top5)
-        self.assertIsInstance(model.dataset_interface, ClassificationTestDatasetInterface)
-        self.assertIsInstance(model.optimizer, torch.optim.ASGD)
+        self.assertIsInstance(trainer.train_metrics.Accuracy, Accuracy)
+        self.assertIsInstance(trainer.valid_metrics.Top5, Top5)
+        self.assertIsInstance(trainer.dataset_interface, ClassificationTestDatasetInterface)
+        self.assertIsInstance(trainer.optimizer, torch.optim.ASGD)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -1,5 +1,5 @@
 import unittest
 import unittest
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
 from super_gradients.training.utils.callbacks import PhaseCallback, Phase, PhaseContext
 from super_gradients.training.utils.callbacks import PhaseCallback, Phase, PhaseContext
@@ -34,9 +34,9 @@ class ForwardpassPrepFNTest(unittest.TestCase):
 
 
     def test_resizing_with_forward_pass_prep_fn(self):
     def test_resizing_with_forward_pass_prep_fn(self):
         # Define Model
         # Define Model
-        model = SgModel("ForwardpassPrepFNTest")
-        model.connect_dataset_interface(self.dataset)
-        model.build_model("resnet18", arch_params=self.arch_params)
+        trainer = Trainer("ForwardpassPrepFNTest")
+        trainer.connect_dataset_interface(self.dataset)
+        trainer.build_model("resnet18", arch_params=self.arch_params)
 
 
         sizes = []
         sizes = []
         phase_callbacks = [TestInputSizesCallback(sizes)]
         phase_callbacks = [TestInputSizesCallback(sizes)]
@@ -49,7 +49,7 @@ class ForwardpassPrepFNTest(unittest.TestCase):
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks,
                         "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks,
                         "pre_prediction_callback": test_forward_pass_prep_fn}
                         "pre_prediction_callback": test_forward_pass_prep_fn}
-        model.train(train_params)
+        trainer.train(train_params)
 
 
         # ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
         # ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
         # THE LRS AFTER THE UPDATE
         # THE LRS AFTER THE UPDATE
Discard
@@ -1,9 +1,9 @@
 import unittest
 import unittest
-from super_gradients import SgModel, ClassificationTestDatasetInterface
+from super_gradients import Trainer, ClassificationTestDatasetInterface
 import torch
 import torch
 from torch.utils.data import TensorDataset, DataLoader
 from torch.utils.data import TensorDataset, DataLoader
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
-from super_gradients.training.exceptions.sg_model_exceptions import IllegalDataloaderInitialization
+from super_gradients.training.exceptions.sg_trainer_exceptions import IllegalDataloaderInitialization
 
 
 
 
 class InitializeWithDataloadersTest(unittest.TestCase):
 class InitializeWithDataloadersTest(unittest.TestCase):
@@ -24,12 +24,12 @@ class InitializeWithDataloadersTest(unittest.TestCase):
         self.testcase_testloader = DataLoader(TensorDataset(inp, label))
         self.testcase_testloader = DataLoader(TensorDataset(inp, label))
 
 
     def test_interface_was_not_broken(self):
     def test_interface_was_not_broken(self):
-        model = SgModel("test_interface", model_checkpoints_location='local')
+        trainer = Trainer("test_interface", model_checkpoints_location='local')
         dataset_params = {"batch_size": 10}
         dataset_params = {"batch_size": 10}
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
-        model.connect_dataset_interface(dataset)
+        trainer.connect_dataset_interface(dataset)
 
 
-        model.build_model("efficientnet_b0")
+        trainer.build_model("efficientnet_b0")
         train_params = {"max_epochs": 1, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 1, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": torch.nn.CrossEntropyLoss(),
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": torch.nn.CrossEntropyLoss(),
                         "optimizer": "SGD",
                         "optimizer": "SGD",
@@ -37,46 +37,46 @@ class InitializeWithDataloadersTest(unittest.TestCase):
                         "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
                         "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
                         "metric_to_watch": "Accuracy",
                         "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
-        model.train(train_params)
+        trainer.train(train_params)
 
 
     def test_initialization_rules(self):
     def test_initialization_rules(self):
-        self.assertRaises(IllegalDataloaderInitialization, SgModel, "test_name", model_checkpoints_location='local',
+        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
                           train_loader=self.testcase_trainloader)
                           train_loader=self.testcase_trainloader)
-        self.assertRaises(IllegalDataloaderInitialization, SgModel, "test_name", model_checkpoints_location='local',
+        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
                           valid_loader=self.testcase_validloader)
                           valid_loader=self.testcase_validloader)
-        self.assertRaises(IllegalDataloaderInitialization, SgModel, "test_name", model_checkpoints_location='local',
+        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
                           classes=self.testcase_classes)
                           classes=self.testcase_classes)
-        self.assertRaises(IllegalDataloaderInitialization, SgModel, "test_name", model_checkpoints_location='local',
+        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
                           train_loader=self.testcase_trainloader, valid_loader=self.testcase_validloader)
                           train_loader=self.testcase_trainloader, valid_loader=self.testcase_validloader)
-        self.assertRaises(IllegalDataloaderInitialization, SgModel, "test_name", model_checkpoints_location='local',
+        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
                           train_loader=self.testcase_trainloader, classes=self.testcase_classes)
                           train_loader=self.testcase_trainloader, classes=self.testcase_classes)
-        self.assertRaises(IllegalDataloaderInitialization, SgModel, "test_name", model_checkpoints_location='local',
+        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
                           valid_loader=self.testcase_validloader, classes=self.testcase_classes)
                           valid_loader=self.testcase_validloader, classes=self.testcase_classes)
-        SgModel("test_name", model_checkpoints_location='local', train_loader=self.testcase_trainloader,
+        Trainer("test_name", model_checkpoints_location='local', train_loader=self.testcase_trainloader,
                 valid_loader=self.testcase_validloader, classes=self.testcase_classes)
                 valid_loader=self.testcase_validloader, classes=self.testcase_classes)
-        SgModel("test_name", model_checkpoints_location='local', train_loader=self.testcase_trainloader,
+        Trainer("test_name", model_checkpoints_location='local', train_loader=self.testcase_trainloader,
                 valid_loader=self.testcase_validloader, test_loader=self.testcase_testloader,
                 valid_loader=self.testcase_validloader, test_loader=self.testcase_testloader,
                 classes=self.testcase_classes)
                 classes=self.testcase_classes)
 
 
     def test_train_with_dataloaders(self):
     def test_train_with_dataloaders(self):
-        model = SgModel(experiment_name="test_name", model_checkpoints_location="local",
-                        train_loader=self.testcase_trainloader, valid_loader=self.testcase_validloader,
-                        classes=self.testcase_classes)
+        trainer = Trainer(experiment_name="test_name", model_checkpoints_location="local",
+                          train_loader=self.testcase_trainloader, valid_loader=self.testcase_validloader,
+                          classes=self.testcase_classes)
 
 
-        model.build_model("resnet18")
-        model.train(training_params={"max_epochs": 2,
-                                     "lr_updates": [5, 6, 12],
-                                     "lr_decay_factor": 0.01,
-                                     "lr_mode": "step",
-                                     "initial_lr": 0.01,
-                                     "loss": "cross_entropy",
-                                     "optimizer": "SGD",
-                                     "optimizer_params": {"weight_decay": 1e-5, "momentum": 0.9},
-                                     "train_metrics_list": [Accuracy()],
-                                     "valid_metrics_list": [Accuracy()],
-                                     "metric_to_watch": "Accuracy",
-                                     "greater_metric_to_watch_is_better": True})
-        self.assertTrue(0 < model.best_metric.item() < 1)
+        trainer.build_model("resnet18")
+        trainer.train(training_params={"max_epochs": 2,
+                                       "lr_updates": [5, 6, 12],
+                                       "lr_decay_factor": 0.01,
+                                       "lr_mode": "step",
+                                       "initial_lr": 0.01,
+                                       "loss": "cross_entropy",
+                                       "optimizer": "SGD",
+                                       "optimizer_params": {"weight_decay": 1e-5, "momentum": 0.9},
+                                       "train_metrics_list": [Accuracy()],
+                                       "valid_metrics_list": [Accuracy()],
+                                       "metric_to_watch": "Accuracy",
+                                       "greater_metric_to_watch_is_better": True})
+        self.assertTrue(0 < trainer.best_metric.item() < 1)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -1,6 +1,6 @@
 import unittest
 import unittest
-from super_gradients.training.sg_model import SgModel
-from super_gradients.training.kd_model.kd_model import KDModel
+from super_gradients.training.sg_trainer import Trainer
+from super_gradients.training.kd_trainer.kd_trainer import KDTrainer
 import torch
 import torch
 from super_gradients.training.utils.utils import check_models_have_same_weights
 from super_gradients.training.utils.utils import check_models_have_same_weights
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
@@ -11,7 +11,7 @@ from super_gradients.training.losses.kd_losses import KDLogitsLoss
 class KDEMATest(unittest.TestCase):
 class KDEMATest(unittest.TestCase):
     @classmethod
     @classmethod
     def setUp(cls):
     def setUp(cls):
-        cls.sg_trained_teacher = SgModel("sg_trained_teacher", device='cpu')
+        cls.sg_trained_teacher = Trainer("sg_trained_teacher", device='cpu')
         cls.dataset_params = {"batch_size": 5}
         cls.dataset_params = {"batch_size": 5}
         cls.dataset = ClassificationTestDatasetInterface(dataset_params=cls.dataset_params)
         cls.dataset = ClassificationTestDatasetInterface(dataset_params=cls.dataset_params)
 
 
@@ -29,50 +29,50 @@ class KDEMATest(unittest.TestCase):
     def test_teacher_ema_not_duplicated(self):
     def test_teacher_ema_not_duplicated(self):
         """Check that the teacher EMA is a reference to the teacher net (not a copy)."""
         """Check that the teacher EMA is a reference to the teacher net (not a copy)."""
 
 
-        kd_model = KDModel("test_teacher_ema_not_duplicated", device='cpu')
-        kd_model.connect_dataset_interface(self.dataset)
-        kd_model.build_model(student_architecture='resnet18',
-                             teacher_architecture='resnet50',
-                             student_arch_params={'num_classes': 1000},
-                             teacher_arch_params={'num_classes': 1000},
-                             checkpoint_params={'teacher_pretrained_weights': "imagenet"},
-                             run_teacher_on_eval=True, )
+        kd_trainer = KDTrainer("test_teacher_ema_not_duplicated", device='cpu')
+        kd_trainer.connect_dataset_interface(self.dataset)
+        kd_trainer.build_model(student_architecture='resnet18',
+                               teacher_architecture='resnet50',
+                               student_arch_params={'num_classes': 1000},
+                               teacher_arch_params={'num_classes': 1000},
+                               checkpoint_params={'teacher_pretrained_weights': "imagenet"},
+                               run_teacher_on_eval=True, )
 
 
-        kd_model.train(self.kd_train_params)
+        kd_trainer.train(self.kd_train_params)
 
 
-        self.assertTrue(kd_model.ema_model.ema.module.teacher is kd_model.net.module.teacher)
-        self.assertTrue(kd_model.ema_model.ema.module.student is not kd_model.net.module.student)
+        self.assertTrue(kd_trainer.ema_model.ema.module.teacher is kd_trainer.net.module.teacher)
+        self.assertTrue(kd_trainer.ema_model.ema.module.student is not kd_trainer.net.module.student)
 
 
     def test_kd_ckpt_reload_ema(self):
     def test_kd_ckpt_reload_ema(self):
         """Check that the KD model load correctly from checkpoint when "load_ema_as_net=True"."""
         """Check that the KD model load correctly from checkpoint when "load_ema_as_net=True"."""
 
 
         # Create a KD model and train it
         # Create a KD model and train it
-        kd_model = KDModel("test_kd_ema_ckpt_reload", device='cpu')
-        kd_model.connect_dataset_interface(self.dataset)
-        kd_model.build_model(student_architecture='resnet18',
-                             teacher_architecture='resnet50',
-                             student_arch_params={'num_classes': 1000},
-                             teacher_arch_params={'num_classes': 1000},
-                             checkpoint_params={'teacher_pretrained_weights': "imagenet"},
-                             run_teacher_on_eval=True, )
-        kd_model.train(self.kd_train_params)
-        ema_model = kd_model.ema_model.ema
-        net = kd_model.net
+        kd_trainer = KDTrainer("test_kd_ema_ckpt_reload", device='cpu')
+        kd_trainer.connect_dataset_interface(self.dataset)
+        kd_trainer.build_model(student_architecture='resnet18',
+                               teacher_architecture='resnet50',
+                               student_arch_params={'num_classes': 1000},
+                               teacher_arch_params={'num_classes': 1000},
+                               checkpoint_params={'teacher_pretrained_weights': "imagenet"},
+                               run_teacher_on_eval=True, )
+        kd_trainer.train(self.kd_train_params)
+        ema_model = kd_trainer.ema_model.ema
+        net = kd_trainer.net
 
 
         # Load the trained KD model
         # Load the trained KD model
-        kd_model = KDModel("test_kd_ema_ckpt_reload", device='cpu')
-        kd_model.connect_dataset_interface(self.dataset)
-        kd_model.build_model(student_architecture='resnet18',
-                             teacher_architecture='resnet50',
-                             student_arch_params={'num_classes': 1000},
-                             teacher_arch_params={'num_classes': 1000},
-                             checkpoint_params={"load_checkpoint": True, "load_ema_as_net": True},
-                             run_teacher_on_eval=True, )
+        kd_trainer = KDTrainer("test_kd_ema_ckpt_reload", device='cpu')
+        kd_trainer.connect_dataset_interface(self.dataset)
+        kd_trainer.build_model(student_architecture='resnet18',
+                               teacher_architecture='resnet50',
+                               student_arch_params={'num_classes': 1000},
+                               teacher_arch_params={'num_classes': 1000},
+                               checkpoint_params={"load_checkpoint": True, "load_ema_as_net": True},
+                               run_teacher_on_eval=True, )
 
 
         # TRAIN FOR 0 EPOCHS JUST TO SEE THAT WHEN CONTINUING TRAINING EMA MODEL HAS BEEN SAVED CORRECTLY
         # TRAIN FOR 0 EPOCHS JUST TO SEE THAT WHEN CONTINUING TRAINING EMA MODEL HAS BEEN SAVED CORRECTLY
-        kd_model.train(self.kd_train_params)
-        reloaded_ema_model = kd_model.ema_model.ema
-        reloaded_net = kd_model.net
+        kd_trainer.train(self.kd_train_params)
+        reloaded_ema_model = kd_trainer.ema_model.ema
+        reloaded_net = kd_trainer.net
 
 
         # trained ema == loaded ema (Should always be true as long as "ema=True" in train_params)
         # trained ema == loaded ema (Should always be true as long as "ema=True" in train_params)
         self.assertTrue(check_models_have_same_weights(ema_model, reloaded_ema_model))
         self.assertTrue(check_models_have_same_weights(ema_model, reloaded_ema_model))
@@ -93,32 +93,32 @@ class KDEMATest(unittest.TestCase):
         """Check that the KD model load correctly from checkpoint when "load_ema_as_net=False"."""
         """Check that the KD model load correctly from checkpoint when "load_ema_as_net=False"."""
 
 
         # Create a KD model and train it
         # Create a KD model and train it
-        kd_model = KDModel("test_kd_ema_ckpt_reload", device='cpu')
-        kd_model.connect_dataset_interface(self.dataset)
-        kd_model.build_model(student_architecture='resnet18',
-                             teacher_architecture='resnet50',
-                             student_arch_params={'num_classes': 1000},
-                             teacher_arch_params={'num_classes': 1000},
-                             checkpoint_params={'teacher_pretrained_weights': "imagenet"},
-                             run_teacher_on_eval=True, )
-        kd_model.train(self.kd_train_params)
-        ema_model = kd_model.ema_model.ema
-        net = kd_model.net
+        kd_trainer = KDTrainer("test_kd_ema_ckpt_reload", device='cpu')
+        kd_trainer.connect_dataset_interface(self.dataset)
+        kd_trainer.build_model(student_architecture='resnet18',
+                               teacher_architecture='resnet50',
+                               student_arch_params={'num_classes': 1000},
+                               teacher_arch_params={'num_classes': 1000},
+                               checkpoint_params={'teacher_pretrained_weights': "imagenet"},
+                               run_teacher_on_eval=True, )
+        kd_trainer.train(self.kd_train_params)
+        ema_model = kd_trainer.ema_model.ema
+        net = kd_trainer.net
 
 
         # Load the trained KD model
         # Load the trained KD model
-        kd_model = KDModel("test_kd_ema_ckpt_reload", device='cpu')
-        kd_model.connect_dataset_interface(self.dataset)
-        kd_model.build_model(student_architecture='resnet18',
-                             teacher_architecture='resnet50',
-                             student_arch_params={'num_classes': 1000},
-                             teacher_arch_params={'num_classes': 1000},
-                             checkpoint_params={"load_checkpoint": True, "load_ema_as_net": False},
-                             run_teacher_on_eval=True, )
+        kd_trainer = KDTrainer("test_kd_ema_ckpt_reload", device='cpu')
+        kd_trainer.connect_dataset_interface(self.dataset)
+        kd_trainer.build_model(student_architecture='resnet18',
+                               teacher_architecture='resnet50',
+                               student_arch_params={'num_classes': 1000},
+                               teacher_arch_params={'num_classes': 1000},
+                               checkpoint_params={"load_checkpoint": True, "load_ema_as_net": False},
+                               run_teacher_on_eval=True, )
 
 
         # TRAIN FOR 0 EPOCHS JUST TO SEE THAT WHEN CONTINUING TRAINING EMA MODEL HAS BEEN SAVED CORRECTLY
         # TRAIN FOR 0 EPOCHS JUST TO SEE THAT WHEN CONTINUING TRAINING EMA MODEL HAS BEEN SAVED CORRECTLY
-        kd_model.train(self.kd_train_params)
-        reloaded_ema_model = kd_model.ema_model.ema
-        reloaded_net = kd_model.net
+        kd_trainer.train(self.kd_train_params)
+        reloaded_ema_model = kd_trainer.ema_model.ema
+        reloaded_net = kd_trainer.net
 
 
         # trained ema == loaded ema (Should always be true as long as "ema=True" in train_params)
         # trained ema == loaded ema (Should always be true as long as "ema=True" in train_params)
         self.assertTrue(check_models_have_same_weights(ema_model, reloaded_ema_model))
         self.assertTrue(check_models_have_same_weights(ema_model, reloaded_ema_model))
Discard
@@ -1,12 +1,12 @@
 import unittest
 import unittest
 import os
 import os
-from super_gradients.training.sg_model import SgModel
-from super_gradients.training.kd_model.kd_model import KDModel
+from super_gradients.training.sg_trainer import Trainer
+from super_gradients.training.kd_trainer.kd_trainer import KDTrainer
 import torch
 import torch
 from super_gradients.training.utils.utils import check_models_have_same_weights
 from super_gradients.training.utils.utils import check_models_have_same_weights
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
-from super_gradients.training.exceptions.kd_model_exceptions import ArchitectureKwargsException, \
+from super_gradients.training.exceptions.kd_trainer_exceptions import ArchitectureKwargsException, \
     UnsupportedKDArchitectureException, InconsistentParamsException, UnsupportedKDModelArgException, \
     UnsupportedKDArchitectureException, InconsistentParamsException, UnsupportedKDModelArgException, \
     TeacherKnowledgeException
     TeacherKnowledgeException
 from super_gradients.training.models.classification_models.resnet import ResNet50, ResNet18
 from super_gradients.training.models.classification_models.resnet import ResNet50, ResNet18
@@ -15,10 +15,10 @@ from copy import deepcopy
 from super_gradients.training.utils.module_utils import NormalizationAdapter
 from super_gradients.training.utils.module_utils import NormalizationAdapter
 
 
 
 
-class KDModelTest(unittest.TestCase):
+class KDTrainerTest(unittest.TestCase):
     @classmethod
     @classmethod
     def setUp(cls):
     def setUp(cls):
-        cls.sg_trained_teacher = SgModel("sg_trained_teacher", device='cpu')
+        cls.sg_trained_teacher = Trainer("sg_trained_teacher", device='cpu')
         cls.dataset_params = {"batch_size": 5}
         cls.dataset_params = {"batch_size": 5}
         cls.dataset = ClassificationTestDatasetInterface(dataset_params=cls.dataset_params)
         cls.dataset = ClassificationTestDatasetInterface(dataset_params=cls.dataset_params)
         cls.sg_trained_teacher.connect_dataset_interface(cls.dataset)
         cls.sg_trained_teacher.connect_dataset_interface(cls.dataset)
@@ -44,86 +44,86 @@ class KDModelTest(unittest.TestCase):
                                "greater_metric_to_watch_is_better": True, "average_best_models": False}
                                "greater_metric_to_watch_is_better": True, "average_best_models": False}
 
 
     def test_build_kd_module_with_pretrained_teacher(self):
     def test_build_kd_module_with_pretrained_teacher(self):
-        kd_model = KDModel("build_kd_module_with_pretrained_teacher", device='cpu')
-        kd_model.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
+        kd_trainer = KDTrainer("build_kd_module_with_pretrained_teacher", device='cpu')
+        kd_trainer.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
                              student_arch_params={'num_classes': 1000}, teacher_arch_params={'num_classes': 1000},
                              student_arch_params={'num_classes': 1000}, teacher_arch_params={'num_classes': 1000},
                              checkpoint_params={'teacher_pretrained_weights': "imagenet"}
                              checkpoint_params={'teacher_pretrained_weights': "imagenet"}
                              )
                              )
-        imagenet_resnet50_sg_model = SgModel("pretrained_resnet50")
-        imagenet_resnet50_sg_model.build_model('resnet50', arch_params={'num_classes': 1000},
+        imagenet_resnet50_trainer = Trainer("pretrained_resnet50")
+        imagenet_resnet50_trainer.build_model('resnet50', arch_params={'num_classes': 1000},
                                                checkpoint_params={'pretrained_weights': "imagenet"})
                                                checkpoint_params={'pretrained_weights': "imagenet"})
 
 
-        self.assertTrue(check_models_have_same_weights(kd_model.net.module.teacher,
-                                                       imagenet_resnet50_sg_model.net.module))
+        self.assertTrue(check_models_have_same_weights(kd_trainer.net.module.teacher,
+                                                       imagenet_resnet50_trainer.net.module))
 
 
     def test_build_kd_module_with_pretrained_student(self):
     def test_build_kd_module_with_pretrained_student(self):
-        kd_model = KDModel("build_kd_module_with_pretrained_student", device='cpu')
-        kd_model.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
+        kd_trainer = KDTrainer("build_kd_module_with_pretrained_student", device='cpu')
+        kd_trainer.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
                              student_arch_params={'num_classes': 1000}, teacher_arch_params={'num_classes': 1000},
                              student_arch_params={'num_classes': 1000}, teacher_arch_params={'num_classes': 1000},
                              checkpoint_params={'student_pretrained_weights': "imagenet",
                              checkpoint_params={'student_pretrained_weights': "imagenet",
                                                 'teacher_pretrained_weights': "imagenet"}
                                                 'teacher_pretrained_weights': "imagenet"}
                              )
                              )
 
 
-        imagenet_resnet18_sg_model = SgModel("pretrained_resnet18", device='cpu')
-        imagenet_resnet18_sg_model.build_model('resnet18', arch_params={'num_classes': 1000},
+        imagenet_resnet18_trainer = Trainer("pretrained_resnet18", device='cpu')
+        imagenet_resnet18_trainer.build_model('resnet18', arch_params={'num_classes': 1000},
                                                checkpoint_params={'pretrained_weights': "imagenet"})
                                                checkpoint_params={'pretrained_weights': "imagenet"})
 
 
-        self.assertTrue(check_models_have_same_weights(kd_model.net.module.student,
-                                                       imagenet_resnet18_sg_model.net.module))
+        self.assertTrue(check_models_have_same_weights(kd_trainer.net.module.student,
+                                                       imagenet_resnet18_trainer.net.module))
 
 
     def test_build_kd_module_pretrained_student_with_head_replacement(self):
     def test_build_kd_module_pretrained_student_with_head_replacement(self):
         self.sg_trained_teacher.train(self.train_params)
         self.sg_trained_teacher.train(self.train_params)
         teacher_path = os.path.join(self.sg_trained_teacher.checkpoints_dir_path, 'ckpt_latest.pth')
         teacher_path = os.path.join(self.sg_trained_teacher.checkpoints_dir_path, 'ckpt_latest.pth')
 
 
-        sg_kd_model = KDModel('test_build_kd_module_student_replace_head', device='cpu')
-        sg_kd_model.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
+        kd_trainer = KDTrainer('test_build_kd_module_student_replace_head', device='cpu')
+        kd_trainer.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
                                 student_arch_params={'num_classes': 5}, teacher_arch_params={'num_classes': 5},
                                 student_arch_params={'num_classes': 5}, teacher_arch_params={'num_classes': 5},
                                 checkpoint_params={'student_pretrained_weights': "imagenet",
                                 checkpoint_params={'student_pretrained_weights': "imagenet",
                                                    "teacher_checkpoint_path": teacher_path}
                                                    "teacher_checkpoint_path": teacher_path}
                                 )
                                 )
 
 
-        self.assertTrue(sg_kd_model.net.module.student.linear.out_features == 5)
+        self.assertTrue(kd_trainer.net.module.student.linear.out_features == 5)
 
 
     def test_build_kd_module_with_sg_trained_teacher(self):
     def test_build_kd_module_with_sg_trained_teacher(self):
         self.sg_trained_teacher.train(self.train_params)
         self.sg_trained_teacher.train(self.train_params)
         teacher_path = os.path.join(self.sg_trained_teacher.checkpoints_dir_path, 'ckpt_latest.pth')
         teacher_path = os.path.join(self.sg_trained_teacher.checkpoints_dir_path, 'ckpt_latest.pth')
 
 
-        sg_kd_model = KDModel('test_build_kd_module_with_sg_trained_teacher', device='cpu')
+        kd_trainer = KDTrainer('test_build_kd_module_with_sg_trained_teacher', device='cpu')
 
 
-        sg_kd_model.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
+        kd_trainer.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
                                 student_arch_params={'num_classes': 5}, teacher_arch_params={'num_classes': 5},
                                 student_arch_params={'num_classes': 5}, teacher_arch_params={'num_classes': 5},
                                 checkpoint_params={"teacher_checkpoint_path": teacher_path}
                                 checkpoint_params={"teacher_checkpoint_path": teacher_path}
                                 )
                                 )
 
 
         self.assertTrue(
         self.assertTrue(
-            check_models_have_same_weights(self.sg_trained_teacher.net.module, sg_kd_model.net.module.teacher))
+            check_models_have_same_weights(self.sg_trained_teacher.net.module, kd_trainer.net.module.teacher))
 
 
     def test_teacher_sg_module_methods(self):
     def test_teacher_sg_module_methods(self):
-        sg_kd_model = KDModel("test_teacher_sg_module_methods", device='cpu')
-        sg_kd_model.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
+        kd_trainer = KDTrainer("test_teacher_sg_module_methods", device='cpu')
+        kd_trainer.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
                                 student_arch_params={'num_classes': 1000}, teacher_arch_params={'num_classes': 1000},
                                 student_arch_params={'num_classes': 1000}, teacher_arch_params={'num_classes': 1000},
                                 checkpoint_params={'teacher_pretrained_weights': "imagenet"}
                                 checkpoint_params={'teacher_pretrained_weights': "imagenet"}
                                 )
                                 )
 
 
-        initial_param_groups = sg_kd_model.net.module.initialize_param_groups(lr=0.1, training_params={})
-        updated_param_groups = sg_kd_model.net.module.update_param_groups(param_groups=initial_param_groups, lr=0.2,
+        initial_param_groups = kd_trainer.net.module.initialize_param_groups(lr=0.1, training_params={})
+        updated_param_groups = kd_trainer.net.module.update_param_groups(param_groups=initial_param_groups, lr=0.2,
                                                                           epoch=0, iter=0, training_params={},
                                                                           epoch=0, iter=0, training_params={},
                                                                           total_batch=None)
                                                                           total_batch=None)
 
 
         self.assertTrue(initial_param_groups[0]['lr'] == 0.2 == updated_param_groups[0]['lr'])
         self.assertTrue(initial_param_groups[0]['lr'] == 0.2 == updated_param_groups[0]['lr'])
 
 
     def test_kd_architecture_kwarg_exception_catching(self):
     def test_kd_architecture_kwarg_exception_catching(self):
-        sg_kd_model = KDModel("test_teacher_sg_module_methods", device='cpu')
+        kd_trainer = KDTrainer("test_teacher_sg_module_methods", device='cpu')
         with self.assertRaises(ArchitectureKwargsException):
         with self.assertRaises(ArchitectureKwargsException):
-            sg_kd_model.build_model(teacher_architecture='resnet50',
+            kd_trainer.build_model(teacher_architecture='resnet50',
                                     student_arch_params={'num_classes': 5}, teacher_arch_params={'num_classes': 5},
                                     student_arch_params={'num_classes': 5}, teacher_arch_params={'num_classes': 5},
                                     checkpoint_params={'teacher_pretrained_weights': "imagenet"}
                                     checkpoint_params={'teacher_pretrained_weights': "imagenet"}
                                     )
                                     )
 
 
     def test_kd_unsupported_kdmodel_arg_exception_catching(self):
     def test_kd_unsupported_kdmodel_arg_exception_catching(self):
-        sg_kd_model = KDModel("test_teacher_sg_module_methods", device='cpu')
+        kd_trainer = KDTrainer("test_teacher_sg_module_methods", device='cpu')
         with self.assertRaises(UnsupportedKDModelArgException):
         with self.assertRaises(UnsupportedKDModelArgException):
-            sg_kd_model.build_model(student_architecture='resnet18',
+            kd_trainer.build_model(student_architecture='resnet18',
                                     teacher_architecture='resnet50',
                                     teacher_architecture='resnet50',
                                     student_arch_params={'num_classes': 1000},
                                     student_arch_params={'num_classes': 1000},
                                     teacher_arch_params={'num_classes': 1000},
                                     teacher_arch_params={'num_classes': 1000},
@@ -131,9 +131,9 @@ class KDModelTest(unittest.TestCase):
                                     )
                                     )
 
 
     def test_kd_unsupported_model_exception_catching(self):
     def test_kd_unsupported_model_exception_catching(self):
-        sg_kd_model = KDModel("test_teacher_sg_module_methods", device='cpu')
+        kd_trainer = KDTrainer("test_teacher_sg_module_methods", device='cpu')
         with self.assertRaises(UnsupportedKDArchitectureException):
         with self.assertRaises(UnsupportedKDArchitectureException):
-            sg_kd_model.build_model(student_architecture='resnet18',
+            kd_trainer.build_model(student_architecture='resnet18',
                                     teacher_architecture='resnet50',
                                     teacher_architecture='resnet50',
                                     student_arch_params={'num_classes': 1000},
                                     student_arch_params={'num_classes': 1000},
                                     teacher_arch_params={'num_classes': 1000},
                                     teacher_arch_params={'num_classes': 1000},
@@ -142,90 +142,90 @@ class KDModelTest(unittest.TestCase):
                                     )
                                     )
 
 
     def test_kd_inconsistent_params_exception_catching(self):
     def test_kd_inconsistent_params_exception_catching(self):
-        sg_kd_model = KDModel("test_teacher_sg_module_methods", device='cpu')
+        kd_trainer = KDTrainer("test_teacher_sg_module_methods", device='cpu')
         with self.assertRaises(InconsistentParamsException):
         with self.assertRaises(InconsistentParamsException):
-            sg_kd_model.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
+            kd_trainer.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
                                     student_arch_params={'num_classes': 10}, teacher_arch_params={'num_classes': 1000},
                                     student_arch_params={'num_classes': 10}, teacher_arch_params={'num_classes': 1000},
                                     checkpoint_params={'teacher_pretrained_weights': "imagenet"}
                                     checkpoint_params={'teacher_pretrained_weights': "imagenet"}
                                     )
                                     )
 
 
     def test_kd_teacher_knowledge_exception_catching(self):
     def test_kd_teacher_knowledge_exception_catching(self):
-        sg_kd_model = KDModel("test_teacher_sg_module_methods", device='cpu')
+        kd_trainer = KDTrainer("test_teacher_sg_module_methods", device='cpu')
         with self.assertRaises(TeacherKnowledgeException):
         with self.assertRaises(TeacherKnowledgeException):
-            sg_kd_model.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
+            kd_trainer.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
                                     student_arch_params={'num_classes': 1000}, teacher_arch_params={'num_classes': 1000}
                                     student_arch_params={'num_classes': 1000}, teacher_arch_params={'num_classes': 1000}
                                     )
                                     )
 
 
     def test_build_external_models(self):
     def test_build_external_models(self):
-        sg_model = KDModel("test_training_with_external_teacher", device='cpu')
+        kd_trainer = KDTrainer("test_training_with_external_teacher", device='cpu')
         teacher_model = ResNet50(arch_params={}, num_classes=10)
         teacher_model = ResNet50(arch_params={}, num_classes=10)
         student_model = ResNet18(arch_params={}, num_classes=10)
         student_model = ResNet18(arch_params={}, num_classes=10)
-        sg_model.build_model(student_architecture=student_model, teacher_architecture=teacher_model,
+        kd_trainer.build_model(student_architecture=student_model, teacher_architecture=teacher_model,
                              student_arch_params={'num_classes': 10}, teacher_arch_params={'num_classes': 10}
                              student_arch_params={'num_classes': 10}, teacher_arch_params={'num_classes': 10}
                              )
                              )
 
 
         self.assertTrue(
         self.assertTrue(
-            check_models_have_same_weights(teacher_model, sg_model.net.module.teacher))
+            check_models_have_same_weights(teacher_model, kd_trainer.net.module.teacher))
         self.assertTrue(
         self.assertTrue(
-            check_models_have_same_weights(student_model, sg_model.net.module.student))
+            check_models_have_same_weights(student_model, kd_trainer.net.module.student))
 
 
     def test_train_kd_module_external_models(self):
     def test_train_kd_module_external_models(self):
-        sg_model = KDModel("test_train_kd_module_external_models", device='cpu')
+        kd_trainer = KDTrainer("test_train_kd_module_external_models", device='cpu')
         teacher_model = ResNet50(arch_params={}, num_classes=5)
         teacher_model = ResNet50(arch_params={}, num_classes=5)
         student_model = ResNet18(arch_params={}, num_classes=5)
         student_model = ResNet18(arch_params={}, num_classes=5)
-        sg_model.connect_dataset_interface(self.dataset)
-        sg_model.build_model(run_teacher_on_eval=True,
+        kd_trainer.connect_dataset_interface(self.dataset)
+        kd_trainer.build_model(run_teacher_on_eval=True,
                              student_arch_params={'num_classes': 5},
                              student_arch_params={'num_classes': 5},
                              teacher_arch_params={'num_classes': 5},
                              teacher_arch_params={'num_classes': 5},
                              student_architecture=deepcopy(student_model),
                              student_architecture=deepcopy(student_model),
                              teacher_architecture=deepcopy(teacher_model),
                              teacher_architecture=deepcopy(teacher_model),
                              )
                              )
 
 
-        sg_model.train(self.kd_train_params)
+        kd_trainer.train(self.kd_train_params)
 
 
         # TEACHER WEIGHT'S SHOULD REMAIN THE SAME
         # TEACHER WEIGHT'S SHOULD REMAIN THE SAME
         self.assertTrue(
         self.assertTrue(
-            check_models_have_same_weights(teacher_model, sg_model.net.module.teacher))
+            check_models_have_same_weights(teacher_model, kd_trainer.net.module.teacher))
 
 
         # STUDENT WEIGHT'S SHOULD NOT REMAIN THE SAME
         # STUDENT WEIGHT'S SHOULD NOT REMAIN THE SAME
         self.assertFalse(
         self.assertFalse(
-            check_models_have_same_weights(student_model, sg_model.net.module.student))
+            check_models_have_same_weights(student_model, kd_trainer.net.module.student))
 
 
     def test_train_kd_module_pretrained_ckpt(self):
     def test_train_kd_module_pretrained_ckpt(self):
-        sg_model = KDModel("test_train_kd_module_pretrained_ckpt", device='cpu')
+        kd_trainer = KDTrainer("test_train_kd_module_pretrained_ckpt", device='cpu')
         teacher_model = ResNet50(arch_params={}, num_classes=5)
         teacher_model = ResNet50(arch_params={}, num_classes=5)
         teacher_path = '/tmp/teacher.pth'
         teacher_path = '/tmp/teacher.pth'
         torch.save(teacher_model.state_dict(), teacher_path)
         torch.save(teacher_model.state_dict(), teacher_path)
-        sg_model.connect_dataset_interface(self.dataset)
+        kd_trainer.connect_dataset_interface(self.dataset)
 
 
-        sg_model.build_model(student_arch_params={'num_classes': 5},
+        kd_trainer.build_model(student_arch_params={'num_classes': 5},
                              teacher_arch_params={'num_classes': 5},
                              teacher_arch_params={'num_classes': 5},
                              student_architecture='resnet18',
                              student_architecture='resnet18',
                              teacher_architecture='resnet50',
                              teacher_architecture='resnet50',
                              checkpoint_params={"teacher_checkpoint_path": teacher_path}
                              checkpoint_params={"teacher_checkpoint_path": teacher_path}
                              )
                              )
-        sg_model.train(self.kd_train_params)
+        kd_trainer.train(self.kd_train_params)
 
 
     def test_build_model_with_input_adapter(self):
     def test_build_model_with_input_adapter(self):
         adapter = NormalizationAdapter(mean_original=[0.485, 0.456, 0.406],
         adapter = NormalizationAdapter(mean_original=[0.485, 0.456, 0.406],
                                        std_original=[0.229, 0.224, 0.225],
                                        std_original=[0.229, 0.224, 0.225],
                                        mean_required=[0.5, 0.5, 0.5],
                                        mean_required=[0.5, 0.5, 0.5],
                                        std_required=[0.5, 0.5, 0.5])
                                        std_required=[0.5, 0.5, 0.5])
-        kd_model = KDModel("build_kd_module_with_with_input_adapter", device='cpu')
-        kd_model.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
+        kd_trainer = KDTrainer("build_kd_module_with_with_input_adapter", device='cpu')
+        kd_trainer.build_model(student_architecture='resnet18', teacher_architecture='resnet50',
                              student_arch_params={'num_classes': 1000}, teacher_arch_params={'num_classes': 1000},
                              student_arch_params={'num_classes': 1000}, teacher_arch_params={'num_classes': 1000},
                              checkpoint_params={'teacher_pretrained_weights': "imagenet"},
                              checkpoint_params={'teacher_pretrained_weights': "imagenet"},
                              arch_params={"teacher_input_adapter": adapter})
                              arch_params={"teacher_input_adapter": adapter})
-        self.assertEqual(kd_model.net.module.teacher_input_adapter, adapter)
+        self.assertEqual(kd_trainer.net.module.teacher_input_adapter, adapter)
 
 
     def test_load_ckpt_best_for_student(self):
     def test_load_ckpt_best_for_student(self):
-        sg_model = KDModel("test_load_ckpt_best", device='cpu')
+        kd_trainer = KDTrainer("test_load_ckpt_best", device='cpu')
         teacher_model = ResNet50(arch_params={}, num_classes=5)
         teacher_model = ResNet50(arch_params={}, num_classes=5)
         teacher_path = '/tmp/teacher.pth'
         teacher_path = '/tmp/teacher.pth'
         torch.save(teacher_model.state_dict(), teacher_path)
         torch.save(teacher_model.state_dict(), teacher_path)
-        sg_model.connect_dataset_interface(self.dataset)
+        kd_trainer.connect_dataset_interface(self.dataset)
 
 
-        sg_model.build_model(student_arch_params={'num_classes': 5},
+        kd_trainer.build_model(student_arch_params={'num_classes': 5},
                              teacher_arch_params={'num_classes': 5},
                              teacher_arch_params={'num_classes': 5},
                              student_architecture='resnet18',
                              student_architecture='resnet18',
                              teacher_architecture='resnet50',
                              teacher_architecture='resnet50',
@@ -233,24 +233,24 @@ class KDModelTest(unittest.TestCase):
                              )
                              )
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
         train_params["max_epochs"] = 1
         train_params["max_epochs"] = 1
-        sg_model.train(train_params)
-        best_student_ckpt = os.path.join(sg_model.checkpoints_dir_path, "ckpt_best.pth")
+        kd_trainer.train(train_params)
+        best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
 
 
-        student_sg_model = SgModel("studnet_sg_model")
-        student_sg_model.build_model("resnet18", arch_params={'num_classes': 5},
+        student_trainer = Trainer("studnet_trainer")
+        student_trainer.build_model("resnet18", arch_params={'num_classes': 5},
                                      checkpoint_params={"load_checkpoint": True, "external_checkpoint_path": best_student_ckpt})
                                      checkpoint_params={"load_checkpoint": True, "external_checkpoint_path": best_student_ckpt})
 
 
         self.assertTrue(
         self.assertTrue(
-            check_models_have_same_weights(student_sg_model.net.module, sg_model.net.module.student))
+            check_models_have_same_weights(student_trainer.net.module, kd_trainer.net.module.student))
 
 
     def test_load_ckpt_best_for_student_with_ema(self):
     def test_load_ckpt_best_for_student_with_ema(self):
-        sg_model = KDModel("test_load_ckpt_best_for_student_with_ema", device='cpu')
+        kd_trainer = KDTrainer("test_load_ckpt_best_for_student_with_ema", device='cpu')
         teacher_model = ResNet50(arch_params={}, num_classes=5)
         teacher_model = ResNet50(arch_params={}, num_classes=5)
         teacher_path = '/tmp/teacher.pth'
         teacher_path = '/tmp/teacher.pth'
         torch.save(teacher_model.state_dict(), teacher_path)
         torch.save(teacher_model.state_dict(), teacher_path)
-        sg_model.connect_dataset_interface(self.dataset)
+        kd_trainer.connect_dataset_interface(self.dataset)
 
 
-        sg_model.build_model(student_arch_params={'num_classes': 5},
+        kd_trainer.build_model(student_arch_params={'num_classes': 5},
                              teacher_arch_params={'num_classes': 5},
                              teacher_arch_params={'num_classes': 5},
                              student_architecture='resnet18',
                              student_architecture='resnet18',
                              teacher_architecture='resnet50',
                              teacher_architecture='resnet50',
@@ -259,14 +259,14 @@ class KDModelTest(unittest.TestCase):
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
         train_params["max_epochs"] = 1
         train_params["max_epochs"] = 1
         train_params["ema"] = True
         train_params["ema"] = True
-        sg_model.train(train_params)
-        best_student_ckpt = os.path.join(sg_model.checkpoints_dir_path, "ckpt_best.pth")
+        kd_trainer.train(train_params)
+        best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
 
 
-        student_sg_model = SgModel("studnet_sg_model")
-        student_sg_model.build_model("resnet18", arch_params={'num_classes': 5},
+        student_trainer = Trainer("studnet_trainer")
+        student_trainer.build_model("resnet18", arch_params={'num_classes': 5},
                                      checkpoint_params={"load_checkpoint": True, "external_checkpoint_path": best_student_ckpt})
                                      checkpoint_params={"load_checkpoint": True, "external_checkpoint_path": best_student_ckpt})
         self.assertTrue(
         self.assertTrue(
-            check_models_have_same_weights(student_sg_model.net.module, sg_model.ema_model.ema.module.student))
+            check_models_have_same_weights(student_trainer.net.module, kd_trainer.ema_model.ema.module.student))
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -2,11 +2,11 @@ import shutil
 import tempfile
 import tempfile
 import unittest
 import unittest
 import os
 import os
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.functional as F
-from super_gradients.training.sg_model.sg_model import StrictLoad
+from super_gradients.training.sg_trainer.sg_trainer import StrictLoad
 
 
 
 
 class Net(nn.Module):
 class Net(nn.Module):
@@ -55,15 +55,15 @@ class LoadCheckpointFromDirectPathTest(unittest.TestCase):
         # Make sure we initialized a model with different weights
         # Make sure we initialized a model with different weights
         assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
         assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
 
 
-        # Build the SgModel and load the checkpoint
-        model = SgModel("load_checkpoint_test", model_checkpoints_location='local')
-        model.build_model(new_torch_net, arch_params={'num_classes': 10},
-                          checkpoint_params={'external_checkpoint_path': self.checkpoint_path,
-                                             'load_checkpoint': True,
-                                             'strict_load': StrictLoad.NO_KEY_MATCHING})
+        # Build the Trainer and load the checkpoint
+        trainer = Trainer("load_checkpoint_test", model_checkpoints_location='local')
+        trainer.build_model(new_torch_net, arch_params={'num_classes': 10},
+                            checkpoint_params={'external_checkpoint_path': self.checkpoint_path,
+                                               'load_checkpoint': True,
+                                               'strict_load': StrictLoad.NO_KEY_MATCHING})
 
 
         # Assert the weights were loaded correctly
         # Assert the weights were loaded correctly
-        assert self.check_models_have_same_weights(model.net, self.original_torch_net)
+        assert self.check_models_have_same_weights(trainer.net, self.original_torch_net)
 
 
     def check_models_have_same_weights(self, model_1, model_2):
     def check_models_have_same_weights(self, model_1, model_2):
         model_1, model_2 = model_1.to('cpu'), model_2.to('cpu')
         model_1, model_2 = model_1.to('cpu'), model_2.to('cpu')
Discard
@@ -1,5 +1,5 @@
 import unittest
 import unittest
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.utils.utils import check_models_have_same_weights
 from super_gradients.training.utils.utils import check_models_have_same_weights
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
@@ -20,24 +20,24 @@ class LoadCheckpointWithEmaTest(unittest.TestCase):
     def test_ema_ckpt_reload(self):
     def test_ema_ckpt_reload(self):
         # Define Model
         # Define Model
         net = LeNet()
         net = LeNet()
-        model = SgModel("ema_ckpt_test", model_checkpoints_location='local')
+        trainer = Trainer("ema_ckpt_test", model_checkpoints_location='local')
 
 
-        model.connect_dataset_interface(self.dataset)
-        model.build_model(net, arch_params={'num_classes': 10})
+        trainer.connect_dataset_interface(self.dataset)
+        trainer.build_model(net, arch_params={'num_classes': 10})
 
 
-        model.train(self.train_params)
+        trainer.train(self.train_params)
 
 
-        ema_model = model.ema_model.ema
+        ema_model = trainer.ema_model.ema
 
 
         net = LeNet()
         net = LeNet()
-        model = SgModel("ema_ckpt_test", model_checkpoints_location='local')
-        model.build_model(net, arch_params={'num_classes': 10}, checkpoint_params={'load_checkpoint': True})
-        model.connect_dataset_interface(self.dataset)
+        trainer = Trainer("ema_ckpt_test", model_checkpoints_location='local')
+        trainer.build_model(net, arch_params={'num_classes': 10}, checkpoint_params={'load_checkpoint': True})
+        trainer.connect_dataset_interface(self.dataset)
 
 
         # TRAIN FOR 0 EPOCHS JUST TO SEE THAT WHEN CONTINUING TRAINING EMA MODEL HAS BEEN SAVED CORRECTLY
         # TRAIN FOR 0 EPOCHS JUST TO SEE THAT WHEN CONTINUING TRAINING EMA MODEL HAS BEEN SAVED CORRECTLY
-        model.train(self.train_params)
+        trainer.train(self.train_params)
 
 
-        reloaded_ema_model = model.ema_model.ema
+        reloaded_ema_model = trainer.ema_model.ema
 
 
         # ASSERT RELOADED EMA MODEL HAS THE SAME WEIGHTS AS THE EMA MODEL SAVED IN FIRST PART OF TRAINING
         # ASSERT RELOADED EMA MODEL HAS THE SAME WEIGHTS AS THE EMA MODEL SAVED IN FIRST PART OF TRAINING
         assert check_models_have_same_weights(ema_model, reloaded_ema_model)
         assert check_models_have_same_weights(ema_model, reloaded_ema_model)
Discard
@@ -1,5 +1,5 @@
 import unittest
 import unittest
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
 from super_gradients.training.models import LeNet
 from super_gradients.training.models import LeNet
@@ -15,9 +15,9 @@ class LRCooldownTest(unittest.TestCase):
     def test_lr_cooldown_with_lr_scheduling(self):
     def test_lr_cooldown_with_lr_scheduling(self):
         # Define Model
         # Define Model
         net = LeNet()
         net = LeNet()
-        model = SgModel("lr_warmup_test", model_checkpoints_location='local')
-        model.connect_dataset_interface(self.dataset)
-        model.build_model(net, arch_params=self.arch_params)
+        trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
+        trainer.connect_dataset_interface(self.dataset)
+        trainer.build_model(net, arch_params=self.arch_params)
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -31,7 +31,7 @@ class LRCooldownTest(unittest.TestCase):
                         "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks}
                         "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks}
 
 
         expected_lrs = [0.25, 0.5, 0.75, 0.9236067977499791, 0.4763932022500211, 0.4763932022500211, 0.4763932022500211]
         expected_lrs = [0.25, 0.5, 0.75, 0.9236067977499791, 0.4763932022500211, 0.4763932022500211, 0.4763932022500211]
-        model.train(train_params)
+        trainer.train(train_params)
 
 
         # ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
         # ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
         # THE LRS AFTER THE UPDATE
         # THE LRS AFTER THE UPDATE
Discard
@@ -2,7 +2,7 @@ import unittest
 
 
 import numpy as np
 import numpy as np
 
 
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
 from super_gradients.training.models import LeNet
 from super_gradients.training.models import LeNet
@@ -42,9 +42,9 @@ class LRWarmupTest(unittest.TestCase):
     def test_lr_warmup(self):
     def test_lr_warmup(self):
         # Define Model
         # Define Model
         net = LeNet()
         net = LeNet()
-        model = SgModel("lr_warmup_test", model_checkpoints_location='local')
-        model.connect_dataset_interface(self.dataset)
-        model.build_model(net, arch_params=self.arch_params)
+        trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
+        trainer.connect_dataset_interface(self.dataset)
+        trainer.build_model(net, arch_params=self.arch_params)
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -58,15 +58,15 @@ class LRWarmupTest(unittest.TestCase):
                         "warmup_mode": "linear_step"}
                         "warmup_mode": "linear_step"}
 
 
         expected_lrs = [0.25, 0.5, 0.75, 1.0, 1.0]
         expected_lrs = [0.25, 0.5, 0.75, 1.0, 1.0]
-        model.train(train_params)
+        trainer.train(train_params)
         self.assertListEqual(lrs, expected_lrs)
         self.assertListEqual(lrs, expected_lrs)
 
 
     def test_lr_warmup_with_lr_scheduling(self):
     def test_lr_warmup_with_lr_scheduling(self):
         # Define Model
         # Define Model
         net = LeNet()
         net = LeNet()
-        model = SgModel("lr_warmup_test", model_checkpoints_location='local')
-        model.connect_dataset_interface(self.dataset)
-        model.build_model(net, arch_params=self.arch_params)
+        trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
+        trainer.connect_dataset_interface(self.dataset)
+        trainer.build_model(net, arch_params=self.arch_params)
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -80,7 +80,7 @@ class LRWarmupTest(unittest.TestCase):
                         "warmup_mode": "linear_step"}
                         "warmup_mode": "linear_step"}
 
 
         expected_lrs = [0.25, 0.5, 0.75, 0.9236067977499791, 0.4763932022500211]
         expected_lrs = [0.25, 0.5, 0.75, 0.9236067977499791, 0.4763932022500211]
-        model.train(train_params)
+        trainer.train(train_params)
 
 
         # ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
         # ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
         # THE LRS AFTER THE UPDATE
         # THE LRS AFTER THE UPDATE
@@ -89,9 +89,9 @@ class LRWarmupTest(unittest.TestCase):
     def test_warmup_initial_lr(self):
     def test_warmup_initial_lr(self):
         # Define Model
         # Define Model
         net = LeNet()
         net = LeNet()
-        model = SgModel("test_warmup_initial_lr", model_checkpoints_location='local')
-        model.connect_dataset_interface(self.dataset)
-        model.build_model(net, arch_params=self.arch_params)
+        trainer = Trainer("test_warmup_initial_lr", model_checkpoints_location='local')
+        trainer.connect_dataset_interface(self.dataset)
+        trainer.build_model(net, arch_params=self.arch_params)
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -105,15 +105,15 @@ class LRWarmupTest(unittest.TestCase):
                         "warmup_mode": "linear_step", "initial_lr": 1, "warmup_initial_lr": 4.}
                         "warmup_mode": "linear_step", "initial_lr": 1, "warmup_initial_lr": 4.}
 
 
         expected_lrs = [4., 3., 2., 1., 1.]
         expected_lrs = [4., 3., 2., 1., 1.]
-        model.train(train_params)
+        trainer.train(train_params)
         self.assertListEqual(lrs, expected_lrs)
         self.assertListEqual(lrs, expected_lrs)
 
 
     def test_custom_lr_warmup(self):
     def test_custom_lr_warmup(self):
         # Define Model
         # Define Model
         net = LeNet()
         net = LeNet()
-        model = SgModel("custom_lr_warmup_test", model_checkpoints_location='local')
-        model.connect_dataset_interface(self.dataset)
-        model.build_model(net, arch_params=self.arch_params)
+        trainer = Trainer("custom_lr_warmup_test", model_checkpoints_location='local')
+        trainer.connect_dataset_interface(self.dataset)
+        trainer.build_model(net, arch_params=self.arch_params)
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -127,7 +127,7 @@ class LRWarmupTest(unittest.TestCase):
                         "warmup_mode": ExponentialWarmupLRCallback, "initial_lr": 1., "warmup_initial_lr": 0.1}
                         "warmup_mode": ExponentialWarmupLRCallback, "initial_lr": 1., "warmup_initial_lr": 0.1}
 
 
         expected_lrs = [0.1, 0.18102751585334242, 0.40128313980266034, 1.0, 1.0]
         expected_lrs = [0.1, 0.18102751585334242, 0.40128313980266034, 1.0, 1.0]
-        model.train(train_params)
+        trainer.train(train_params)
         self.assertListEqual(lrs, expected_lrs)
         self.assertListEqual(lrs, expected_lrs)
 
 
 
 
Discard
@@ -1,6 +1,6 @@
 import unittest
 import unittest
 from super_gradients.training.utils.callbacks import PhaseContextTestCallback, Phase
 from super_gradients.training.utils.callbacks import PhaseContextTestCallback, Phase
-from super_gradients import SgModel, \
+from super_gradients import Trainer, \
     ClassificationTestDatasetInterface
     ClassificationTestDatasetInterface
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.models import ResNet18
 from super_gradients.training.models import ResNet18
@@ -11,13 +11,13 @@ from torchmetrics import MetricCollection
 
 
 class PhaseContextTest(unittest.TestCase):
 class PhaseContextTest(unittest.TestCase):
     def context_information_in_train_test(self):
     def context_information_in_train_test(self):
-        model = SgModel("context_information_in_train_test", model_checkpoints_location='local')
+        trainer = Trainer("context_information_in_train_test", model_checkpoints_location='local')
         dataset_params = {"batch_size": 10}
         dataset_params = {"batch_size": 10}
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
-        model.connect_dataset_interface(dataset)
+        trainer.connect_dataset_interface(dataset)
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
-        model.build_model(net)
+        trainer.build_model(net)
 
 
         phase_callbacks = [PhaseContextTestCallback(Phase.TRAIN_BATCH_END),
         phase_callbacks = [PhaseContextTestCallback(Phase.TRAIN_BATCH_END),
                            PhaseContextTestCallback(Phase.TRAIN_BATCH_STEP),
                            PhaseContextTestCallback(Phase.TRAIN_BATCH_STEP),
@@ -32,8 +32,8 @@ class PhaseContextTest(unittest.TestCase):
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Top5",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Top5",
                         "greater_metric_to_watch_is_better": True, "phase_callbacks": phase_callbacks}
                         "greater_metric_to_watch_is_better": True, "phase_callbacks": phase_callbacks}
 
 
-        model.train(train_params)
-        context_callbacks = list(filter(lambda cb: isinstance(cb, PhaseContextTestCallback), model.phase_callbacks))
+        trainer.train(train_params)
+        context_callbacks = list(filter(lambda cb: isinstance(cb, PhaseContextTestCallback), trainer.phase_callbacks))
 
 
         # CHECK THAT PHASE CONTEXES HAVE THE EXACT INFORMATION THERY'RE SUPPOSE TO HOLD
         # CHECK THAT PHASE CONTEXES HAVE THE EXACT INFORMATION THERY'RE SUPPOSE TO HOLD
         for phase_callback in context_callbacks:
         for phase_callback in context_callbacks:
Discard
@@ -1,5 +1,5 @@
 import unittest
 import unittest
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
 from super_gradients.training.models import LeNet
 from super_gradients.training.models import LeNet
@@ -8,7 +8,7 @@ from super_gradients.training.utils.callbacks import Phase, PhaseCallback, Phase
 
 
 class ContextMethodsCheckerCallback(PhaseCallback):
 class ContextMethodsCheckerCallback(PhaseCallback):
     """
     """
-    Callback for checking that at a certain phase specific SgModel methods are accessible.
+    Callback for checking that at a certain phase specific Trainer methods are accessible.
     """
     """
 
 
     def __init__(self, phase: Phase, accessible_method_names: list, non_accessible_method_names: list):
     def __init__(self, phase: Phase, accessible_method_names: list, non_accessible_method_names: list):
@@ -35,9 +35,9 @@ class ContextMethodsTest(unittest.TestCase):
 
 
     def test_access_to_methods_by_phase(self):
     def test_access_to_methods_by_phase(self):
         net = LeNet()
         net = LeNet()
-        model = SgModel("test_access_to_methods_by_phase", model_checkpoints_location='local')
-        model.connect_dataset_interface(self.dataset)
-        model.build_model(net, arch_params=self.arch_params)
+        trainer = Trainer("test_access_to_methods_by_phase", model_checkpoints_location='local')
+        trainer.connect_dataset_interface(self.dataset)
+        trainer.build_model(net, arch_params=self.arch_params)
 
 
         phase_callbacks = []
         phase_callbacks = []
         for phase in Phase:
         for phase in Phase:
@@ -68,7 +68,7 @@ class ContextMethodsTest(unittest.TestCase):
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks}
                         "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks}
 
 
-        model.train(train_params)
+        trainer.train(train_params)
         for phase_callback in phase_callbacks:
         for phase_callback in phase_callbacks:
             if isinstance(phase_callback, ContextMethodsCheckerCallback):
             if isinstance(phase_callback, ContextMethodsCheckerCallback):
                 self.assertTrue(phase_callback.result)
                 self.assertTrue(phase_callback.result)
Discard
@@ -1,7 +1,7 @@
 import unittest
 import unittest
 import super_gradients
 import super_gradients
 from super_gradients.training import MultiGPUMode
 from super_gradients.training import MultiGPUMode
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
 import os
 import os
@@ -16,7 +16,7 @@ class PretrainedModelsUnitTest(unittest.TestCase):
         self.test_dataset = ClassificationTestDatasetInterface(classes=range(1000))
         self.test_dataset = ClassificationTestDatasetInterface(classes=range(1000))
 
 
     def test_pretrained_resnet50_imagenet(self):
     def test_pretrained_resnet50_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_resnet50_unit_test', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet50_unit_test', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.test_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.test_dataset, data_loader_num_workers=8)
         trainer.build_model("resnet50", checkpoint_params={"pretrained_weights": "imagenet"})
         trainer.build_model("resnet50", checkpoint_params={"pretrained_weights": "imagenet"})
@@ -24,7 +24,7 @@ class PretrainedModelsUnitTest(unittest.TestCase):
                      metrics_progress_verbose=True)
                      metrics_progress_verbose=True)
 
 
     def test_pretrained_regnetY800_imagenet(self):
     def test_pretrained_regnetY800_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_regnetY800_unit_test', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY800_unit_test', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.test_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.test_dataset, data_loader_num_workers=8)
         trainer.build_model("regnetY800", checkpoint_params={"pretrained_weights": "imagenet"})
         trainer.build_model("regnetY800", checkpoint_params={"pretrained_weights": "imagenet"})
@@ -32,7 +32,7 @@ class PretrainedModelsUnitTest(unittest.TestCase):
                      metrics_progress_verbose=True)
                      metrics_progress_verbose=True)
 
 
     def test_pretrained_repvgg_a0_imagenet(self):
     def test_pretrained_repvgg_a0_imagenet(self):
-        trainer = SgModel('imagenet_pretrained_repvgg_a0_unit_test', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_repvgg_a0_unit_test', model_checkpoints_location='local',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         trainer.connect_dataset_interface(self.test_dataset, data_loader_num_workers=8)
         trainer.connect_dataset_interface(self.test_dataset, data_loader_num_workers=8)
         trainer.build_model("repvgg_a0", checkpoint_params={"pretrained_weights": "imagenet"},
         trainer.build_model("repvgg_a0", checkpoint_params={"pretrained_weights": "imagenet"},
Discard
@@ -1,6 +1,6 @@
 import unittest
 import unittest
 import os
 import os
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 
 
@@ -18,19 +18,19 @@ class SaveCkptListUnitTest(unittest.TestCase):
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
 
 
         # Define Model
         # Define Model
-        model = SgModel("save_ckpt_test", model_checkpoints_location='local')
+        trainer = Trainer("save_ckpt_test", model_checkpoints_location='local')
 
 
         # Connect Dataset
         # Connect Dataset
         dataset = ClassificationTestDatasetInterface()
         dataset = ClassificationTestDatasetInterface()
-        model.connect_dataset_interface(dataset, data_loader_num_workers=8)
+        trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
 
 
         # Build Model
         # Build Model
-        model.build_model("resnet18_cifar")
+        trainer.build_model("resnet18_cifar")
 
 
         # Train Model (and save ckpt_epoch_list)
         # Train Model (and save ckpt_epoch_list)
-        model.train(training_params=train_params)
+        trainer.train(training_params=train_params)
 
 
-        dir_path = model.checkpoints_dir_path
+        dir_path = trainer.checkpoints_dir_path
         self.file_names_list = [dir_path + f'/ckpt_epoch_{epoch}.pth' for epoch in train_params["save_ckpt_epoch_list"]]
         self.file_names_list = [dir_path + f'/ckpt_epoch_{epoch}.pth' for epoch in train_params["save_ckpt_epoch_list"]]
 
 
     def test_save_ckpt_epoch_list(self):
     def test_save_ckpt_epoch_list(self):
Discard
@@ -4,12 +4,12 @@ import unittest
 import os
 import os
 
 
 from super_gradients.common.sg_loggers import BaseSGLogger
 from super_gradients.common.sg_loggers import BaseSGLogger
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.functional as F
 
 
-from super_gradients.training.sg_model.sg_model import StrictLoad
+from super_gradients.training.sg_trainer.sg_trainer import StrictLoad
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.utils import HpmStruct
 
 
 
 
@@ -51,14 +51,14 @@ class StrictLoadEnumTest(unittest.TestCase):
         # Save the model's state_dict checkpoint with different keys
         # Save the model's state_dict checkpoint with different keys
         torch.save(cls.change_state_dict_keys(cls.original_torch_net.state_dict()), cls.checkpoint_diff_keys_path)
         torch.save(cls.change_state_dict_keys(cls.original_torch_net.state_dict()), cls.checkpoint_diff_keys_path)
 
 
-        # Save the model's state_dict checkpoint in SgModel format
-        cls.sg_model = SgModel("load_checkpoint_test", model_checkpoints_location='local')  # Saves in /checkpoints
-        cls.sg_model.build_model(cls.original_torch_net, arch_params={'num_classes': 10})
+        # Save the model's state_dict checkpoint in Trainer format
+        cls.trainer = Trainer("load_checkpoint_test", model_checkpoints_location='local')  # Saves in /checkpoints
+        cls.trainer.build_model(cls.original_torch_net, arch_params={'num_classes': 10})
         # FIXME: after uniting init and build_model we should remove this
         # FIXME: after uniting init and build_model we should remove this
-        cls.sg_model.sg_logger = BaseSGLogger('project_name', 'load_checkpoint_test', 'local', resumed=False,
-                                              training_params=HpmStruct(max_epochs=10),
-                                              checkpoints_dir_path=cls.sg_model.checkpoints_dir_path)
-        cls.sg_model._save_checkpoint()
+        cls.trainer.sg_logger = BaseSGLogger('project_name', 'load_checkpoint_test', 'local', resumed=False,
+                                             training_params=HpmStruct(max_epochs=10),
+                                             checkpoints_dir_path=cls.trainer.checkpoints_dir_path)
+        cls.trainer._save_checkpoint()
 
 
     @classmethod
     @classmethod
     def tearDownClass(cls):
     def tearDownClass(cls):
@@ -96,15 +96,15 @@ class StrictLoadEnumTest(unittest.TestCase):
         # Make sure we initialized a model with different weights
         # Make sure we initialized a model with different weights
         assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
         assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
 
 
-        # Build the SgModel and load the checkpoint
-        model = SgModel(self.experiment_name, model_checkpoints_location='local',
-                        ckpt_name='ckpt_latest_weights_only.pth')
-        model.build_model(new_torch_net, arch_params={'num_classes': 10},
-                          checkpoint_params={'strict_load': StrictLoad.ON,
-                                             'load_checkpoint': True})
+        # Build the Trainer and load the checkpoint
+        trainer = Trainer(self.experiment_name, model_checkpoints_location='local',
+                          ckpt_name='ckpt_latest_weights_only.pth')
+        trainer.build_model(new_torch_net, arch_params={'num_classes': 10},
+                            checkpoint_params={'strict_load': StrictLoad.ON,
+                                               'load_checkpoint': True})
 
 
         # Assert the weights were loaded correctly
         # Assert the weights were loaded correctly
-        assert self.check_models_have_same_weights(model.net, self.original_torch_net)
+        assert self.check_models_have_same_weights(trainer.net, self.original_torch_net)
 
 
     def test_strict_load_off(self):
     def test_strict_load_off(self):
         # Define Model
         # Define Model
@@ -113,15 +113,15 @@ class StrictLoadEnumTest(unittest.TestCase):
         # Make sure we initialized a model with different weights
         # Make sure we initialized a model with different weights
         assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
         assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
 
 
-        # Build the SgModel and load the checkpoint
-        model = SgModel(self.experiment_name, model_checkpoints_location='local',
-                        ckpt_name='ckpt_latest_weights_only.pth')
-        model.build_model(new_torch_net, arch_params={'num_classes': 10},
-                          checkpoint_params={'strict_load': StrictLoad.OFF,
-                                             'load_checkpoint': True})
+        # Build the Trainer and load the checkpoint
+        trainer = Trainer(self.experiment_name, model_checkpoints_location='local',
+                          ckpt_name='ckpt_latest_weights_only.pth')
+        trainer.build_model(new_torch_net, arch_params={'num_classes': 10},
+                            checkpoint_params={'strict_load': StrictLoad.OFF,
+                                               'load_checkpoint': True})
 
 
         # Assert the weights were loaded correctly
         # Assert the weights were loaded correctly
-        assert self.check_models_have_same_weights(model.net, self.original_torch_net)
+        assert self.check_models_have_same_weights(trainer.net, self.original_torch_net)
 
 
     def test_strict_load_no_key_matching_external_checkpoint(self):
     def test_strict_load_no_key_matching_external_checkpoint(self):
         # Define Model
         # Define Model
@@ -130,15 +130,15 @@ class StrictLoadEnumTest(unittest.TestCase):
         # Make sure we initialized a model with different weights
         # Make sure we initialized a model with different weights
         assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
         assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
 
 
-        # Build the SgModel and load the checkpoint
-        model = SgModel(self.experiment_name, model_checkpoints_location='local')
-        model.build_model(new_torch_net, arch_params={'num_classes': 10},
-                          checkpoint_params={'strict_load': StrictLoad.NO_KEY_MATCHING,
-                                             'external_checkpoint_path': self.checkpoint_diff_keys_path,
-                                             'load_checkpoint': True})
+        # Build the Trainer and load the checkpoint
+        trainer = Trainer(self.experiment_name, model_checkpoints_location='local')
+        trainer.build_model(new_torch_net, arch_params={'num_classes': 10},
+                            checkpoint_params={'strict_load': StrictLoad.NO_KEY_MATCHING,
+                                               'external_checkpoint_path': self.checkpoint_diff_keys_path,
+                                               'load_checkpoint': True})
 
 
         # Assert the weights were loaded correctly
         # Assert the weights were loaded correctly
-        assert self.check_models_have_same_weights(model.net, self.original_torch_net)
+        assert self.check_models_have_same_weights(trainer.net, self.original_torch_net)
 
 
     def test_strict_load_no_key_matching_sg_checkpoint(self):
     def test_strict_load_no_key_matching_sg_checkpoint(self):
         # Define Model
         # Define Model
@@ -147,15 +147,15 @@ class StrictLoadEnumTest(unittest.TestCase):
         # Make sure we initialized a model with different weights
         # Make sure we initialized a model with different weights
         assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
         assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
 
 
-        # Build the SgModel and load the checkpoint
-        model = SgModel(self.experiment_name, model_checkpoints_location='local',
-                        ckpt_name='ckpt_latest_weights_only.pth')
-        model.build_model(new_torch_net, arch_params={'num_classes': 10},
-                          checkpoint_params={'strict_load': StrictLoad.NO_KEY_MATCHING,
-                                             'load_checkpoint': True})
+        # Build the Trainer and load the checkpoint
+        trainer = Trainer(self.experiment_name, model_checkpoints_location='local',
+                          ckpt_name='ckpt_latest_weights_only.pth')
+        trainer.build_model(new_torch_net, arch_params={'num_classes': 10},
+                            checkpoint_params={'strict_load': StrictLoad.NO_KEY_MATCHING,
+                                               'load_checkpoint': True})
 
 
         # Assert the weights were loaded correctly
         # Assert the weights were loaded correctly
-        assert self.check_models_have_same_weights(model.net, self.original_torch_net)
+        assert self.check_models_have_same_weights(trainer.net, self.original_torch_net)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -1,7 +1,7 @@
 import shutil
 import shutil
 import unittest
 import unittest
 import os
 import os
-from super_gradients import SgModel, \
+from super_gradients import Trainer, \
     ClassificationTestDatasetInterface, \
     ClassificationTestDatasetInterface, \
     SegmentationTestDatasetInterface, DetectionTestDatasetInterface
     SegmentationTestDatasetInterface, DetectionTestDatasetInterface
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
@@ -27,12 +27,12 @@ class TestWithoutTrainTest(unittest.TestCase):
 
 
     @staticmethod
     @staticmethod
     def get_classification_trainer(name=''):
     def get_classification_trainer(name=''):
-        model = SgModel(name, model_checkpoints_location='local')
+        trainer = Trainer(name, model_checkpoints_location='local')
         dataset_params = {"batch_size": 4}
         dataset_params = {"batch_size": 4}
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
-        model.connect_dataset_interface(dataset)
-        model.build_model("resnet18_cifar")
-        return model
+        trainer.connect_dataset_interface(dataset)
+        trainer.build_model("resnet18_cifar")
+        return trainer
 
 
     @staticmethod
     @staticmethod
     def get_detection_trainer(name=''):
     def get_detection_trainer(name=''):
@@ -45,46 +45,46 @@ class TestWithoutTrainTest(unittest.TestCase):
                           "train_collate_fn": DetectionCollateFN(),
                           "train_collate_fn": DetectionCollateFN(),
                           }
                           }
 
 
-        model = SgModel(name, model_checkpoints_location='local',
-                        multi_gpu=MultiGPUMode.OFF,
-                        post_prediction_callback=YoloPostPredictionCallback())
+        trainer = Trainer(name, model_checkpoints_location='local',
+                          multi_gpu=MultiGPUMode.OFF,
+                          post_prediction_callback=YoloPostPredictionCallback())
         dataset_interface = DetectionTestDatasetInterface(dataset_params=dataset_params)
         dataset_interface = DetectionTestDatasetInterface(dataset_params=dataset_params)
-        model.connect_dataset_interface(dataset_interface, data_loader_num_workers=4)
-        model.build_model('yolox_s')
-        return model
+        trainer.connect_dataset_interface(dataset_interface, data_loader_num_workers=4)
+        trainer.build_model('yolox_s')
+        return trainer
 
 
     @staticmethod
     @staticmethod
     def get_segmentation_trainer(name=''):
     def get_segmentation_trainer(name=''):
         shelfnet_lw_arch_params = {"num_classes": 5, "load_checkpoint": False}
         shelfnet_lw_arch_params = {"num_classes": 5, "load_checkpoint": False}
-        model = SgModel(name, model_checkpoints_location='local', multi_gpu=False)
+        trainer = Trainer(name, model_checkpoints_location='local', multi_gpu=False)
 
 
         dataset_interface = SegmentationTestDatasetInterface()
         dataset_interface = SegmentationTestDatasetInterface()
-        model.connect_dataset_interface(dataset_interface, data_loader_num_workers=8)
-        model.build_model('shelfnet34_lw', arch_params=shelfnet_lw_arch_params)
-        return model
+        trainer.connect_dataset_interface(dataset_interface, data_loader_num_workers=8)
+        trainer.build_model('shelfnet34_lw', arch_params=shelfnet_lw_arch_params)
+        return trainer
 
 
     def test_test_without_train(self):
     def test_test_without_train(self):
-        model = self.get_classification_trainer(self.folder_names[0])
-        assert isinstance(model.test(silent_mode=True, test_metrics_list=[Accuracy(), Top5()]), tuple)
+        trainer = self.get_classification_trainer(self.folder_names[0])
+        assert isinstance(trainer.test(silent_mode=True, test_metrics_list=[Accuracy(), Top5()]), tuple)
 
 
-        model = self.get_detection_trainer(self.folder_names[1])
+        trainer = self.get_detection_trainer(self.folder_names[1])
 
 
-        test_metrics = [DetectionMetrics(post_prediction_callback=model.post_prediction_callback, num_cls=5)]
+        test_metrics = [DetectionMetrics(post_prediction_callback=trainer.post_prediction_callback, num_cls=5)]
 
 
-        assert isinstance(model.test(silent_mode=True, test_metrics_list=test_metrics), tuple)
+        assert isinstance(trainer.test(silent_mode=True, test_metrics_list=test_metrics), tuple)
 
 
-        model = self.get_segmentation_trainer(self.folder_names[2])
-        assert isinstance(model.test(silent_mode=True, test_metrics_list=[IoU(21), PixelAccuracy()]), tuple)
+        trainer = self.get_segmentation_trainer(self.folder_names[2])
+        assert isinstance(trainer.test(silent_mode=True, test_metrics_list=[IoU(21), PixelAccuracy()]), tuple)
 
 
     def test_test_on_valid_loader_without_train(self):
     def test_test_on_valid_loader_without_train(self):
-        model = self.get_classification_trainer(self.folder_names[0])
-        assert isinstance(model.test(test_loader=model.valid_loader, silent_mode=True, test_metrics_list=[Accuracy(), Top5()]), tuple)
+        trainer = self.get_classification_trainer(self.folder_names[0])
+        assert isinstance(trainer.test(test_loader=trainer.valid_loader, silent_mode=True, test_metrics_list=[Accuracy(), Top5()]), tuple)
 
 
-        model = self.get_detection_trainer(self.folder_names[1])
+        trainer = self.get_detection_trainer(self.folder_names[1])
 
 
-        test_metrics = [DetectionMetrics(post_prediction_callback=model.post_prediction_callback, num_cls=5)]
+        test_metrics = [DetectionMetrics(post_prediction_callback=trainer.post_prediction_callback, num_cls=5)]
 
 
-        assert isinstance(model.test(test_loader=model.valid_loader, silent_mode=True, test_metrics_list=test_metrics), tuple)
+        assert isinstance(trainer.test(test_loader=trainer.valid_loader, silent_mode=True, test_metrics_list=test_metrics), tuple)
 
 
         model = self.get_segmentation_trainer(self.folder_names[2])
         model = self.get_segmentation_trainer(self.folder_names[2])
         assert isinstance(model.test(test_loader=model.valid_loader, silent_mode=True, test_metrics_list=[IoU(21), PixelAccuracy()]), tuple)
         assert isinstance(model.test(test_loader=model.valid_loader, silent_mode=True, test_metrics_list=[IoU(21), PixelAccuracy()]), tuple)
Discard
@@ -1,5 +1,5 @@
 import unittest
 import unittest
-from super_gradients import SgModel, \
+from super_gradients import Trainer, \
     ClassificationTestDatasetInterface
     ClassificationTestDatasetInterface
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.models import ResNet18
 from super_gradients.training.models import ResNet18
@@ -11,13 +11,13 @@ import shutil
 
 
 class SgTrainerLoggingTest(unittest.TestCase):
 class SgTrainerLoggingTest(unittest.TestCase):
     def test_train_logging(self):
     def test_train_logging(self):
-        model = SgModel("test_train_with_full_log", model_checkpoints_location='local')
+        trainer = Trainer("test_train_with_full_log", model_checkpoints_location='local')
         dataset_params = {"batch_size": 10}
         dataset_params = {"batch_size": 10}
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
-        model.connect_dataset_interface(dataset)
+        trainer.connect_dataset_interface(dataset)
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
-        model.build_model(net)
+        trainer.build_model(net)
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
@@ -26,9 +26,9 @@ class SgTrainerLoggingTest(unittest.TestCase):
                         "greater_metric_to_watch_is_better": True,
                         "greater_metric_to_watch_is_better": True,
                         "save_full_train_log": True}
                         "save_full_train_log": True}
 
 
-        model.train(train_params)
+        trainer.train(train_params)
 
 
-        logfile_path = model.log_file.replace('.txt', 'full_train_log.log')
+        logfile_path = trainer.log_file.replace('.txt', 'full_train_log.log')
         assert os.path.exists(logfile_path) and os.path.getsize(logfile_path) > 0
         assert os.path.exists(logfile_path) and os.path.getsize(logfile_path) > 0
 
 
         root_logger_handlers = logging.root.handlers
         root_logger_handlers = logging.root.handlers
@@ -41,7 +41,7 @@ class SgTrainerLoggingTest(unittest.TestCase):
         if os.path.exists(logs_dir_path):
         if os.path.exists(logs_dir_path):
             shutil.rmtree(logs_dir_path)
             shutil.rmtree(logs_dir_path)
 
 
-        module_name = 'super_gradients.trainer.sg_model'
+        module_name = 'super_gradients.trainer.sg_trainer'
 
 
         _ = get_logger(module_name, training_log_path=None, logs_dir_path=logs_dir_path)
         _ = get_logger(module_name, training_log_path=None, logs_dir_path=logs_dir_path)
         root_logger_handlers = logging.root.handlers
         root_logger_handlers = logging.root.handlers
Discard
@@ -1,5 +1,5 @@
 import unittest
 import unittest
-from super_gradients import SgModel, \
+from super_gradients import Trainer, \
     ClassificationTestDatasetInterface
     ClassificationTestDatasetInterface
 from super_gradients.training.metrics import Accuracy, Top5, ToyTestClassificationMetric
 from super_gradients.training.metrics import Accuracy, Top5, ToyTestClassificationMetric
 from super_gradients.training.models import ResNet18
 from super_gradients.training.models import ResNet18
@@ -16,51 +16,52 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
     """
     """
     Unit test for training with initialized objects passed as parameters.
     Unit test for training with initialized objects passed as parameters.
     """
     """
+
     def test_train_with_external_criterion(self):
     def test_train_with_external_criterion(self):
-        model = SgModel("external_criterion_test", model_checkpoints_location='local')
+        trainer = Trainer("external_criterion_test", model_checkpoints_location='local')
         dataset_params = {"batch_size": 10}
         dataset_params = {"batch_size": 10}
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
-        model.connect_dataset_interface(dataset)
+        trainer.connect_dataset_interface(dataset)
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
-        model.build_model(net)
+        trainer.build_model(net)
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": torch.nn.CrossEntropyLoss(), "optimizer": "SGD",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": torch.nn.CrossEntropyLoss(), "optimizer": "SGD",
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
                         "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
                         "metric_to_watch": "Accuracy",
                         "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
-        model.train(train_params)
+        trainer.train(train_params)
 
 
     def test_train_with_external_optimizer(self):
     def test_train_with_external_optimizer(self):
-        model = SgModel("external_optimizer_test", model_checkpoints_location='local')
+        trainer = Trainer("external_optimizer_test", model_checkpoints_location='local')
         dataset_params = {"batch_size": 10}
         dataset_params = {"batch_size": 10}
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
-        model.connect_dataset_interface(dataset)
+        trainer.connect_dataset_interface(dataset)
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
         optimizer = SGD(params=net.parameters(), lr=0.1)
         optimizer = SGD(params=net.parameters(), lr=0.1)
-        model.build_model(net)
+        trainer.build_model(net)
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": optimizer,
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": optimizer,
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "train_metrics_list": [Accuracy(), Top5()], "valid_metrics_list": [Accuracy(), Top5()],
                         "train_metrics_list": [Accuracy(), Top5()], "valid_metrics_list": [Accuracy(), Top5()],
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
-        model.train(train_params)
+        trainer.train(train_params)
 
 
     def test_train_with_external_scheduler(self):
     def test_train_with_external_scheduler(self):
-        model = SgModel("external_scheduler_test", model_checkpoints_location='local')
+        trainer = Trainer("external_scheduler_test", model_checkpoints_location='local')
         dataset_params = {"batch_size": 10}
         dataset_params = {"batch_size": 10}
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
-        model.connect_dataset_interface(dataset)
+        trainer.connect_dataset_interface(dataset)
 
 
         lr = 0.3
         lr = 0.3
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
         optimizer = SGD(params=net.parameters(), lr=lr)
         optimizer = SGD(params=net.parameters(), lr=lr)
         lr_scheduler = MultiStepLR(optimizer=optimizer, milestones=[1, 2], gamma=0.1)
         lr_scheduler = MultiStepLR(optimizer=optimizer, milestones=[1, 2], gamma=0.1)
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.TRAIN_EPOCH_END)]
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.TRAIN_EPOCH_END)]
-        model.build_model(net)
+        trainer.build_model(net)
 
 
         train_params = {"max_epochs": 2, "phase_callbacks": phase_callbacks,
         train_params = {"max_epochs": 2, "phase_callbacks": phase_callbacks,
                         "lr_warmup_epochs": 0, "initial_lr": lr, "loss": "cross_entropy", "optimizer": optimizer,
                         "lr_warmup_epochs": 0, "initial_lr": lr, "loss": "cross_entropy", "optimizer": optimizer,
@@ -68,18 +69,18 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
                         "train_metrics_list": [Accuracy(), Top5()], "valid_metrics_list": [Accuracy(), Top5()],
                         "train_metrics_list": [Accuracy(), Top5()], "valid_metrics_list": [Accuracy(), Top5()],
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
-        model.train(train_params)
+        trainer.train(train_params)
         assert lr_scheduler.get_last_lr()[0] == lr * 0.1 * 0.1
         assert lr_scheduler.get_last_lr()[0] == lr * 0.1 * 0.1
 
 
     def test_train_with_external_scheduler_class(self):
     def test_train_with_external_scheduler_class(self):
-        model = SgModel("external_scheduler_test", model_checkpoints_location='local')
+        trainer = Trainer("external_scheduler_test", model_checkpoints_location='local')
         dataset_params = {"batch_size": 10}
         dataset_params = {"batch_size": 10}
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
-        model.connect_dataset_interface(dataset)
+        trainer.connect_dataset_interface(dataset)
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
         optimizer = SGD  # a class - not an instance
         optimizer = SGD  # a class - not an instance
-        model.build_model(net)
+        trainer.build_model(net)
 
 
         train_params = {"max_epochs": 2,
         train_params = {"max_epochs": 2,
                         "lr_warmup_epochs": 0, "initial_lr": 0.3, "loss": "cross_entropy", "optimizer": optimizer,
                         "lr_warmup_epochs": 0, "initial_lr": 0.3, "loss": "cross_entropy", "optimizer": optimizer,
@@ -87,20 +88,20 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
                         "train_metrics_list": [Accuracy(), Top5()], "valid_metrics_list": [Accuracy(), Top5()],
                         "train_metrics_list": [Accuracy(), Top5()], "valid_metrics_list": [Accuracy(), Top5()],
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
-        model.train(train_params)
+        trainer.train(train_params)
 
 
     def test_train_with_reduce_on_plateau(self):
     def test_train_with_reduce_on_plateau(self):
-        model = SgModel("external_reduce_on_plateau_scheduler_test", model_checkpoints_location='local')
+        trainer = Trainer("external_reduce_on_plateau_scheduler_test", model_checkpoints_location='local')
         dataset_params = {"batch_size": 10}
         dataset_params = {"batch_size": 10}
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
-        model.connect_dataset_interface(dataset)
+        trainer.connect_dataset_interface(dataset)
 
 
         lr = 0.3
         lr = 0.3
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
         optimizer = SGD(params=net.parameters(), lr=lr)
         optimizer = SGD(params=net.parameters(), lr=lr)
         lr_scheduler = ReduceLROnPlateau(optimizer=optimizer, patience=0)
         lr_scheduler = ReduceLROnPlateau(optimizer=optimizer, patience=0)
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.VALIDATION_EPOCH_END, "ToyTestClassificationMetric")]
         phase_callbacks = [LRSchedulerCallback(lr_scheduler, Phase.VALIDATION_EPOCH_END, "ToyTestClassificationMetric")]
-        model.build_model(net)
+        trainer.build_model(net)
 
 
         train_params = {"max_epochs": 2, "phase_callbacks": phase_callbacks,
         train_params = {"max_epochs": 2, "phase_callbacks": phase_callbacks,
                         "lr_warmup_epochs": 0, "initial_lr": lr, "loss": "cross_entropy", "optimizer": optimizer,
                         "lr_warmup_epochs": 0, "initial_lr": lr, "loss": "cross_entropy", "optimizer": optimizer,
@@ -109,27 +110,27 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
                         "valid_metrics_list": [Accuracy(), Top5(), ToyTestClassificationMetric()],
                         "valid_metrics_list": [Accuracy(), Top5(), ToyTestClassificationMetric()],
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
-        model.train(train_params)
+        trainer.train(train_params)
         assert lr_scheduler._last_lr[0] == lr * 0.1
         assert lr_scheduler._last_lr[0] == lr * 0.1
 
 
     def test_train_with_external_metric(self):
     def test_train_with_external_metric(self):
-        model = SgModel("external_metric_test", model_checkpoints_location='local')
+        trainer = Trainer("external_metric_test", model_checkpoints_location='local')
         dataset_params = {"batch_size": 10}
         dataset_params = {"batch_size": 10}
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
-        model.connect_dataset_interface(dataset)
+        trainer.connect_dataset_interface(dataset)
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
-        model.build_model(net)
+        trainer.build_model(net)
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "train_metrics_list": [F1Score()], "valid_metrics_list": [F1Score()],
                         "train_metrics_list": [F1Score()], "valid_metrics_list": [F1Score()],
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "F1Score",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "F1Score",
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
-        model.train(train_params)
+        trainer.train(train_params)
 
 
     def test_train_with_external_dataloaders(self):
     def test_train_with_external_dataloaders(self):
-        model = SgModel("external_data_loader_test", model_checkpoints_location='local')
+        trainer = Trainer("external_data_loader_test", model_checkpoints_location='local')
 
 
         batch_size = 5
         batch_size = 5
         trainset = torch.utils.data.TensorDataset(torch.Tensor(np.random.random((10, 3, 32, 32))), torch.LongTensor(np.zeros((10))))
         trainset = torch.utils.data.TensorDataset(torch.Tensor(np.random.random((10, 3, 32, 32))), torch.LongTensor(np.zeros((10))))
@@ -141,17 +142,17 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size)
         val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size)
 
 
         dataset_interface = DatasetInterface(train_loader=train_loader, val_loader=val_loader, classes=classes)
         dataset_interface = DatasetInterface(train_loader=train_loader, val_loader=val_loader, classes=classes)
-        model.connect_dataset_interface(dataset_interface)
+        trainer.connect_dataset_interface(dataset_interface)
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
-        model.build_model(net)
+        trainer.build_model(net)
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "train_metrics_list": [F1Score()], "valid_metrics_list": [F1Score()],
                         "train_metrics_list": [F1Score()], "valid_metrics_list": [F1Score()],
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "F1Score",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "F1Score",
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
-        model.train(train_params)
+        trainer.train(train_params)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -1,5 +1,5 @@
 import unittest
 import unittest
-from super_gradients import SgModel, \
+from super_gradients import Trainer, \
     ClassificationTestDatasetInterface
     ClassificationTestDatasetInterface
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.models import ResNet18
 from super_gradients.training.models import ResNet18
@@ -11,13 +11,13 @@ class TrainWithPreciseBNTest(unittest.TestCase):
     """
     """
 
 
     def test_train_with_precise_bn_explicit_size(self):
     def test_train_with_precise_bn_explicit_size(self):
-        model = SgModel("test_train_with_precise_bn_explicit_size", model_checkpoints_location='local')
+        trainer = Trainer("test_train_with_precise_bn_explicit_size", model_checkpoints_location='local')
         dataset_params = {"batch_size": 10}
         dataset_params = {"batch_size": 10}
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
-        model.connect_dataset_interface(dataset)
+        trainer.connect_dataset_interface(dataset)
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
-        model.build_model(net)
+        trainer.build_model(net)
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
@@ -25,16 +25,16 @@ class TrainWithPreciseBNTest(unittest.TestCase):
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True,
                         "greater_metric_to_watch_is_better": True,
                         "precise_bn": True, "precise_bn_batch_size": 100}
                         "precise_bn": True, "precise_bn_batch_size": 100}
-        model.train(train_params)
+        trainer.train(train_params)
 
 
     def test_train_with_precise_bn_implicit_size(self):
     def test_train_with_precise_bn_implicit_size(self):
-        model = SgModel("test_train_with_precise_bn_implicit_size", model_checkpoints_location='local')
+        trainer = Trainer("test_train_with_precise_bn_implicit_size", model_checkpoints_location='local')
         dataset_params = {"batch_size": 10}
         dataset_params = {"batch_size": 10}
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
         dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
-        model.connect_dataset_interface(dataset)
+        trainer.connect_dataset_interface(dataset)
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
-        model.build_model(net)
+        trainer.build_model(net)
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
@@ -42,7 +42,7 @@ class TrainWithPreciseBNTest(unittest.TestCase):
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "greater_metric_to_watch_is_better": True,
                         "greater_metric_to_watch_is_better": True,
                         "precise_bn": True}
                         "precise_bn": True}
-        model.train(train_params)
+        trainer.train(train_params)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
@@ -1,5 +1,5 @@
 import unittest
 import unittest
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
 from super_gradients.training.models import LeNet
 from super_gradients.training.models import LeNet
@@ -33,9 +33,9 @@ class UpdateParamGroupsTest(unittest.TestCase):
     def test_lr_scheduling_with_update_param_groups(self):
     def test_lr_scheduling_with_update_param_groups(self):
         # Define Model
         # Define Model
         net = TestNet()
         net = TestNet()
-        model = SgModel("lr_warmup_test", model_checkpoints_location='local')
-        model.connect_dataset_interface(self.dataset)
-        model.build_model(net, arch_params=self.arch_params)
+        trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
+        trainer.connect_dataset_interface(self.dataset)
+        trainer.build_model(net, arch_params=self.arch_params)
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -53,6 +53,6 @@ class UpdateParamGroupsTest(unittest.TestCase):
                         }
                         }
 
 
         expected_lrs = np.array([0.1, 0.2, 0.3])
         expected_lrs = np.array([0.1, 0.2, 0.3])
-        model.train(train_params)
+        trainer.train(train_params)
 
 
         self.assertTrue(np.allclose(np.array(lrs), expected_lrs, rtol=0.0000001))
         self.assertTrue(np.allclose(np.array(lrs), expected_lrs, rtol=0.0000001))
Discard
@@ -1,7 +1,7 @@
 import unittest
 import unittest
 from super_gradients.training.utils.utils import HpmStruct
 from super_gradients.training.utils.utils import HpmStruct
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
-from super_gradients import SgModel
+from super_gradients import Trainer
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 
 
 
 
@@ -21,10 +21,10 @@ class TestViT(unittest.TestCase):
         """
         """
         Validate vit_base
         Validate vit_base
         """
         """
-        model = SgModel("test_vit_base", device='cpu')
-        model.connect_dataset_interface(self.dataset, data_loader_num_workers=8)
-        model.build_model('vit_base', load_checkpoint=False)
-        model.train(training_params=self.train_params)
+        trainer = Trainer("test_vit_base", device='cpu')
+        trainer.connect_dataset_interface(self.dataset, data_loader_num_workers=8)
+        trainer.build_model('vit_base', load_checkpoint=False)
+        trainer.train(training_params=self.train_params)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard