Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#356 Feature/sg 216 remove dataset interface

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-216_remove_dataset_interface
100 changed files with 688 additions and 3388 deletions
  1. 3
    5
      src/super_gradients/__init__.py
  2. 0
    26
      src/super_gradients/common/factories/datasets_factory.py
  3. 2
    4
      src/super_gradients/examples/cifar10_training_torch_objects/cifar10_training_torch_objects_example.py
  4. 21
    14
      src/super_gradients/examples/ddrnet_imagenet/ddrnet_classification_example.py
  5. 4
    4
      src/super_gradients/examples/deci_lab_export_example/deci_lab_export_example.py
  6. 6
    7
      src/super_gradients/examples/deci_platform_logger_example/deci_platform_logger_example.py
  7. 6
    7
      src/super_gradients/examples/early_stop/early_stop_example.py
  8. 0
    0
      src/super_gradients/examples/legacy/cifar_resnet/__init__.py
  9. 0
    27
      src/super_gradients/examples/legacy/cifar_resnet/cifar_example.py
  10. 0
    31
      src/super_gradients/examples/legacy/darknet53_example.py
  11. 0
    0
      src/super_gradients/examples/legacy/imagenet_efficientnet/__init__.py
  12. 0
    31
      src/super_gradients/examples/legacy/imagenet_efficientnet/efficientnet_example.py
  13. 0
    0
      src/super_gradients/examples/legacy/imagenet_mobilenetv3/__init__.py
  14. 0
    30
      src/super_gradients/examples/legacy/imagenet_mobilenetv3/mobilenetv3_imagenet_example.py
  15. 0
    0
      src/super_gradients/examples/legacy/imagenet_regnetY800/__init__.py
  16. 0
    29
      src/super_gradients/examples/legacy/imagenet_regnetY800/regnetY800_example.py
  17. 0
    0
      src/super_gradients/examples/legacy/imagenet_repvgg/__init__.py
  18. 0
    24
      src/super_gradients/examples/legacy/imagenet_repvgg/imagenet_repvgg_example.py
  19. 0
    0
      src/super_gradients/examples/legacy/imagenet_resnet/__init__.py
  20. 0
    42
      src/super_gradients/examples/legacy/imagenet_resnet/imagenet_resnet_example.py
  21. 0
    0
      src/super_gradients/examples/legacy/imagenet_resnet_ddp/__init__.py
  22. 0
    68
      src/super_gradients/examples/legacy/imagenet_resnet_ddp/distributed_training_imagenet.py
  23. 0
    81
      src/super_gradients/examples/legacy/shelfnet_lw_example.py
  24. 13
    16
      src/super_gradients/examples/regseg_transfer_learning_example/regseg_transfer_learning_example.py
  25. 6
    7
      src/super_gradients/examples/resnet_qat/resnet_qat_example.py
  26. 0
    0
      src/super_gradients/examples/shelfnet_lw_pascal_aug/__init__.py
  27. 0
    56
      src/super_gradients/examples/shelfnet_lw_pascal_aug/shelfnet_pascal_aug.py
  28. 0
    0
      src/super_gradients/examples/user_guide_walkthrough_example/__init__.py
  29. 0
    44
      src/super_gradients/examples/user_guide_walkthrough_example/dataset.py
  30. 0
    46
      src/super_gradients/examples/user_guide_walkthrough_example/loss.py
  31. 0
    24
      src/super_gradients/examples/user_guide_walkthrough_example/metrics.py
  32. 0
    154
      src/super_gradients/examples/user_guide_walkthrough_example/model.py
  33. 0
    55
      src/super_gradients/examples/user_guide_walkthrough_example/train.py
  34. 4
    3
      src/super_gradients/recipes/cifar10_resnet.yaml
  35. 4
    50
      src/super_gradients/recipes/cityscapes_ddrnet.yaml
  36. 3
    49
      src/super_gradients/recipes/cityscapes_regseg48.yaml
  37. 2
    33
      src/super_gradients/recipes/cityscapes_stdc_base.yaml
  38. 3
    50
      src/super_gradients/recipes/cityscapes_stdc_seg50.yaml
  39. 3
    48
      src/super_gradients/recipes/cityscapes_stdc_seg75.yaml
  40. 3
    3
      src/super_gradients/recipes/coco2017_ssd_lite_mobilenet_v2.yaml
  41. 3
    7
      src/super_gradients/recipes/coco2017_yolox.yaml
  42. 3
    21
      src/super_gradients/recipes/coco_segmentation_shelfnet_lw.yaml
  43. 0
    5
      src/super_gradients/recipes/dataset_params/cifar100_dataset_params.yaml
  44. 2
    2
      src/super_gradients/recipes/dataset_params/cifar10_dataset_params.yaml
  45. 0
    78
      src/super_gradients/recipes/dataset_params/coco_detection_dataset_params.yaml
  46. 1
    2
      src/super_gradients/recipes/dataset_params/coco_detection_ssd_lite_mobilenet_v2_dataset_params.yaml
  47. 2
    21
      src/super_gradients/recipes/dataset_params/imagenet_dataset_params.yaml
  48. 0
    12
      src/super_gradients/recipes/dataset_params/imagenet_efficientnet_dataset_params.yaml
  49. 0
    13
      src/super_gradients/recipes/dataset_params/imagenet_mobilenetv2_dataset_params.yaml
  50. 0
    7
      src/super_gradients/recipes/dataset_params/imagenet_mobilenetv3_dataset_params.yaml
  51. 0
    12
      src/super_gradients/recipes/dataset_params/imagenet_regnetY_dataset_params.yaml
  52. 0
    17
      src/super_gradients/recipes/dataset_params/imagenet_resnet50_dataset_params.yaml
  53. 6
    25
      src/super_gradients/recipes/dataset_params/imagenet_resnet50_kd_dataset_params.yaml
  54. 0
    20
      src/super_gradients/recipes/dataset_params/imagenet_vit_base_dataset_params.yaml
  55. 2
    5
      src/super_gradients/recipes/imagenet_efficientnet.yaml
  56. 3
    4
      src/super_gradients/recipes/imagenet_mobilenetv2.yaml
  57. 2
    6
      src/super_gradients/recipes/imagenet_mobilenetv3_base.yaml
  58. 2
    6
      src/super_gradients/recipes/imagenet_regnetY.yaml
  59. 2
    6
      src/super_gradients/recipes/imagenet_repvgg.yaml
  60. 2
    5
      src/super_gradients/recipes/imagenet_resnet50.yaml
  61. 2
    6
      src/super_gradients/recipes/imagenet_resnet50_kd.yaml
  62. 2
    5
      src/super_gradients/recipes/imagenet_vit_base.yaml
  63. 0
    1
      src/super_gradients/recipes/imagenet_vit_large.yaml
  64. 0
    24
      src/super_gradients/recipes/test_resnet.yaml
  65. 3
    5
      src/super_gradients/training/__init__.py
  66. 31
    0
      src/super_gradients/training/dataloaders/__init__.py
  67. 134
    58
      src/super_gradients/training/dataloaders/dataloaders.py
  68. 2
    15
      src/super_gradients/training/datasets/__init__.py
  69. 0
    89
      src/super_gradients/training/datasets/all_datasets.py
  70. 0
    15
      src/super_gradients/training/datasets/dataset_interfaces/__init__.py
  71. 0
    807
      src/super_gradients/training/datasets/dataset_interfaces/dataset_interface.py
  72. 19
    6
      src/super_gradients/training/kd_trainer/kd_trainer.py
  73. 3
    2
      src/super_gradients/training/params.py
  74. 30
    31
      src/super_gradients/training/sg_trainer/sg_trainer.py
  75. 1
    38
      tests/end_to_end_tests/cifar_trainer_test.py
  76. 0
    135
      tests/end_to_end_tests/external_dataset_e2e.py
  77. 11
    10
      tests/end_to_end_tests/trainer_test.py
  78. 1
    2
      tests/integration_tests/__init__.py
  79. 14
    16
      tests/integration_tests/conversion_callback_test.py
  80. 4
    5
      tests/integration_tests/deci_lab_export_test.py
  81. 4
    4
      tests/integration_tests/ema_train_integration_test.py
  82. 11
    10
      tests/integration_tests/lr_test.py
  83. 217
    338
      tests/integration_tests/pretrained_models_test.py
  84. 9
    8
      tests/integration_tests/qat_integration_test.py
  85. 0
    21
      tests/integration_tests/s3_dataset_test.py
  86. 1
    2
      tests/unit_tests/__init__.py
  87. 1
    1
      tests/unit_tests/cityscapes_dataset_test.py
  88. 1
    1
      tests/unit_tests/coco_segmentation_dataset_test.py
  89. 1
    1
      tests/unit_tests/datalaoder_factory_test.py
  90. 0
    139
      tests/unit_tests/dataset_interface_test.py
  91. 3
    38
      tests/unit_tests/dataset_statistics_test.py
  92. 3
    38
      tests/unit_tests/detection_utils_test.py
  93. 15
    25
      tests/unit_tests/early_stop_test.py
  94. 0
    77
      tests/unit_tests/externel_dataset_interface_test.py
  95. 7
    10
      tests/unit_tests/factories_test.py
  96. 4
    8
      tests/unit_tests/forward_pass_prep_fn_test.py
  97. 2
    20
      tests/unit_tests/initialize_with_dataloaders_test.py
  98. 12
    10
      tests/unit_tests/kd_ema_test.py
  99. 19
    16
      tests/unit_tests/kd_trainer_test.py
  100. 8
    9
      tests/unit_tests/load_ema_ckpt_test.py
@@ -1,5 +1,4 @@
-from super_gradients.training import ARCHITECTURES, losses, utils, datasets_utils, DataAugmentation, \
-    TestDatasetInterface, SegmentationTestDatasetInterface, DetectionTestDatasetInterface, ClassificationTestDatasetInterface, SgModel, KDModel, \
+from super_gradients.training import ARCHITECTURES, losses, utils, datasets_utils, DataAugmentation, SgModel, KDModel, \
     Trainer, KDTrainer
     Trainer, KDTrainer
 from super_gradients.common import init_trainer, is_distributed
 from super_gradients.common import init_trainer, is_distributed
 from super_gradients.examples.train_from_recipe_example import train_from_recipe
 from super_gradients.examples.train_from_recipe_example import train_from_recipe
@@ -7,9 +6,8 @@ from super_gradients.examples.train_from_kd_recipe_example import train_from_kd_
 from super_gradients.sanity_check import env_sanity_check
 from super_gradients.sanity_check import env_sanity_check
 
 
 __all__ = ['ARCHITECTURES', 'losses', 'utils', 'datasets_utils', 'DataAugmentation',
 __all__ = ['ARCHITECTURES', 'losses', 'utils', 'datasets_utils', 'DataAugmentation',
-           'TestDatasetInterface', 'Trainer', 'KDTrainer', 'SegmentationTestDatasetInterface', 'DetectionTestDatasetInterface',
-           'ClassificationTestDatasetInterface', 'init_trainer', 'is_distributed', 'train_from_recipe', 'train_from_kd_recipe',
+           'Trainer', 'KDTrainer',
+           'init_trainer', 'is_distributed', 'train_from_recipe', 'train_from_kd_recipe',
            'env_sanity_check', 'KDModel', 'SgModel']
            'env_sanity_check', 'KDModel', 'SgModel']
 
 
-
 env_sanity_check()
 env_sanity_check()
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
  1. from super_gradients.common.factories.base_factory import BaseFactory
  2. from super_gradients.training.datasets.dataset_interfaces import LibraryDatasetInterface, ClassificationDatasetInterface, Cifar10DatasetInterface,\
  3. Cifar100DatasetInterface, ImageNetDatasetInterface, TinyImageNetDatasetInterface, CoCoSegmentationDatasetInterface,\
  4. PascalAUG2012SegmentationDataSetInterface, PascalVOC2012SegmentationDataSetInterface
  5. from super_gradients.training.datasets.dataset_interfaces.dataset_interface import \
  6. ClassificationTestDatasetInterface, CityscapesDatasetInterface, CoCoDetectionDatasetInterface
  7. class DatasetsFactory(BaseFactory):
  8. def __init__(self):
  9. type_dict = {
  10. "classification_test_dataset": ClassificationTestDatasetInterface,
  11. "library_dataset": LibraryDatasetInterface,
  12. "classification_dataset": ClassificationDatasetInterface,
  13. "cifar_10": Cifar10DatasetInterface,
  14. "cifar_100": Cifar100DatasetInterface,
  15. "imagenet": ImageNetDatasetInterface,
  16. "tiny_imagenet": TinyImageNetDatasetInterface,
  17. "coco2017_detection": CoCoDetectionDatasetInterface,
  18. "coco2017_segmentation": CoCoSegmentationDatasetInterface,
  19. "pascal_voc_segmentation": PascalVOC2012SegmentationDataSetInterface,
  20. "pascal_aug_segmentation": PascalAUG2012SegmentationDataSetInterface,
  21. "cityscapes": CityscapesDatasetInterface,
  22. }
  23. super().__init__(type_dict)
Discard
@@ -49,9 +49,7 @@ phase_callbacks = [LRSchedulerCallback(scheduler=rop_lr_scheduler, phase=Phase.V
                    LRSchedulerCallback(scheduler=step_lr_scheduler, phase=Phase.TRAIN_EPOCH_END)]
                    LRSchedulerCallback(scheduler=step_lr_scheduler, phase=Phase.TRAIN_EPOCH_END)]
 
 
 # Bring everything together with Trainer and start training
 # Bring everything together with Trainer and start training
-trainer = Trainer("Cifar10_external_objects_example", multi_gpu=MultiGPUMode.OFF,
-                  train_loader=train_loader, valid_loader=valid_loader, classes=train_dataset.classes)
-trainer.build_model(net)
+trainer = Trainer("Cifar10_external_objects_example", multi_gpu=MultiGPUMode.OFF)
 
 
 train_params = {"max_epochs": 300,
 train_params = {"max_epochs": 300,
                 "phase_callbacks": phase_callbacks,
                 "phase_callbacks": phase_callbacks,
@@ -65,4 +63,4 @@ train_params = {"max_epochs": 300,
                 "greater_metric_to_watch_is_better": True,
                 "greater_metric_to_watch_is_better": True,
                 "lr_scheduler_step_type": "epoch"}
                 "lr_scheduler_step_type": "epoch"}
 
 
-trainer.train(training_params=train_params)
+trainer.train(model=net, training_params=train_params, train_loader=train_loader, valid_loader=valid_loader)
Discard
@@ -14,15 +14,14 @@ Paper:              https://arxiv.org/pdf/2101.06085.pdf
 
 
 import torch
 import torch
 
 
-from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ImageNetDatasetInterface
-
+from super_gradients.common import MultiGPUMode
+from super_gradients.training.datasets.datasets_utils import RandomResizedCropAndInterpolation
+from torchvision.transforms import RandomHorizontalFlip, ColorJitter, ToTensor, Normalize
 import super_gradients
 import super_gradients
-from super_gradients.training import Trainer, MultiGPUMode
-from super_gradients.training.models import HpmStruct
+from super_gradients.training import Trainer, models, dataloaders
 import argparse
 import argparse
-
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
-
+from super_gradients.training.datasets.data_augmentation import RandomErase
 parser = argparse.ArgumentParser()
 parser = argparse.ArgumentParser()
 super_gradients.init_trainer()
 super_gradients.init_trainer()
 
 
@@ -56,19 +55,27 @@ dataset_params = {"batch_size": args.batch,
                   "random_erase_prob": 0.2,
                   "random_erase_prob": 0.2,
                   "random_erase_value": 'random',
                   "random_erase_value": 'random',
                   "train_interpolation": 'random',
                   "train_interpolation": 'random',
-                  "auto_augment_config_string": 'rand-m9-mstd0.5'
                   }
                   }
 
 
+
+train_transforms = [RandomResizedCropAndInterpolation(size=224, interpolation="random"),
+                    RandomHorizontalFlip(),
+                    ColorJitter(0.4, 0.4, 0.4),
+                    ToTensor(),
+                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+                    RandomErase(0.2, "random")
+                    ]
+
 trainer = Trainer(experiment_name=args.experiment_name,
 trainer = Trainer(experiment_name=args.experiment_name,
                   multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL if distributed else MultiGPUMode.DATA_PARALLEL,
                   multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL if distributed else MultiGPUMode.DATA_PARALLEL,
                   device='cuda')
                   device='cuda')
 
 
-dataset = ImageNetDatasetInterface(dataset_params=dataset_params)
-
-trainer.connect_dataset_interface(dataset, data_loader_num_workers=8 * devices)
+train_loader = dataloaders.imagenet_train(dataset_params={"transforms": train_transforms},
+                                          dataloader_params={"batch_size": args.batch})
+valid_loader = dataloaders.imagenet_val()
 
 
-arch_params = HpmStruct(**{"num_classes": 1000, "aux_head": False, "classification_mode": True, 'dropout_prob': 0.3})
+model = models.get("ddrnet_23_slim" if args.slim else "ddrnet_23",
+                   arch_params={"aux_head": False, "classification_mode": True, 'dropout_prob': 0.3},
+                   num_classes=1000)
 
 
-trainer.build_model(architecture="ddrnet_23_slim" if args.slim else "ddrnet_23",
-                    arch_params=arch_params)
-trainer.train(training_params=train_params_ddr)
+trainer.train(model=model, training_params=train_params_ddr, train_loader=train_loader, valid_loader=valid_loader)
Discard
@@ -4,7 +4,8 @@ Deci-lab model export example.
 The main purpose of this code is to demonstrate how to upload the model to the platform, optimize and download it
 The main purpose of this code is to demonstrate how to upload the model to the platform, optimize and download it
  after training is complete, using DeciPlatformCallback.
  after training is complete, using DeciPlatformCallback.
 """
 """
-from super_gradients import Trainer, ClassificationTestDatasetInterface
+from super_gradients import Trainer
+from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.utils.callbacks import DeciLabUploadCallback, ModelConversionCheckCallback
 from super_gradients.training.utils.callbacks import DeciLabUploadCallback, ModelConversionCheckCallback
 from deci_lab_client.models import (
 from deci_lab_client.models import (
@@ -28,8 +29,6 @@ def main(architecture_name: str):
         model_checkpoints_location="local",
         model_checkpoints_location="local",
         ckpt_root_dir=checkpoint_dir,
         ckpt_root_dir=checkpoint_dir,
     )
     )
-    dataset = ClassificationTestDatasetInterface(dataset_params={"batch_size": 10})
-    trainer.connect_dataset_interface(dataset, data_loader_num_workers=0)
 
 
     trainer.build_model(architecture=architecture_name, arch_params={"use_aux_heads": True, "aux_head": True})
     trainer.build_model(architecture=architecture_name, arch_params={"use_aux_heads": True, "aux_head": True})
 
 
@@ -91,7 +90,8 @@ def main(architecture_name: str):
 
 
     # RUN TRAINING. ONCE ALL EPOCHS ARE DONE THE OPTIMIZED MODEL FILE WILL BE LOCATED IN THE EXPERIMENT'S
     # RUN TRAINING. ONCE ALL EPOCHS ARE DONE THE OPTIMIZED MODEL FILE WILL BE LOCATED IN THE EXPERIMENT'S
     # CHECKPOINT DIRECTORY
     # CHECKPOINT DIRECTORY
-    trainer.train(train_params)
+    trainer.train(train_params, train_loader=classification_test_dataloader(),
+                  valid_loader=classification_test_dataloader())
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard
@@ -1,15 +1,12 @@
 import os
 import os
-from super_gradients.training import Trainer
-from super_gradients.training.datasets.dataset_interfaces import Cifar10DatasetInterface
+from super_gradients.training import Trainer, models
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
-
+from super_gradients.training.dataloaders.dataloaders import cifar10_train, cifar10_val
 os.environ["DECI_PLATFORM_TOKEN"] = "XXX"  # Replace XXX with your token
 os.environ["DECI_PLATFORM_TOKEN"] = "XXX"  # Replace XXX with your token
 
 
 
 
 trainer = Trainer(experiment_name='demo-deci-platform-logger')
 trainer = Trainer(experiment_name='demo-deci-platform-logger')
-dataset = Cifar10DatasetInterface(dataset_params={"batch_size": 256, "val_batch_size": 512})
-trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
-trainer.build_model("resnet18")
+model = models.get("resnet18", num_classes=10)
 
 
 trainer.train(training_params={"max_epochs": 20,
 trainer.train(training_params={"max_epochs": 20,
                                "lr_updates": [5, 10, 15],
                                "lr_updates": [5, 10, 15],
@@ -23,4 +20,6 @@ trainer.train(training_params={"max_epochs": 20,
                                "valid_metrics_list": [Accuracy(), Top5()],
                                "valid_metrics_list": [Accuracy(), Top5()],
                                "metric_to_watch": "Accuracy",
                                "metric_to_watch": "Accuracy",
                                "greater_metric_to_watch_is_better": True,
                                "greater_metric_to_watch_is_better": True,
-                               "sg_logger": "deci_platform_sg_logger"})
+                               "sg_logger": "deci_platform_sg_logger"},
+              train_loader=cifar10_train(),
+              valid_loader=cifar10_val())
Discard
@@ -2,10 +2,11 @@
 # Reaches ~94.9 Accuracy after 250 Epochs
 # Reaches ~94.9 Accuracy after 250 Epochs
 import super_gradients
 import super_gradients
 from super_gradients import Trainer
 from super_gradients import Trainer
-from super_gradients.training.datasets.dataset_interfaces.dataset_interface import Cifar10DatasetInterface
+from super_gradients.training import models, dataloaders
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
 from super_gradients.training.utils.early_stopping import EarlyStop
 from super_gradients.training.utils.early_stopping import EarlyStop
 from super_gradients.training.utils.callbacks import Phase
 from super_gradients.training.utils.callbacks import Phase
+
 # Define Parameters
 # Define Parameters
 super_gradients.init_trainer()
 super_gradients.init_trainer()
 
 
@@ -22,10 +23,8 @@ train_params = {"max_epochs": 250, "lr_updates": [100, 150, 200], "lr_decay_fact
 # Define Model
 # Define Model
 trainer = Trainer("Callback_Example")
 trainer = Trainer("Callback_Example")
 
 
-# Connect Dataset
-dataset = Cifar10DatasetInterface()
-trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
-
 # Build Model
 # Build Model
-trainer.build_model("resnet18_cifar")
-trainer.train(training_params=train_params)
+model = models.get("resnet18_cifar", num_classes=10)
+
+trainer.train(model=model, training_params=train_params,
+              train_loader=dataloaders.cifar10_train(), valid_loader=dataloaders.cifar10_val())
Discard
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    1. # Cifar10 Classification Training:
    2. # Reaches ~94.9 Accuracy after 250 Epochs
    3. import super_gradients
    4. from omegaconf import DictConfig
    5. import hydra
    6. import pkg_resources
    7. @hydra.main(config_path=pkg_resources.resource_filename("conf", ""), config_name="cifar10_resnet_conf")
    8. def train(cfg: DictConfig) -> None:
    9. # INSTANTIATE ALL OBJECTS IN CFG
    10. cfg = hydra.utils.instantiate(cfg)
    11. # CONNECT THE DATASET INTERFACE WITH DECI MODEL
    12. cfg.trainer .connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
    13. # BUILD NETWORK
    14. cfg.trainer .build_model(cfg.architecture, load_checkpoint=cfg.load_checkpoint)
    15. # TRAIN
    16. cfg.trainer.train(training_params=cfg.training_params)
    17. if __name__ == "__main__":
    18. super_gradients.init_trainer()
    19. train()
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    1. # Darknet53 Backbone Training on HAM10000 Dataset
    2. from super_gradients.training import MultiGPUMode
    3. from super_gradients.training import Trainer
    4. from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationDatasetInterface
    5. # Define Parameters
    6. train_params = {"max_epochs": 110, "lr_updates": [30, 60, 90, 100], "lr_decay_factor": 0.1, "lr_mode": "step",
    7. "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
    8. "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9}}
    9. arch_params = {'backbone_mode': False, 'num_classes': 7}
    10. dataset_params = {"batch_size": 16, "test_batch_size": 16, 'dataset_dir': '/data/HAM10000'}
    11. # Define Model
    12. trainer = Trainer("Darknet53_Backbone_HAM10000",
    13. model_checkpoints_location='local',
    14. device='cuda',
    15. multi_gpu=MultiGPUMode.DATA_PARALLEL)
    16. # Connect Dataset
    17. dataset = ClassificationDatasetInterface(normalization_mean=(0.7483, 0.5154, 0.5353),
    18. normalization_std=(0.1455, 0.1691, 0.1879),
    19. resolution=416,
    20. dataset_params=dataset_params)
    21. trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
    22. # Build Model
    23. trainer.build_model("darknet53", arch_params=arch_params)
    24. # Start Training
    25. trainer.train(training_params=train_params)
    Discard
      Discard
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      1. """EfficientNet-b0 training on Imagenet
      2. TODO: This example code is the STARTING POINT for training EfficientNet - IT DIDN'T ACHIEVE THE PAPER'S ACCURACY!!!
      3. Training params are set according to https://github.com/rwightman/pytorch-image-models/issues/11
      4. Training on 4 GPUs with initial LR = 0.0032 achieves ~74.7%, (Paper=77.1% Timm=77.69%)
      5. The Tensorboards of the previous attempts: 's3/deci-model-repository-research/enet_reproduce_attempts'
      6. """
      7. import super_gradients
      8. from omegaconf import DictConfig
      9. import hydra
      10. import pkg_resources
      11. @hydra.main(config_path=pkg_resources.resource_filename("conf", ""), config_name="imagenet_efficientnet_conf")
      12. def train(cfg: DictConfig) -> None:
      13. # INSTANTIATE ALL OBJECTS IN CFG
      14. cfg = hydra.utils.instantiate(cfg)
      15. # CONNECT THE DATASET INTERFACE WITH DECI MODEL
      16. cfg.trainer .connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
      17. # BUILD NETWORK
      18. cfg.trainer .build_model(cfg.architecture, load_checkpoint=cfg.load_checkpoint)
      19. # TRAIN
      20. cfg.trainer.train(training_params=cfg.training_params)
      21. if __name__ == "__main__":
      22. super_gradients.init_trainer()
      23. train()
      Discard
        Discard
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15
        16
        17
        18
        19
        20
        21
        22
        23
        24
        25
        26
        27
        28
        29
        30
        1. # MobileNetV3 Large Imagenet classification training:
        2. # This example trains with batch_size = 128 * 2 GPUs, total 256.
        3. # Training time on 2 X GeForce RTX 2080 Ti is 19min / epoch, total time ~ 50 hours.
        4. # Reach 73.79 Top1 accuracy.
        5. # Training parameters are for MobileNet Large
        6. import super_gradients
        7. from omegaconf import DictConfig
        8. import hydra
        9. import pkg_resources
        10. @hydra.main(config_path=pkg_resources.resource_filename("conf", ""), config_name="imagenet_mobilenetv3_conf")
        11. def train(cfg: DictConfig) -> None:
        12. # INSTANTIATE ALL OBJECTS IN CFG
        13. cfg = hydra.utils.instantiate(cfg)
        14. # CONNECT THE DATASET INTERFACE WITH DECI MODEL
        15. cfg.trainer .connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
        16. # BUILD NETWORK
        17. cfg.trainer .build_model(cfg.architecture, load_checkpoint=cfg.load_checkpoint)
        18. # TRAIN
        19. cfg.trainer.train(training_params=cfg.training_params)
        20. if __name__ == "__main__":
        21. super_gradients.init_trainer()
        22. train()
        Discard
          Discard
          1
          2
          3
          4
          5
          6
          7
          8
          9
          10
          11
          12
          13
          14
          15
          16
          17
          18
          19
          20
          21
          22
          23
          24
          25
          26
          27
          28
          29
          1. # Imagenet classification training:
          2. # For RegnetY800 => 76.1 accuracy
          3. # the hyper-parameters are tailored for training on Single 2080Ti GPU.
          4. import super_gradients
          5. from omegaconf import DictConfig
          6. import hydra
          7. import pkg_resources
          8. @hydra.main(config_path=pkg_resources.resource_filename("conf", ""), config_name="imagenet_regnetY800_conf")
          9. def train(cfg: DictConfig) -> None:
          10. # INSTANTIATE ALL OBJECTS IN CFG
          11. cfg = hydra.utils.instantiate(cfg)
          12. # CONNECT THE DATASET INTERFACE WITH DECI MODEL
          13. cfg.trainer .connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
          14. # BUILD NETWORK
          15. cfg.trainer .build_model(cfg.architecture, arch_params=cfg.arch_params, load_checkpoint=cfg.load_checkpoint)
          16. # TRAIN
          17. cfg.trainer.train(training_params=cfg.training_params)
          18. if __name__ == "__main__":
          19. super_gradients.init_trainer()
          20. train()
          Discard
            Discard
            1
            2
            3
            4
            5
            6
            7
            8
            9
            10
            11
            12
            13
            14
            15
            16
            17
            18
            19
            20
            21
            22
            23
            24
            1. import super_gradients
            2. from omegaconf import DictConfig
            3. import hydra
            4. import pkg_resources
            5. @hydra.main(config_path=pkg_resources.resource_filename("conf", ""), config_name="imagenet_repvgg_conf")
            6. def train(cfg: DictConfig) -> None:
            7. # INSTANTIATE ALL OBJECTS IN CFG
            8. cfg = hydra.utils.instantiate(cfg)
            9. # CONNECT THE DATASET INTERFACE WITH DECI MODEL
            10. cfg.trainer .connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
            11. # BUILD NETWORK
            12. cfg.trainer .build_model(cfg.architecture, arch_params=cfg.arch_params, load_checkpoint=cfg.load_checkpoint)
            13. # TRAIN
            14. cfg.trainer.train(training_params=cfg.training_params)
            15. if __name__ == "__main__":
            16. super_gradients.init_trainer()
            17. train()
            Discard
              Discard
              1
              2
              3
              4
              5
              6
              7
              8
              9
              10
              11
              12
              13
              14
              15
              16
              17
              18
              19
              20
              21
              22
              23
              24
              25
              26
              27
              28
              29
              30
              31
              32
              33
              34
              35
              36
              37
              38
              39
              40
              41
              42
              1. """
              2. ResNet50 Imagenet classification training:
              3. This example trains with batch_size = 64 * 4 GPUs, total 256.
              4. Training times:
              5. ResNet18: 36 hours with 4 X NVIDIA RTX A5000.
              6. ResNet34: 36 hours with 4 X NVIDIA RTX A5000.
              7. ResNet50: 46 hours with 4 X GeForce RTX 3090 Ti.
              8. Top1, Top5 results:
              9. ResNet18: Top1: 70.60 Top5: 89.64
              10. ResNet34: Top1: 74.13 Top5: 91.70
              11. ResNet50: Top1: 76.30 Top5: 93.03
              12. BE AWARE THAT THIS RECIPE USE DATA_PARALLEL, WHEN USING DDP FOR DISTRIBUTED TRAINING THIS RECIPE REACH ONLY 75.4 TOP1
              13. ACCURACY.
              14. """
              15. import super_gradients
              16. from omegaconf import DictConfig
              17. import hydra
              18. import pkg_resources
              19. @hydra.main(config_path=pkg_resources.resource_filename("conf", ""), config_name="imagenet_resnet50_conf")
              20. def train(cfg: DictConfig) -> None:
              21. # INSTANTIATE ALL OBJECTS IN CFG
              22. cfg = hydra.utils.instantiate(cfg)
              23. # CONNECT THE DATASET INTERFACE WITH DECI MODEL
              24. cfg.trainer .connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
              25. # BUILD NETWORK
              26. cfg.trainer .build_model(cfg.architecture, arch_params=cfg.arch_params, load_checkpoint=cfg.load_checkpoint)
              27. # TRAIN
              28. cfg.trainer.train(training_params=cfg.training_params)
              29. if __name__ == "__main__":
              30. super_gradients.init_trainer()
              31. train()
              Discard
                Discard
                1
                2
                3
                4
                5
                6
                7
                8
                9
                10
                11
                12
                13
                14
                15
                16
                17
                18
                19
                20
                21
                22
                23
                24
                25
                26
                27
                28
                29
                30
                31
                32
                33
                34
                35
                36
                37
                38
                39
                40
                41
                42
                43
                44
                45
                46
                47
                48
                49
                50
                51
                52
                53
                54
                55
                56
                57
                58
                59
                60
                61
                62
                63
                64
                65
                66
                67
                68
                1. #!/usr/bin/env python
                2. """ Single node distributed training.
                3. The program will dispatch distributed training on all available GPUs residing in a single node.
                4. Usage:
                5. python -m torch.distributed.launch --nproc_per_node=n distributed_training_imagenet.py
                6. where n is the number of GPUs required, e.g., n=8
                7. Important note: (1) in distributed training it is customary to specify learning rates and batch sizes per GPU.
                8. Whatever learning rate and schedule you specify will be applied to the each GPU individually.
                9. Since gradients are passed and summed (reduced) from all to all GPUs, the effective batch size is the
                10. batch you specify times the number of GPUs. In the literature there are several "best practices" to set
                11. learning rates and schedules for large batch sizes.
                12. Should be checked with. (2) The training protocol specified in this file for 8 GPUs are far from optimal.
                13. The best protocol should use cosine schedule.
                14. In the example below: for ImageNet training using Resnet50, when applied with n=8 should compute an Eopch in about
                15. 5min20sec with 8 V100 GPUs.
                16. Todo: (1) the code is more or less ready for multiple nodes, but I have not experimented with it at all.
                17. (2) detection and segmentation codes were not modified and should not work properly.
                18. Specifically, the analogue changes done in sg_classification_model should be done also in
                19. deci_segmentation_model and deci_detection_model
                20. """
                21. import super_gradients
                22. import torch.distributed
                23. from super_gradients.training.sg_trainer import MultiGPUMode
                24. from super_gradients.training import Trainer
                25. from super_gradients.training.datasets.dataset_interfaces import ImageNetDatasetInterface
                26. from super_gradients.common.aws_connection.aws_secrets_manager_connector import AWSSecretsManagerConnector
                27. from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
                28. torch.backends.cudnn.benchmark = True
                29. super_gradients.init_trainer()
                30. # TODO - VALIDATE THE HYPER PARAMETERS WITH RAN TO FIX THIS EXAMPLE CODE
                31. train_params = {"max_epochs": 110,
                32. "lr_updates": [30, 60, 90],
                33. "lr_decay_factor": 0.1,
                34. "initial_lr": 0.6,
                35. "loss": "cross_entropy",
                36. "lr_mode": "step",
                37. # "initial_lr": 0.05 * 2,
                38. "lr_warmup_epochs": 5,
                39. # "criterion_params":{"smooth_eps":0.1}}
                40. "mixed_precision": True,
                41. # "mixed_precision_opt_level": "O3",
                42. "optimizer_params": {"weight_decay": 0.000, "momentum": 0.9},
                43. # "optimizer_params": {"weight_decay": 0.0001, "momentum": 0.9}
                44. "train_metrics_list": [Accuracy(), Top5()], "valid_metrics_list": [Accuracy(), Top5()],
                45. "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                46. "greater_metric_to_watch_is_better": True}
                47. dataset_params = {"batch_size": 128}
                48. model_repo_bucket_name = AWSSecretsManagerConnector.get_secret_value_for_secret_key(aws_env='research',
                49. secret_name='training_secrets',
                50. secret_key='S3.MODEL_REPOSITORY_BUCKET_NAME')
                51. trainer = Trainer("test_checkpoints_resnet_8_gpus",
                52. model_checkpoints_location='s3://' + model_repo_bucket_name,
                53. multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL
                54. )
                55. # FOR AWS
                56. dataset = ImageNetDatasetInterface(data_dir="/data/Imagenet", dataset_params=dataset_params)
                57. trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
                58. trainer.build_model("resnet50")
                59. trainer.train(training_params=train_params)
                Discard
                1
                2
                3
                4
                5
                6
                7
                8
                9
                10
                11
                12
                13
                14
                15
                16
                17
                18
                19
                20
                21
                22
                23
                24
                25
                26
                27
                28
                29
                30
                31
                32
                33
                34
                35
                36
                37
                38
                39
                40
                41
                42
                43
                44
                45
                46
                47
                48
                49
                50
                51
                52
                53
                54
                55
                56
                57
                58
                59
                60
                61
                62
                63
                64
                65
                66
                67
                68
                69
                70
                71
                72
                73
                74
                75
                76
                77
                78
                79
                80
                81
                1. # ShelfNet LW 34 training on CoCo Segmentation Dataset:
                2. # mIOU on CoCo Seg: ~0.66
                3. # Since the code is training on a Subset of COCO Seg, there is an initial creation process for the "Sub-DataSet"
                4. # this training process is optimized to enable fine-tuning on PASCAL VOC 2012 Dataset that has only 21 Classes...
                5. # IMPORTANT: The code is optimized for a fixed initial LR since the Polynomial Loss is pretty sensitive, so we keep the
                6. # same LR by dividing by the number of GPUs (since our code base multiplies it automatically)
                7. # P.S. - Use the relevant training params dict if you are running on TZAG or on V100
                8. import torch
                9. from super_gradients.training import Trainer, MultiGPUMode
                10. from super_gradients.training.datasets import CoCoSegmentationDatasetInterface
                11. from super_gradients.training.sg_trainer.sg_trainer import StrictLoad
                12. from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
                13. model_size_str = '34'
                14. coco_sub_classes_inclusion_tuples_list = [(0, 'background'), (5, 'airplane'), (2, 'bicycle'), (16, 'bird'),
                15. (9, 'boat'),
                16. (44, 'bottle'), (6, 'bus'), (3, 'car'), (17, 'cat'), (62, 'chair'),
                17. (21, 'cow'),
                18. (67, 'dining table'), (18, 'dog'), (19, 'horse'), (4, 'motorcycle'),
                19. (1, 'person'),
                20. (64, 'potted plant'), (20, 'sheep'), (63, 'couch'), (7, 'train'),
                21. (72, 'tv')]
                22. coco_seg_dataset_tzag_params = {
                23. "batch_size": 24,
                24. "test_batch_size": 24,
                25. "dataset_dir": "/data/coco/",
                26. "s3_link": None,
                27. "img_size": 608,
                28. "crop_size": 512
                29. }
                30. coco_seg_dataset_v100_params = {
                31. "batch_size": 32,
                32. "test_batch_size": 32,
                33. "dataset_dir": "/home/ubuntu/data/coco/",
                34. "s3_link": None,
                35. "img_size": 608,
                36. "crop_size": 512
                37. }
                38. shelfnet_coco_training_params = {
                39. "max_epochs": 150, "initial_lr": 5e-3, "loss": "shelfnet_ohem_loss",
                40. "optimizer": "SGD", "mixed_precision": True, "lr_mode": "poly",
                41. "optimizer_params": {"momentum": 0.9, "weight_decay": 1e-4, "nesterov": False},
                42. "load_opt_params": False, "train_metrics_list": [PixelAccuracy(), IoU(21)],
                43. "valid_metrics_list": [PixelAccuracy(), IoU(21)],
                44. "loss_logging_items_names": ["Loss1/4", "Loss1/8", "Loss1/16", "Loss"], "metric_to_watch": "IoU",
                45. "greater_metric_to_watch_is_better": True}
                46. shelfnet_lw_arch_params = {"num_classes": 21, "load_checkpoint": True, "strict_load": StrictLoad.ON,
                47. "multi_gpu_mode": "data_parallel", "load_weights_only": True,
                48. "load_backbone": True, "source_ckpt_folder_name": 'resnet' + model_size_str}
                49. data_loader_num_workers = 8 * torch.cuda.device_count()
                50. # BUILD THE LIGHT-WEIGHT SHELFNET ARCHITECTURE FOR TRAINING
                51. experiment_name_prefix = 'shelfnet_lw_'
                52. experiment_name_dataset_suffix = '_coco_seg_' + str(
                53. shelfnet_coco_training_params['max_epochs']) + '_epochs_train_example'
                54. experiment_name = experiment_name_prefix + model_size_str + experiment_name_dataset_suffix
                55. trainer = Trainer(experiment_name,
                56. multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL,
                57. ckpt_name='ckpt_best.pth')
                58. coco_seg_datasaet_interface = CoCoSegmentationDatasetInterface(dataset_params=coco_seg_dataset_tzag_params,
                59. cache_labels=False,
                60. dataset_classes_inclusion_tuples_list=coco_sub_classes_inclusion_tuples_list)
                61. trainer.connect_dataset_interface(coco_seg_datasaet_interface, data_loader_num_workers=data_loader_num_workers)
                62. trainer.build_model('shelfnet' + model_size_str, arch_params=shelfnet_lw_arch_params)
                63. print('Training ShelfNet-LW model: ' + experiment_name)
                64. trainer.train(training_params=shelfnet_coco_training_params)
                Discard
                @@ -1,32 +1,29 @@
                -from super_gradients.training.datasets.dataset_interfaces.dataset_interface import SuperviselyPersonsDatasetInterface
                +from super_gradients.training import models, dataloaders
                +
                 from super_gradients.training.sg_trainer import Trainer
                 from super_gradients.training.sg_trainer import Trainer
                 from super_gradients.training.metrics import BinaryIOU
                 from super_gradients.training.metrics import BinaryIOU
                 from super_gradients.training.transforms.transforms import ResizeSeg, RandomFlip, RandomRescale, CropImageAndMask, \
                 from super_gradients.training.transforms.transforms import ResizeSeg, RandomFlip, RandomRescale, CropImageAndMask, \
                     PadShortToCropSize, ColorJitterSeg
                     PadShortToCropSize, ColorJitterSeg
                 from super_gradients.training.utils.callbacks import BinarySegmentationVisualizationCallback, Phase
                 from super_gradients.training.utils.callbacks import BinarySegmentationVisualizationCallback, Phase
                -from torchvision import transforms
                 
                 
                 # DEFINE DATA TRANSFORMATIONS
                 # DEFINE DATA TRANSFORMATIONS
                -dataset_params = {
                -    "image_mask_transforms_aug": transforms.Compose([ColorJitterSeg(brightness=0.5, contrast=0.5, saturation=0.5),
                -                                                     RandomFlip(),
                -                                                     RandomRescale(scales=[0.25, 1.]),
                -                                                     PadShortToCropSize([320, 480]),
                -                                                     CropImageAndMask(crop_size=[320, 480],
                -                                                                      mode="random")]),
                -    "image_mask_transforms": transforms.Compose([ResizeSeg(h=480, w=320)])
                -}
                -
                -dataset_interface = SuperviselyPersonsDatasetInterface(dataset_params)
                 
                 
                -trainer = Trainer("regseg48_transfer_learning_old_dice_diff_lrs_head_fixed_50_epochs")
                +dl_train = dataloaders.supervisely_persons_train(
                +    dataset_params={"transforms": [ColorJitterSeg(brightness=0.5, contrast=0.5, saturation=0.5),
                +                                   RandomFlip(),
                +                                   RandomRescale(scales=[0.25, 1.]),
                +                                   PadShortToCropSize([320, 480]),
                +                                   CropImageAndMask(crop_size=[320, 480],
                +                                                    mode="random")]})
                 
                 
                -# CONNECTING THE DATASET INTERFACE WILL SET SGMODEL'S CLASSES ATTRIBUTE ACCORDING TO SUPERVISELY
                -trainer.connect_dataset_interface(dataset_interface)
                +dl_val = dataloaders.supervisely_persons_val(dataset_params={"transforms": [ResizeSeg(h=480, w=320)]})
                +
                +trainer = Trainer("regseg48_transfer_learning_old_dice_diff_lrs_head_fixed_50_epochs")
                 
                 
                 # THIS IS WHERE THE MAGIC HAPPENS- SINCE SGMODEL'S CLASSES ATTRIBUTE WAS SET TO BE DIFFERENT FROM CITYSCAPES'S, AFTER
                 # THIS IS WHERE THE MAGIC HAPPENS- SINCE SGMODEL'S CLASSES ATTRIBUTE WAS SET TO BE DIFFERENT FROM CITYSCAPES'S, AFTER
                 # LOADING THE PRETRAINED REGSET, IT WILL CALL IT'S REPLACE_HEAD METHOD AND CHANGE IT'S SEGMENTATION HEAD LAYER ACCORDING
                 # LOADING THE PRETRAINED REGSET, IT WILL CALL IT'S REPLACE_HEAD METHOD AND CHANGE IT'S SEGMENTATION HEAD LAYER ACCORDING
                 # TO OUR BINARY SEGMENTATION DATASET
                 # TO OUR BINARY SEGMENTATION DATASET
                +model = models.get("regseg48", pretrained_weights="cityscapes", num_classes=1)
                 trainer.build_model("regseg48", arch_params={"pretrained_weights": "cityscapes"})
                 trainer.build_model("regseg48", arch_params={"pretrained_weights": "cityscapes"})
                 
                 
                 # DEFINE TRAINING PARAMS. SEE DOCS FOR THE FULL LIST.
                 # DEFINE TRAINING PARAMS. SEE DOCS FOR THE FULL LIST.
                Discard
                @@ -16,9 +16,7 @@ Once triggered, the following will happen:
                 Finally, once training is over- we trigger a pos-training callback that will export the ONNX files.
                 Finally, once training is over- we trigger a pos-training callback that will export the ONNX files.
                 
                 
                 """
                 """
                -
                -from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ImageNetDatasetInterface
                -from super_gradients.training import Trainer, MultiGPUMode
                +from super_gradients.training import Trainer, MultiGPUMode, models, dataloaders
                 from super_gradients.training.metrics.classification_metrics import Accuracy
                 from super_gradients.training.metrics.classification_metrics import Accuracy
                 
                 
                 import super_gradients
                 import super_gradients
                @@ -26,13 +24,14 @@ from super_gradients.training.utils.quantization_utils import PostQATConversionC
                 
                 
                 super_gradients.init_trainer()
                 super_gradients.init_trainer()
                 
                 
                -dataset = ImageNetDatasetInterface(data_dir="/data/Imagenet", dataset_params={"batch_size": 128})
                 trainer = Trainer("resnet18_qat_example",
                 trainer = Trainer("resnet18_qat_example",
                                   model_checkpoints_location='local',
                                   model_checkpoints_location='local',
                                   multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL)
                                   multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL)
                 
                 
                -trainer.connect_dataset_interface(dataset)
                -trainer.build_model("resnet18", checkpoint_params={"pretrained_weights": "imagenet"})
                +train_loader = dataloaders.imagenet_train()
                +valid_loader = dataloaders.imagenet_val()
                +
                +model = models.get("resnet18", pretrained_weights="imagenet")
                 
                 
                 train_params = {"max_epochs": 1,
                 train_params = {"max_epochs": 1,
                                 "lr_mode": "step",
                                 "lr_mode": "step",
                @@ -58,4 +57,4 @@ train_params = {"max_epochs": 1,
                                 "phase_callbacks": [PostQATConversionCallback(dummy_input_size=(1, 3, 224, 224))]
                                 "phase_callbacks": [PostQATConversionCallback(dummy_input_size=(1, 3, 224, 224))]
                                 }
                                 }
                 
                 
                -trainer.train(training_params=train_params)
                +trainer.train(model=model, training_params=train_params, train_loader=train_loader, valid_loader=valid_loader)
                Discard
                  Discard
                  1
                  2
                  3
                  4
                  5
                  6
                  7
                  8
                  9
                  10
                  11
                  12
                  13
                  14
                  15
                  16
                  17
                  18
                  19
                  20
                  21
                  22
                  23
                  24
                  25
                  26
                  27
                  28
                  29
                  30
                  31
                  32
                  33
                  34
                  35
                  36
                  37
                  38
                  39
                  40
                  41
                  42
                  43
                  44
                  45
                  46
                  47
                  48
                  49
                  50
                  51
                  52
                  53
                  54
                  55
                  56
                  1. # TODO: REFACTOR AS YAML FILES RECIPE
                  2. import super_gradients
                  3. import torch
                  4. from super_gradients.training.datasets import PascalAUG2012SegmentationDataSetInterface
                  5. from super_gradients.training import Trainer, MultiGPUMode
                  6. from super_gradients.training.sg_trainer.sg_trainer import StrictLoad
                  7. from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
                  8. super_gradients.init_trainer()
                  9. pascal_aug_dataset_params = {"batch_size": 16,
                  10. "test_batch_size": 16,
                  11. "dataset_dir": "/data/pascal_voc_2012/VOCaug/dataset/",
                  12. "s3_link": None,
                  13. "img_size": 512,
                  14. "train_loader_drop_last": True,
                  15. }
                  16. shelfnet_lw_pascal_aug_training_params = {"max_epochs": 250, "initial_lr": 1e-2, "loss": "shelfnet_ohem_loss",
                  17. "optimizer": "SGD", "mixed_precision": False, "lr_mode": "poly",
                  18. "optimizer_params": {"momentum": 0.9, "weight_decay": 1e-4,
                  19. "nesterov": False},
                  20. "load_opt_params": False, "train_metrics_list": [PixelAccuracy(), IoU(21)],
                  21. "valid_metrics_list": [PixelAccuracy(), IoU(21)],
                  22. "loss_logging_items_names": ["Loss1/4", "Loss1/8", "Loss1/16", "Loss"],
                  23. "metric_to_watch": "IoU",
                  24. "greater_metric_to_watch_is_better": True}
                  25. shelfnet_lw_arch_params = {"num_classes": 21, "strict_load": StrictLoad.ON,
                  26. "multi_gpu_mode": MultiGPUMode.OFF}
                  27. checkpoint_params = {"load_checkpoint": True, "load_weights_only": True,
                  28. "load_backbone": True, "source_ckpt_folder_name": 'resnet_backbones'}
                  29. if torch.cuda.is_available() and torch.cuda.device_count() > 1:
                  30. data_loader_num_workers = 16
                  31. shelfnet_lw_pascal_aug_training_params["initial_lr"] = shelfnet_lw_pascal_aug_training_params["initial_lr"] / 2.
                  32. else:
                  33. # SINGLE GPU TRAINING
                  34. data_loader_num_workers = 8
                  35. # SET THE *LIGHT-WEIGHT* SHELFNET ARCHITECTURE SIZE (UN-COMMENT TO TRAIN)
                  36. model_size_str = '34'
                  37. # model_size_str = '18'
                  38. # BUILD THE LIGHT-WEIGHT SHELFNET ARCHITECTURE FOR TRAINING
                  39. experiment_name_prefix = 'shelfnet_lw_'
                  40. experiment_name_dataset_suffix = '_pascal_aug_encoding_dataset_train_250_epochs_no_batchnorm_decoder'
                  41. experiment_name = experiment_name_prefix + model_size_str + experiment_name_dataset_suffix
                  42. trainer = Trainer(experiment_name, model_checkpoints_location='local', multi_gpu=True,
                  43. ckpt_name='resnet' + model_size_str + '.pth')
                  44. pascal_aug_datasaet_interface = PascalAUG2012SegmentationDataSetInterface(
                  45. dataset_params=pascal_aug_dataset_params,
                  46. cache_labels=False)
                  47. trainer.connect_dataset_interface(pascal_aug_datasaet_interface, data_loader_num_workers=data_loader_num_workers)
                  48. trainer.build_model('shelfnet' + model_size_str, arch_params=shelfnet_lw_arch_params, checkpoint_params=checkpoint_params)
                  49. print('Training ShelfNet-LW model: ' + experiment_name)
                  50. trainer.train(training_params=shelfnet_lw_pascal_aug_training_params)
                  Discard
                    Discard
                    1
                    2
                    3
                    4
                    5
                    6
                    7
                    8
                    9
                    10
                    11
                    12
                    13
                    14
                    15
                    16
                    17
                    18
                    19
                    20
                    21
                    22
                    23
                    24
                    25
                    26
                    27
                    28
                    29
                    30
                    31
                    32
                    33
                    34
                    35
                    36
                    37
                    38
                    39
                    40
                    41
                    42
                    43
                    44
                    1. """
                    2. This file is used to define the Dataset used for the Training.
                    3. """
                    4. import torchvision.datasets as datasets
                    5. import torchvision.transforms as transforms
                    6. from super_gradients.training import utils as core_utils
                    7. from super_gradients.training.datasets.dataset_interfaces import DatasetInterface
                    8. class UserDataset(DatasetInterface):
                    9. """
                    10. The user's dataset inherits from SuperGradient's DatasetInterface and must
                    11. contain a trainset and testset from which the the data will be loaded using.
                    12. All augmentations, resizing and parsing must be done in this class.
                    13. - Augmentations are defined below and will be carried out in the order they are given.
                    14. super_gradients provides additional dataset reading tools such as ListDataset given a list of files
                    15. corresponding to the images and labels.
                    16. """
                    17. def __init__(self, name="cifar10", dataset_params={}):
                    18. super(UserDataset, self).__init__(dataset_params)
                    19. self.dataset_name = name
                    20. self.lib_dataset_params = {'mean': (0.4914, 0.4822, 0.4465), 'std': (0.2023, 0.1994, 0.2010)}
                    21. crop_size = core_utils.get_param(self.dataset_params, 'crop_size', default_val=32)
                    22. transform_train = transforms.Compose([
                    23. transforms.RandomCrop(crop_size, padding=4),
                    24. transforms.RandomHorizontalFlip(),
                    25. transforms.ToTensor(),
                    26. transforms.Normalize(self.lib_dataset_params['mean'], self.lib_dataset_params['std']),
                    27. ])
                    28. transform_test = transforms.Compose([
                    29. transforms.ToTensor(),
                    30. transforms.Normalize(self.lib_dataset_params['mean'], self.lib_dataset_params['std']),
                    31. ])
                    32. self.trainset = datasets.CIFAR10(root=self.dataset_params.dataset_dir, train=True, download=True,
                    33. transform=transform_train)
                    34. self.testset = datasets.CIFAR10(root=self.dataset_params.dataset_dir, train=False, download=True,
                    35. transform=transform_test)
                    Discard
                    1
                    2
                    3
                    4
                    5
                    6
                    7
                    8
                    9
                    10
                    11
                    12
                    13
                    14
                    15
                    16
                    17
                    18
                    19
                    20
                    21
                    22
                    23
                    24
                    25
                    26
                    27
                    28
                    29
                    30
                    31
                    32
                    33
                    34
                    35
                    36
                    37
                    38
                    39
                    40
                    41
                    42
                    43
                    44
                    45
                    46
                    1. """
                    2. The loss must be of torch.nn.modules.loss._Loss class.
                    3. For commonly used losses, import from deci.core.ADNN.losses
                    4. -IMPORTANT: forward(...) should return (loss, loss_items) where loss is the tensor used for backprop (i.e what your
                    5. original loss function returns), and loss_items should be a tensor of shape (n_items), of values computed during
                    6. the forward pass which we desire to log over the entire epoch. For example- the loss itself should always be logged.
                    7. Another examploe is a scenario where the computed loss is the sum of a few components we would like to log- these
                    8. entries in loss_items).
                    9. -When training, set the "loss_logging_items_names" parameter in train_params to be a list of strings, of length
                    10. n_items who's ith element is the name of the ith entry in loss_items. Then each item will be logged, rendered on
                    11. tensorboard and "watched" (i.e saving model checkpoints according to it).
                    12. -Since running logs will save the loss_items in some internal state, it is recommended that loss_items are detached
                    13. from their computational graph for memory efficiency.
                    14. """
                    15. import torch.nn as nn
                    16. from super_gradients.training.losses.label_smoothing_cross_entropy_loss import cross_entropy
                    17. class LabelSmoothingCrossEntropyLoss(nn.CrossEntropyLoss):
                    18. """
                    19. LabelSmoothingCrossEntropyLoss - POC loss class, uses SuperGradient's cross entropy which support distribution as targets.
                    20. """
                    21. def __init__(self, weight=None, ignore_index=-100, reduction='mean', smooth_eps=None, smooth_dist=None,
                    22. from_logits=True):
                    23. super(LabelSmoothingCrossEntropyLoss, self).__init__(weight=weight,
                    24. ignore_index=ignore_index, reduction=reduction)
                    25. self.smooth_eps = smooth_eps
                    26. self.smooth_dist = smooth_dist
                    27. self.from_logits = from_logits
                    28. def forward(self, input, target, smooth_dist=None):
                    29. if smooth_dist is None:
                    30. smooth_dist = self.smooth_dist
                    31. loss = cross_entropy(input, target, weight=self.weight, ignore_index=self.ignore_index,
                    32. reduction=self.reduction, smooth_eps=self.smooth_eps,
                    33. smooth_dist=smooth_dist, from_logits=self.from_logits)
                    34. loss_items = loss.detach().unsqueeze(0)
                    35. return loss, loss_items
                    Discard
                    1
                    2
                    3
                    4
                    5
                    6
                    7
                    8
                    9
                    10
                    11
                    12
                    13
                    14
                    15
                    16
                    17
                    18
                    19
                    20
                    21
                    22
                    23
                    24
                    1. """
                    2. This file is used to define the Metrics used for training.
                    3. The metrics object must be of torchmetrics.Metric type. For more information on how to use torchmetric.Metric objects and
                    4. implement your own metrics see https://torchmetrics.readthedocs.io/en/latest/pages/overview.html
                    5. """
                    6. import torchmetrics
                    7. import torch
                    8. class Accuracy(torchmetrics.Accuracy):
                    9. def __init__(self, dist_sync_on_step=False):
                    10. super().__init__(dist_sync_on_step=dist_sync_on_step, top_k=1)
                    11. def update(self, preds: torch.Tensor, target: torch.Tensor):
                    12. super().update(preds=preds.softmax(1), target=target)
                    13. class Top5(torchmetrics.Accuracy):
                    14. def __init__(self, dist_sync_on_step=False):
                    15. super().__init__(dist_sync_on_step=dist_sync_on_step, top_k=5)
                    16. def update(self, preds: torch.Tensor, target: torch.Tensor):
                    17. super().update(preds=preds.softmax(1), target=target)
                    Discard
                    1
                    2
                    3
                    4
                    5
                    6
                    7
                    8
                    9
                    10
                    11
                    12
                    13
                    14
                    15
                    16
                    17
                    18
                    19
                    20
                    21
                    22
                    23
                    24
                    25
                    26
                    27
                    28
                    29
                    30
                    31
                    32
                    33
                    34
                    35
                    36
                    37
                    38
                    39
                    40
                    41
                    42
                    43
                    44
                    45
                    46
                    47
                    48
                    49
                    50
                    51
                    52
                    53
                    54
                    55
                    56
                    57
                    58
                    59
                    60
                    61
                    62
                    63
                    64
                    65
                    66
                    67
                    68
                    69
                    70
                    71
                    72
                    73
                    74
                    75
                    76
                    77
                    78
                    79
                    80
                    81
                    82
                    83
                    84
                    85
                    86
                    87
                    88
                    89
                    90
                    91
                    92
                    93
                    94
                    95
                    96
                    97
                    98
                    99
                    100
                    101
                    102
                    103
                    104
                    105
                    106
                    107
                    108
                    109
                    110
                    111
                    112
                    113
                    114
                    115
                    116
                    117
                    118
                    119
                    120
                    121
                    122
                    123
                    124
                    125
                    126
                    127
                    128
                    129
                    130
                    131
                    132
                    133
                    134
                    135
                    136
                    137
                    138
                    139
                    140
                    141
                    142
                    143
                    144
                    145
                    146
                    147
                    148
                    149
                    150
                    151
                    152
                    153
                    154
                    1. """
                    2. This file is used to define the model used for training. For example, in this template, we define ResNet50.
                    3. One may use existing models from torchvision as well (e.g., torchvision.models.resnet50)
                    4. """
                    5. import torch.nn as nn
                    6. import torch.nn.functional as F
                    7. from collections import OrderedDict
                    8. class BasicBlock(nn.Module):
                    9. expansion = 1
                    10. def __init__(self, in_planes, planes, stride=1):
                    11. super(BasicBlock, self).__init__()
                    12. self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
                    13. self.bn1 = nn.BatchNorm2d(planes)
                    14. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
                    15. self.bn2 = nn.BatchNorm2d(planes)
                    16. self.shortcut = nn.Sequential()
                    17. if stride != 1 or in_planes != self.expansion * planes:
                    18. self.shortcut = nn.Sequential(
                    19. nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                    20. nn.BatchNorm2d(self.expansion * planes)
                    21. )
                    22. def forward(self, x):
                    23. out = F.relu(self.bn1(self.conv1(x)))
                    24. out = self.bn2(self.conv2(out))
                    25. out += self.shortcut(x)
                    26. out = F.relu(out)
                    27. return out
                    28. class Bottleneck(nn.Module):
                    29. expansion = 4
                    30. def __init__(self, in_planes, planes, stride=1):
                    31. super(Bottleneck, self).__init__()
                    32. self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
                    33. self.bn1 = nn.BatchNorm2d(planes)
                    34. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
                    35. self.bn2 = nn.BatchNorm2d(planes)
                    36. self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
                    37. self.bn3 = nn.BatchNorm2d(self.expansion * planes)
                    38. self.shortcut = nn.Sequential()
                    39. if stride != 1 or in_planes != self.expansion * planes:
                    40. self.shortcut = nn.Sequential(
                    41. nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                    42. nn.BatchNorm2d(self.expansion * planes)
                    43. )
                    44. def forward(self, x):
                    45. out = F.relu(self.bn1(self.conv1(x)))
                    46. out = F.relu(self.bn2(self.conv2(out)))
                    47. out = self.bn3(self.conv3(out))
                    48. out += self.shortcut(x)
                    49. out = F.relu(out)
                    50. return out
                    51. def width_multiplier(original, factor):
                    52. return int(original * factor)
                    53. class ResNet(nn.Module):
                    54. def __init__(self, block, num_blocks: list, num_classes: int = 10, width_mult: float = 1,
                    55. input_batchnorm: bool = False, backbone_mode: bool = False):
                    56. super(ResNet, self).__init__()
                    57. self.backbone_mode = backbone_mode
                    58. self.structure = [num_blocks, width_mult]
                    59. self.in_planes = width_multiplier(64, width_mult)
                    60. self.input_batchnorm = input_batchnorm
                    61. if self.input_batchnorm:
                    62. self.bn0 = nn.BatchNorm2d(3)
                    63. self.conv1 = nn.Conv2d(3, width_multiplier(64, width_mult), kernel_size=7, stride=2, padding=3, bias=False)
                    64. self.bn1 = nn.BatchNorm2d(width_multiplier(64, width_mult))
                    65. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                    66. self.layer1 = self._make_layer(block, width_multiplier(64, width_mult), num_blocks[0], stride=1)
                    67. self.layer2 = self._make_layer(block, width_multiplier(128, width_mult), num_blocks[1], stride=2)
                    68. self.layer3 = self._make_layer(block, width_multiplier(256, width_mult), num_blocks[2], stride=2)
                    69. self.layer4 = self._make_layer(block, width_multiplier(512, width_mult), num_blocks[3], stride=2)
                    70. if not self.backbone_mode:
                    71. # IF RESNET IS IN BACK_BONE MODE WE DON'T NEED THE FINAL CLASSIFIER LAYERS, BUT ONLY THE NET BLOCK STRUCTURE
                    72. self.linear = nn.Linear(width_multiplier(512, width_mult) * block.expansion, num_classes)
                    73. self.avgpool = nn.AdaptiveAvgPool2d(1)
                    74. def _make_layer(self, block, planes, num_blocks, stride):
                    75. strides = [stride] + [1] * (num_blocks - 1)
                    76. layers = []
                    77. if num_blocks == 0:
                    78. # When the number of blocks is zero but spatial dimension and/or number of filters about to change we put 1
                    79. # 3X3 conv layer to make this change to the new dimensions.
                    80. if stride != 1 or self.in_planes != planes:
                    81. layers.append(nn.Sequential(
                    82. nn.Conv2d(self.in_planes, planes, kernel_size=3, stride=stride, bias=False, padding=1),
                    83. nn.BatchNorm2d(planes))
                    84. )
                    85. self.in_planes = planes
                    86. else:
                    87. for stride in strides:
                    88. layers.append(block(self.in_planes, planes, stride))
                    89. self.in_planes = planes * block.expansion
                    90. return nn.Sequential(*layers)
                    91. def forward(self, x):
                    92. if self.input_batchnorm:
                    93. x = self.bn0(x)
                    94. out = F.relu(self.bn1(self.conv1(x)))
                    95. out = self.maxpool(out)
                    96. out = self.layer1(out)
                    97. out = self.layer2(out)
                    98. out = self.layer3(out)
                    99. out = self.layer4(out)
                    100. if not self.backbone_mode:
                    101. # IF RESNET IS *NOT* IN BACK_BONE MODE WE NEED THE FINAL CLASSIFIER LAYERS OUTPUTS
                    102. out = self.avgpool(out)
                    103. out = out.squeeze(dim=2).squeeze(dim=2)
                    104. out = self.linear(out)
                    105. return out
                    106. def load_state_dict(self, state_dict, strict=True):
                    107. """
                    108. load_state_dict - Overloads the base method and calls it to load a modified dict for usage as a backbone
                    109. :param state_dict: The state_dict to load
                    110. :param strict: strict loading (see super() docs)
                    111. """
                    112. pretrained_model_weights_dict = state_dict.copy()
                    113. if self.backbone_mode:
                    114. # FIRST LET'S POP THE LAST TWO LAYERS - NO NEED TO LOAD THEIR VALUES SINCE THEY ARE IRRELEVANT AS A BACKBONE
                    115. pretrained_model_weights_dict.popitem()
                    116. pretrained_model_weights_dict.popitem()
                    117. pretrained_backbone_weights_dict = OrderedDict()
                    118. for layer_name, weights in pretrained_model_weights_dict.items():
                    119. # GET THE LAYER NAME WITHOUT THE 'module.' PREFIX
                    120. name_without_module_prefix = layer_name.split('module.')[1]
                    121. # MAKE SURE THESE ARE NOT THE FINAL LAYERS
                    122. pretrained_backbone_weights_dict[name_without_module_prefix] = weights
                    123. # RETURNING THE UNMODIFIED/MODIFIED STATE DICT DEPENDING ON THE backbone_mode VALUE
                    124. super().load_state_dict(pretrained_backbone_weights_dict, strict)
                    125. else:
                    126. super().load_state_dict(pretrained_model_weights_dict, strict)
                    Discard
                    1
                    2
                    3
                    4
                    5
                    6
                    7
                    8
                    9
                    10
                    11
                    12
                    13
                    14
                    15
                    16
                    17
                    18
                    19
                    20
                    21
                    22
                    23
                    24
                    25
                    26
                    27
                    28
                    29
                    30
                    31
                    32
                    33
                    34
                    35
                    36
                    37
                    38
                    39
                    40
                    41
                    42
                    43
                    44
                    45
                    46
                    47
                    48
                    49
                    50
                    51
                    52
                    53
                    54
                    55
                    1. from super_gradients.training import Trainer
                    2. from super_gradients.training import MultiGPUMode
                    3. from dataset import UserDataset
                    4. from model import ResNet, BasicBlock
                    5. from loss import LabelSmoothingCrossEntropyLoss
                    6. from metrics import Accuracy, Top5
                    7. def main():
                    8. # ------------------ Loading The Model From Model.py----------------
                    9. arch_params = {'num_classes': 10}
                    10. model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=arch_params['num_classes'])
                    11. trainer = Trainer('client_model_training',
                    12. model_checkpoints_location='local',
                    13. multi_gpu=MultiGPUMode.OFF)
                    14. # if a torch.nn.Module is provided when building the model, the model will be integrated into deci model class
                    15. trainer.build_model(model, arch_params=arch_params)
                    16. # ------------------ Loading The Dataset From Dataset.py----------------
                    17. dataset_params = {"batch_size": 256}
                    18. dataset = UserDataset(dataset_params)
                    19. trainer.connect_dataset_interface(dataset)
                    20. # ------------------ Loading The Loss From Loss.py -----------------
                    21. loss = LabelSmoothingCrossEntropyLoss()
                    22. # ------------------ Defining the metrics we wish to log -----------------
                    23. train_metrics_list = [Accuracy(), Top5()]
                    24. valid_metrics_list = [Accuracy(), Top5()]
                    25. # ------------------ Training -----------------
                    26. train_params = {"max_epochs": 250,
                    27. "lr_updates": [100, 150, 200],
                    28. "lr_decay_factor": 0.1,
                    29. "lr_mode": "step",
                    30. "lr_warmup_epochs": 0,
                    31. "initial_lr": 0.1,
                    32. "loss": loss,
                    33. "criterion_params": {},
                    34. "optimizer": "SGD",
                    35. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                    36. "launch_tensorboard": False,
                    37. "train_metrics_list": train_metrics_list,
                    38. "valid_metrics_list": valid_metrics_list,
                    39. "loss_logging_items_names": ["Loss"],
                    40. "metric_to_watch": "Accuracy",
                    41. "greater_metric_to_watch_is_better": True}
                    42. trainer.train(train_params)
                    43. if __name__ == '__main__':
                    44. main()
                    Discard
                    @@ -12,9 +12,8 @@ defaults:
                       - arch_params: resnet18_cifar_arch_params
                       - arch_params: resnet18_cifar_arch_params
                       - checkpoint_params: default_checkpoint_params
                       - checkpoint_params: default_checkpoint_params
                     
                     
                    -dataset_interface:
                    -  cifar_10:
                    -    dataset_params: ${dataset_params}
                    +train_dataloader: cifar10_train
                    +val_dataloader: cifar10_val
                     
                     
                     data_loader_num_workers: 8
                     data_loader_num_workers: 8
                     
                     
                    @@ -27,3 +26,5 @@ model_checkpoints_location: local
                     ckpt_root_dir:
                     ckpt_root_dir:
                     
                     
                     architecture: resnet18_cifar
                     architecture: resnet18_cifar
                    +
                    +experiment_name: resnet18_cifar
                    Discard
                    @@ -38,57 +38,12 @@
                     
                     
                     defaults:
                     defaults:
                       - training_hyperparams: cityscapes_default_train_params
                       - training_hyperparams: cityscapes_default_train_params
                    -  #  - dataset_params: cityscapes_ddrnet_dataset_params # TODO: uncomment after DatasetInterface refactor
                    +  - dataset_params: cityscapes_ddrnet_dataset_params
                       - checkpoint_params: default_checkpoint_params
                       - checkpoint_params: default_checkpoint_params
                       - _self_
                       - _self_
                     
                     
                    -# ------------------------------------- legacy dataset params ------------------------------------- #
                    -dataset_params:
                    -  _convert_: all
                    -  batch_size: 6
                    -  val_batch_size: 6
                    -  dataset_dir: /data/cityscapes
                    -  crop_size: [ 1024, 1024 ]
                    -  img_size: 1024
                    -  train_loader_drop_last: True
                    -  color_jitter: 0.5
                    -  random_scales: [ 0.5, 2. ]
                    -  eval_scale: 1.
                    -  cityscapes_ignored_label: 19
                    -
                    -  image_mask_transforms_aug:
                    -    Compose:
                    -      transforms:
                    -        - ColorJitterSeg:
                    -            brightness: ${dataset_params.color_jitter}
                    -            contrast: ${dataset_params.color_jitter}
                    -            saturation: ${dataset_params.color_jitter}
                    -
                    -        - RandomFlipSeg
                    -
                    -        - RandomRescaleSeg:
                    -            scales: ${dataset_params.random_scales}
                    -
                    -        - PadShortToCropSizeSeg:
                    -            crop_size: ${dataset_params.crop_size}
                    -            fill_mask: ${dataset_params.cityscapes_ignored_label}
                    -
                    -        - CropImageAndMaskSeg:
                    -            crop_size: ${dataset_params.crop_size}
                    -            mode: random
                    -
                    -  image_mask_transforms:
                    -    Compose:
                    -      transforms:
                    -        - RescaleSeg:
                    -            scale_factor: ${dataset_params.eval_scale}
                    -
                    -dataset_interface:
                    -  cityscapes:
                    -    dataset_params: ${dataset_params}
                    -
                    -data_loader_num_workers: 8
                    -# ------------------------------------- legacy dataset params ------------------------------------- #
                    +train_dataloader: cityscapes_train
                    +val_dataloader: cityscapes_val
                     
                     
                     architecture: ddrnet_23
                     architecture: ddrnet_23
                     
                     
                    @@ -107,13 +62,13 @@ training_hyperparams:
                           edge_kernel: 5
                           edge_kernel: 5
                       loss_logging_items_names: [ main_loss, aux_loss1, loss ]
                       loss_logging_items_names: [ main_loss, aux_loss1, loss ]
                     
                     
                    -
                     arch_params:
                     arch_params:
                       num_classes: 19
                       num_classes: 19
                       aux_head: True
                       aux_head: True
                       sync_bn: True
                       sync_bn: True
                     
                     
                     
                     
                    +
                     load_checkpoint: False
                     load_checkpoint: False
                     checkpoint_params:
                     checkpoint_params:
                       load_checkpoint: ${load_checkpoint}
                       load_checkpoint: ${load_checkpoint}
                    @@ -121,7 +76,6 @@ checkpoint_params:
                       load_backbone: True
                       load_backbone: True
                       strict_load: no_key_matching
                       strict_load: no_key_matching
                     
                     
                    -
                     experiment_name: ${architecture}_cityscapes
                     experiment_name: ${architecture}_cityscapes
                     
                     
                     model_checkpoints_location: local
                     model_checkpoints_location: local
                    Discard
                    @@ -27,58 +27,12 @@
                     
                     
                     defaults:
                     defaults:
                       - training_hyperparams: default_train_params
                       - training_hyperparams: default_train_params
                    -  #  - dataset_params: cityscapes_regseg48_dataset_params # TODO: uncomment after DatasetInterface refactor
                    +  - dataset_params: cityscapes_regseg48_dataset_params
                       - checkpoint_params: default_checkpoint_params
                       - checkpoint_params: default_checkpoint_params
                       - _self_
                       - _self_
                     
                     
                    -# ------------------------------------- legacy dataset params ------------------------------------- #
                    -dataset_params:
                    -  _convert_: all
                    -  batch_size: 4
                    -  val_batch_size: 4
                    -  dataset_dir: /data/cityscapes
                    -  crop_size: 1024
                    -  img_size: 1024
                    -  train_loader_drop_last: True
                    -  color_jitter: 0.1
                    -  random_scales: [ 0.4, 1.6 ]
                    -  cityscapes_ignored_label: 19
                    -
                    -  image_mask_transforms_aug:
                    -    Compose:
                    -      transforms:
                    -        - ColorJitterSeg:
                    -            brightness: ${dataset_params.color_jitter}
                    -            contrast: ${dataset_params.color_jitter}
                    -            saturation: ${dataset_params.color_jitter}
                    -
                    -        - RandomFlipSeg
                    -
                    -        - RandomRescaleSeg:
                    -            scales: ${dataset_params.random_scales}
                    -
                    -        - PadShortToCropSizeSeg:
                    -            crop_size: ${dataset_params.crop_size}
                    -            fill_image:
                    -              - ${dataset_params.cityscapes_ignored_label}
                    -              - 0
                    -              - 0
                    -            fill_mask: ${dataset_params.cityscapes_ignored_label}
                    -
                    -        - CropImageAndMaskSeg:
                    -            crop_size: ${dataset_params.crop_size}
                    -            mode: random
                    -
                    -  image_mask_transforms:
                    -    Compose:
                    -      transforms: [ ]
                    -
                    -dataset_interface:
                    -  cityscapes:
                    -    dataset_params: ${dataset_params}
                    -
                    -data_loader_num_workers: 8
                    -# ------------------------------------- legacy dataset params ------------------------------------- #
                    +train_dataloader: cityscapes_train
                    +val_dataloader: cityscapes_val
                     
                     
                     cityscapes_ignored_label: 19    # convenience parameter since it is used in many places in the YAML
                     cityscapes_ignored_label: 19    # convenience parameter since it is used in many places in the YAML
                     
                     
                    Discard
                    @@ -5,39 +5,8 @@ defaults:
                       - dataset_params: cityscapes_dataset_params
                       - dataset_params: cityscapes_dataset_params
                       - checkpoint_params: default_checkpoint_params
                       - checkpoint_params: default_checkpoint_params
                     
                     
                    -
                    -dataset_params:
                    -  color_jitter: 0.5
                    -  image_mask_transforms_aug:
                    -    Compose:
                    -      transforms:
                    -        - ColorJitterSeg:
                    -            brightness: ${dataset_params.color_jitter}
                    -            contrast: ${dataset_params.color_jitter}
                    -            saturation: ${dataset_params.color_jitter}
                    -
                    -        - RandomFlipSeg
                    -
                    -        - RandomRescaleSeg:
                    -            scales: ${dataset_params.random_scales}
                    -
                    -        - PadShortToCropSizeSeg:
                    -            crop_size: ${dataset_params.crop_size}
                    -            fill_mask: ${dataset_params.cityscapes_ignored_label}
                    -
                    -        - CropImageAndMaskSeg:
                    -            crop_size: ${dataset_params.crop_size}
                    -            mode: random
                    -
                    -  image_mask_transforms:
                    -    Compose:
                    -      transforms:
                    -        - RescaleSeg:
                    -            scale_factor: ${dataset_params.eval_scale}
                    -
                    -dataset_interface:
                    -  cityscapes:
                    -    dataset_params: ${dataset_params}
                    +train_dataloader: cityscapes_train
                    +val_dataloader: cityscapes_val
                     
                     
                     data_loader_num_workers: 10
                     data_loader_num_workers: 10
                     
                     
                    Discard
                    @@ -39,62 +39,15 @@
                     
                     
                     defaults:
                     defaults:
                       - training_hyperparams: cityscapes_default_train_params
                       - training_hyperparams: cityscapes_default_train_params
                    -  #  - dataset_params: cityscapes_stdc_seg50_dataset_params # TODO: uncomment after DatasetInterface refactor
                    +  - dataset_params: cityscapes_stdc_seg50_dataset_params # TODO: uncomment after DatasetInterface refactor
                       - checkpoint_params: default_checkpoint_params
                       - checkpoint_params: default_checkpoint_params
                       - _self_
                       - _self_
                     
                     
                    -# ------------------------------------- legacy dataset params ------------------------------------- #
                    -dataset_params:
                    -  _convert_: all
                    -  batch_size: 16
                    -  val_batch_size: 16
                    -  dataset_dir: /data/cityscapes
                    -  crop_size: [ 1024, 512 ]
                    -  img_size: 1024
                    -  train_loader_drop_last: True
                    -  color_jitter: 0.5
                    -  random_scales: [ 0.125, 1.5 ]
                    -  eval_scale: 0.5
                    -  cityscapes_ignored_label: 19
                    -
                    -  image_mask_transforms_aug:
                    -    Compose:
                    -      transforms:
                    -        - ColorJitterSeg:
                    -            brightness: ${dataset_params.color_jitter}
                    -            contrast: ${dataset_params.color_jitter}
                    -            saturation: ${dataset_params.color_jitter}
                    -
                    -        - RandomFlipSeg
                    -
                    -        - RandomRescaleSeg:
                    -            scales: ${dataset_params.random_scales}
                    -
                    -        - PadShortToCropSizeSeg:
                    -            crop_size: ${dataset_params.crop_size}
                    -            fill_mask: ${dataset_params.cityscapes_ignored_label}
                    -
                    -        - CropImageAndMaskSeg:
                    -            crop_size: ${dataset_params.crop_size}
                    -            mode: random
                    -
                    -  image_mask_transforms:
                    -    Compose:
                    -      transforms:
                    -        - RescaleSeg:
                    -            scale_factor: ${dataset_params.eval_scale}
                    -
                    -dataset_interface:
                    -  cityscapes:
                    -    dataset_params: ${dataset_params}
                    -
                    -data_loader_num_workers: 10
                    -# ------------------------------------- legacy dataset params ------------------------------------- #
                    -
                    +train_dataloader: cityscapes_train
                    +val_dataloader: cityscapes_val
                     
                     
                     architecture: stdc1_seg
                     architecture: stdc1_seg
                     
                     
                    -
                     arch_params:
                     arch_params:
                       num_classes: 19
                       num_classes: 19
                       use_aux_heads: True
                       use_aux_heads: True
                    Discard
                    @@ -42,57 +42,12 @@
                     
                     
                     defaults:
                     defaults:
                       - training_hyperparams: cityscapes_default_train_params
                       - training_hyperparams: cityscapes_default_train_params
                    -  #  - dataset_params: cityscapes_stdc_seg75_dataset_params # TODO: uncomment after DatasetInterface refactor
                    +  - dataset_params: cityscapes_stdc_seg75_dataset_params
                       - checkpoint_params: default_checkpoint_params
                       - checkpoint_params: default_checkpoint_params
                       - _self_
                       - _self_
                     
                     
                    -# ------------------------------------- legacy dataset params ------------------------------------- #
                    -dataset_params:
                    -  _convert_: all
                    -  batch_size: 4
                    -  val_batch_size: 4
                    -  dataset_dir: /data/cityscapes
                    -  crop_size: [ 1536, 768 ]
                    -  img_size: 1024
                    -  train_loader_drop_last: True
                    -  color_jitter: 0.5
                    -  random_scales: [ 0.125, 1.5 ]
                    -  eval_scale: 0.75
                    -  cityscapes_ignored_label: 19
                    -
                    -  image_mask_transforms_aug:
                    -    Compose:
                    -      transforms:
                    -        - ColorJitterSeg:
                    -            brightness: ${dataset_params.color_jitter}
                    -            contrast: ${dataset_params.color_jitter}
                    -            saturation: ${dataset_params.color_jitter}
                    -
                    -        - RandomFlipSeg
                    -
                    -        - RandomRescaleSeg:
                    -            scales: ${dataset_params.random_scales}
                    -
                    -        - PadShortToCropSizeSeg:
                    -            crop_size: ${dataset_params.crop_size}
                    -            fill_mask: ${dataset_params.cityscapes_ignored_label}
                    -
                    -        - CropImageAndMaskSeg:
                    -            crop_size: ${dataset_params.crop_size}
                    -            mode: random
                    -
                    -  image_mask_transforms:
                    -    Compose:
                    -      transforms:
                    -        - RescaleSeg:
                    -            scale_factor: ${dataset_params.eval_scale}
                    -
                    -dataset_interface:
                    -  cityscapes:
                    -    dataset_params: ${dataset_params}
                    -
                    -data_loader_num_workers: 10
                    -# ------------------------------------- legacy dataset params ------------------------------------- #
                    +train_dataloader: cityscapes_train
                    +val_dataloader: cityscapes_val
                     
                     
                     
                     
                     architecture: stdc1_seg
                     architecture: stdc1_seg
                    Discard
                    @@ -28,6 +28,9 @@ defaults:
                       - checkpoint_params: default_checkpoint_params
                       - checkpoint_params: default_checkpoint_params
                       - anchors: ssd_anchors
                       - anchors: ssd_anchors
                     
                     
                    +train_dataloader: coco2017_train
                    +val_dataloader: coco2017_val
                    +
                     architecture: ssd_lite_mobilenet_v2
                     architecture: ssd_lite_mobilenet_v2
                     
                     
                     data_loader_num_workers: 8
                     data_loader_num_workers: 8
                    @@ -43,9 +46,6 @@ arch_params:
                       num_classes: 80
                       num_classes: 80
                       anchors: ${dboxes}
                       anchors: ${dboxes}
                     
                     
                    -dataset_interface:
                    -  coco2017_detection:
                    -    dataset_params: ${dataset_params}
                     resume: False
                     resume: False
                     training_hyperparams:
                     training_hyperparams:
                       resume: ${resume}
                       resume: ${resume}
                    Discard
                    @@ -28,12 +28,8 @@ defaults:
                       - arch_params: yolox_s_arch_params
                       - arch_params: yolox_s_arch_params
                       - checkpoint_params: default_checkpoint_params
                       - checkpoint_params: default_checkpoint_params
                     
                     
                    -dataset_interface:
                    -  coco2017_detection:
                    -    dataset_params: ${dataset_params}
                    -
                    -
                    -data_loader_num_workers: 8
                    +train_dataloader: coco2017_train
                    +val_dataloader: coco2017_val
                     
                     
                     model_checkpoints_location: local
                     model_checkpoints_location: local
                     
                     
                    @@ -46,7 +42,7 @@ architecture: yolox_s
                     
                     
                     multi_gpu: DDP
                     multi_gpu: DDP
                     
                     
                    -experiment_suffix: res${dataset_params.train_image_size}
                    +experiment_suffix: res${dataset_params.train_dataset_params.input_dim}
                     experiment_name: ${architecture}_coco2017_${experiment_suffix}
                     experiment_name: ${architecture}_coco2017_${experiment_suffix}
                     
                     
                     ckpt_root_dir:
                     ckpt_root_dir:
                    Discard
                    @@ -16,31 +16,13 @@
                     
                     
                     defaults:
                     defaults:
                       - training_hyperparams: coco_segmentation_shelfnet_lw_train_params
                       - training_hyperparams: coco_segmentation_shelfnet_lw_train_params
                    -  #  - dataset_params: coco_segmentation_dataset_params # TODO: uncomment after DatasetInterface refactor
                    +  - dataset_params: coco_segmentation_dataset_params
                       - arch_params: shelfnet34_lw_arch_params
                       - arch_params: shelfnet34_lw_arch_params
                       - checkpoint_params: default_checkpoint_params
                       - checkpoint_params: default_checkpoint_params
                       - _self_
                       - _self_
                     
                     
                    -# ------------------------------------- legacy dataset params ------------------------------------- #
                    -dataset_params:
                    -  batch_size: 8
                    -  val_batch_size: 24
                    -  dataset_dir: "/data/coco/"
                    -  img_size: 608
                    -  crop_size: 512
                    -  train_loader_drop_last: True
                    -
                    -sub_classes:
                    -  _target_: super_gradients.training.utils.segmentation_utils.coco_sub_classes_inclusion_tuples_list
                    -
                    -dataset_interface:
                    -  _target_: super_gradients.training.datasets.dataset_interfaces.dataset_interface.CoCoSegmentationDatasetInterface
                    -  dataset_params: ${dataset_params}
                    -  dataset_classes_inclusion_tuples_list: ${sub_classes}
                    -
                    -data_loader_num_workers: 8
                    -# ------------------------------------- legacy dataset params ------------------------------------- #
                    -
                    +train_dataloader: coco_segmentation_train
                    +val_dataloader: coco_segmentation_val
                     
                     
                     checkpoint_params:
                     checkpoint_params:
                       strict_load: True
                       strict_load: True
                    Discard
                    @@ -1,8 +1,3 @@
                    -batch_size: 256 # batch size for trainset
                    -val_batch_size: 512 # batch size for valset in DatasetInterface
                    -
                    -# TODO: REMOVE ABOVE, HERE FOR COMPATIBILITY UNTIL WE REMOVE DATASET_INTERFACE
                    -
                     train_dataset_params:
                     train_dataset_params:
                       root: /data/cifar100
                       root: /data/cifar100
                       train: True
                       train: True
                    Discard
                    @@ -4,7 +4,7 @@ val_batch_size: 512 # batch size for valset in DatasetInterface
                     # TODO: REMOVE ABOVE, HERE FOR COMPATIBILITY UNTIL WE REMOVE DATASET_INTERFACE
                     # TODO: REMOVE ABOVE, HERE FOR COMPATIBILITY UNTIL WE REMOVE DATASET_INTERFACE
                     
                     
                     train_dataset_params:
                     train_dataset_params:
                    -  root: /data/cifar10
                    +  root: ./data/cifar10
                       train: True
                       train: True
                       transforms:
                       transforms:
                         - RandomCrop:
                         - RandomCrop:
                    @@ -31,7 +31,7 @@ train_dataloader_params:
                       pin_memory: True
                       pin_memory: True
                     
                     
                     val_dataset_params:
                     val_dataset_params:
                    -  root: /data/cifar10
                    +  root: ./data/cifar10
                       train: False
                       train: False
                       transforms:
                       transforms:
                         - ToTensor
                         - ToTensor
                    Discard
                    @@ -1,81 +1,3 @@
                    -data_dir: /data/coco # root path to coco data
                    -train_subdir: images/train2017 # sub directory path of data_dir containing the train data.
                    -val_subdir: images/val2017 # sub directory path of data_dir containing the validation data.
                    -train_json_file: instances_train2017.json # path to coco train json file, data_dir/annotations/train_json_file.
                    -val_json_file: instances_val2017.json # path to coco validation json file, data_dir/annotations/val_json_file.
                    -
                    -cache_dir: # path to a directory that will be used for caching (with numpy.memmap).
                    -cache_train_images: False
                    -cache_val_images: False
                    -
                    -batch_size: 16 # batch size for trainset
                    -val_batch_size: 64 # batch size for valset
                    -train_image_size: 640
                    -val_image_size: 640
                    -train_input_dim:
                    -  - ${dataset_params.train_image_size}
                    -  - ${dataset_params.train_image_size}
                    -val_input_dim:
                    -  - ${dataset_params.val_image_size}
                    -  - ${dataset_params.val_image_size}
                    -
                    -filter_box_candidates: False
                    -targets_format:
                    -  _target_: super_gradients.training.utils.detection_utils.DetectionTargetsFormat # targets format
                    -  value: LABEL_CXCYWH
                    -
                    -tight_box_rotation: False
                    -train_transforms:
                    -  - _target_: super_gradients.training.transforms.transforms.DetectionMosaic
                    -    input_dim: ${dataset_params.train_input_dim}
                    -    prob: 1.
                    -  - _target_: super_gradients.training.transforms.transforms.DetectionRandomAffine
                    -    degrees: 10.                  # rotation degrees, randomly sampled from [-degrees, degrees]
                    -    translate: 0.1                # image translation fraction
                    -    scales: [0.1, 2]              # random rescale range (keeps size by padding/cropping) after mosaic transform.
                    -    shear: 2.0                    # shear degrees, randomly sampled from [-degrees, degrees]
                    -    target_size: ${dataset_params.train_input_dim}
                    -    filter_box_candidates: False  # whether to filter out transformed bboxes by edge size, area ratio, and aspect ratio.
                    -    wh_thr: 2                     # edge size threshold when filter_box_candidates = True (pixels)
                    -    area_thr: 0.1                 # threshold for area ratio between original image and the transformed one, when when filter_box_candidates = True
                    -    ar_thr: 20                    # aspect ratio threshold when filter_box_candidates = True
                    -  - _target_: super_gradients.training.transforms.transforms.DetectionMixup
                    -    input_dim: ${dataset_params.train_input_dim}
                    -    mixup_scale: [0.5, 1.5]         # random rescale range for the additional sample in mixup
                    -    prob: 1.0                       # probability to apply per-sample mixup
                    -    flip_prob: 0.5                  # probability to apply horizontal flip
                    -  - _target_: super_gradients.training.transforms.transforms.DetectionHSV
                    -    prob: 1.0                       # probability to apply HSV transform
                    -    hgain: 5                        # HSV transform hue gain (randomly sampled from [-hgain, hgain])
                    -    sgain: 30                       # HSV transform saturation gain (randomly sampled from [-sgain, sgain])
                    -    vgain: 30                       # HSV transform value gain (randomly sampled from [-vgain, vgain])
                    -  - _target_: super_gradients.training.transforms.transforms.DetectionHorizontalFlip
                    -    prob: 0.5                       # probability to apply horizontal flip
                    -  - _target_: super_gradients.training.transforms.transforms.DetectionPaddedRescale
                    -    input_dim: ${dataset_params.train_input_dim}
                    -    max_targets: 120
                    -  - _target_: super_gradients.training.transforms.transforms.DetectionTargetsFormatTransform
                    -    output_format: ${dataset_params.targets_format}
                    -
                    -val_transforms:
                    -  - _target_: super_gradients.training.transforms.transforms.DetectionPaddedRescale
                    -    input_dim: ${dataset_params.val_input_dim}
                    -  - _target_: super_gradients.training.transforms.transforms.DetectionTargetsFormatTransform
                    -    max_targets: 50
                    -    output_format: ${dataset_params.targets_format}
                    -
                    -val_collate_fn: # collate function for valset
                    -  _target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
                    -train_collate_fn: # collate function for trainset
                    -  _target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
                    -
                    -class_inclusion_list: # If not None,every class not included will be ignored.
                    -train_max_num_samples: # If not None, only specified number of samples will be loaded in train dataset
                    -val_max_num_samples:   # If not None, only specified number of samples will be loaded in test dataset
                    -with_crowd: False     # Whether to return "crowd" labels in validation
                    -
                    -# TODO: REMOVE ABOVE, HERE FOR COMPATIBILITY UNTIL WE REMOVE DATASET_INTERFACE
                    -
                     train_dataset_params:
                     train_dataset_params:
                       data_dir: /data/coco # root path to coco data
                       data_dir: /data/coco # root path to coco data
                       subdir: images/train2017 # sub directory path of data_dir containing the train data.
                       subdir: images/train2017 # sub directory path of data_dir containing the train data.
                    Discard
                    @@ -5,7 +5,6 @@ cache_dir: # path to a directory that will be used for caching (with numpy.memma
                     cache_train_images: False
                     cache_train_images: False
                     cache_val_images: False
                     cache_val_images: False
                     
                     
                    -
                     batch_size: 32
                     batch_size: 32
                     val_batch_size: 16
                     val_batch_size: 16
                     train_image_size: 320
                     train_image_size: 320
                    @@ -126,7 +125,7 @@ val_dataset_params:
                       cache: False
                       cache: False
                       transforms:
                       transforms:
                         - DetectionPaddedRescale:
                         - DetectionPaddedRescale:
                    -      input_dim: ${dataset_params.val_dataset_params.input_dim}
                    +        input_dim: ${dataset_params.val_dataset_params.input_dim}
                         - DetectionTargetsFormatTransform:
                         - DetectionTargetsFormatTransform:
                             max_targets: 50
                             max_targets: 50
                             output_format:
                             output_format:
                    Discard
                    @@ -1,25 +1,6 @@
                    -batch_size: 64 # batch size for trainset in DatasetInterface
                    -val_batch_size: 200 # batch size for valset in DatasetInterface
                    -dataset_dir: /data/Imagenet # path to imagenet directory (local)
                    -traindir: train # dirname inside dataset_dir holding trainset files
                    -valdir: val # dirname inside dataset_dir holding valset files
                    -img_mean: [0.485, 0.456, 0.406] # mean for normalization
                    -img_std: [0.229, 0.224, 0.225] # std for normalization
                    -crop_size: 224 # crop size (size of net's input)
                    -resize_size: 256 # loaded image resize size (appplied first among preprocessing transforms)
                    -color_jitter: 0.0 # color jitter augmentation (applied only to trainset)
                    -imagenet_pca_aug: 0.0 # imagenet pca augmentation (applied only to trainset)
                    -train_interpolation: default # interpolation mode
                    -rand_augment_config_string: # randaugment config string (see super_gradients/training/datasets/auto_augment.py)
                    -random_erase_prob: 0.0 # random erase probability (applied only to trainset)
                    -aug_repeat_count: 0 # amount of repetitions (each repetition of an example is augmented differently) for a trainset example.
                    -
                    -
                    -# TODO: REMOVE ABOVE, HERE FOR COMPATIBILITY UNTIL WE REMOVE DATASET_INTERFACE
                    -# TODO: UNCOMMENT BELOW WHEN ABOVE IS REMOVED
                     # Base recipe for ImageNet Datasets amd Dataloaders.
                     # Base recipe for ImageNet Datasets amd Dataloaders.
                    -#img_mean: [0.485, 0.456, 0.406] # mean for normalization
                    -#img_std: [0.229, 0.224, 0.225]  # std  for normalization
                    +img_mean: [0.485, 0.456, 0.406] # mean for normalization
                    +img_std: [0.229, 0.224, 0.225]  # std  for normalization
                     
                     
                     train_dataset_params:
                     train_dataset_params:
                       root: /data/Imagenet/train
                       root: /data/Imagenet/train
                    Discard
                    @@ -1,18 +1,6 @@
                     defaults:
                     defaults:
                       - imagenet_dataset_params
                       - imagenet_dataset_params
                     
                     
                    -batch_size: 64
                    -color_jitter: 0.4
                    -random_erase_prob: 0.2
                    -random_erase_value: random
                    -train_interpolation: random
                    -auto_augment_config_string: rand-m9-mstd0.5
                    -
                    -# TODO: REMOVE ABOVE, HERE FOR COMPATIBILITY UNTIL WE REMOVE DATASET_INTERFACE
                    -# TODO: UNCOMMENT BELOW WHEN ABOVE IS REMOVED
                    -#defaults:
                    -#  - imagenet_dataset_params
                    -
                     train_dataset_params:
                     train_dataset_params:
                       root: /data/Imagenet/train
                       root: /data/Imagenet/train
                       transforms:
                       transforms:
                    Discard
                    @@ -1,19 +1,6 @@
                     defaults:
                     defaults:
                       - imagenet_dataset_params
                       - imagenet_dataset_params
                     
                     
                    -train_loader_drop_last: True
                    -batch_size: 256
                    -val_batch_size: 256
                    -random_erase_prob: 0.2
                    -random_erase_value: random
                    -train_interpolation: random
                    -config_string: rand-m9-mstd0.5
                    -
                    -# TODO: REMOVE ABOVE, HERE FOR COMPATIBILITY UNTIL WE REMOVE DATASET_INTERFACE
                    -# TODO: UNCOMMENT BELOW WHEN ABOVE IS REMOVED
                    -#defaults:
                    -#  - imagenet_dataset_params
                    -
                     train_dataset_params:
                     train_dataset_params:
                       root: /data/Imagenet/train
                       root: /data/Imagenet/train
                       transforms:
                       transforms:
                    Discard
                    @@ -1,13 +1,6 @@
                     defaults:
                     defaults:
                       - imagenet_dataset_params
                       - imagenet_dataset_params
                     
                     
                    -batch_size: 128
                    -
                    -# TODO: REMOVE ABOVE, HERE FOR COMPATIBILITY UNTIL WE REMOVE DATASET_INTERFACE
                    -# TODO: UNCOMMENT BELOW WHEN ABOVE IS REMOVED
                    -#defaults:
                    -#  - imagenet_dataset_params
                    -
                     train_dataloader_params:
                     train_dataloader_params:
                       batch_size: 128
                       batch_size: 128
                       num_workers: 16
                       num_workers: 16
                    Discard
                    @@ -1,18 +1,6 @@
                     defaults:
                     defaults:
                       - imagenet_dataset_params
                       - imagenet_dataset_params
                     
                     
                    -batch_size: 256
                    -color_jitter: 0.4
                    -random_erase_prob: 0.2
                    -random_erase_value: random
                    -train_interpolation: random
                    -auto_augment_config_string: rand-m9-mstd0.5
                    -
                    -# TODO: REMOVE ABOVE, HERE FOR COMPATIBILITY UNTIL WE REMOVE DATASET_INTERFACE
                    -# TODO: UNCOMMENT BELOW WHEN ABOVE IS REMOVED
                    -#defaults:
                    -#  - imagenet_dataset_params
                    -
                     train_dataset_params:
                     train_dataset_params:
                       root: /data/Imagenet/train
                       root: /data/Imagenet/train
                       transforms:
                       transforms:
                    Discard
                    @@ -1,23 +1,6 @@
                     defaults:
                     defaults:
                       - imagenet_dataset_params
                       - imagenet_dataset_params
                     
                     
                    -resize_size: 236
                    -random_erase_prob: 0
                    -random_erase_value: random
                    -train_interpolation: random
                    -config_string: rand-m7-mstd0.5
                    -cutmix: True
                    -cutmix_params:
                    -  mixup_alpha: 0.2
                    -  cutmix_alpha: 1.0
                    -  label_smoothing: 0.1
                    -
                    -
                    -# TODO: REMOVE ABOVE, HERE FOR COMPATIBILITY UNTIL WE REMOVE DATASET_INTERFACE
                    -# TODO: UNCOMMENT BELOW WHEN ABOVE IS REMOVED
                    -#defaults:
                    -#  - imagenet_dataset_params
                    -
                     train_dataset_params:
                     train_dataset_params:
                       root: /data/Imagenet/train
                       root: /data/Imagenet/train
                       transforms:
                       transforms:
                    Discard
                    @@ -1,31 +1,12 @@
                     defaults:
                     defaults:
                       - imagenet_dataset_params
                       - imagenet_dataset_params
                     
                     
                    -batch_size: 192
                    -val_batch_size: 256
                    -random_erase_prob: 0
                    -random_erase_value: random
                    -train_interpolation: random
                    -config_string: rand-m7-mstd0.5
                    -cutmix: True
                    -cutmix_params:
                    -  mixup_alpha: 0.2
                    -  cutmix_alpha: 1.0
                    -  label_smoothing: 0.1
                    -aug_repeat_count: 3
                    -
                    -
                    -# TODO: REMOVE ABOVE, HERE FOR COMPATIBILITY UNTIL WE REMOVE DATASET_INTERFACE
                    -# TODO: UNCOMMENT BELOW WHEN ABOVE IS REMOVED
                    -#defaults:
                    -#  - imagenet_dataset_params
                    -#
                    -#train_transform_args:
                    -#  interpolation: random
                    -#  color_jitter: [0.4, 0.4, 0.4]
                    -#  random_erase_prob: 0.
                    -#  random_erase_value: random
                    -#  auto_augment_config_string: rand-m7-mstd0.5
                    +train_transform_args:
                    +  interpolation: random
                    +  color_jitter: [0.4, 0.4, 0.4]
                    +  random_erase_prob: 0.
                    +  random_erase_value: random
                    +  auto_augment_config_string: rand-m7-mstd0.5
                     
                     
                     train_dataset_params:
                     train_dataset_params:
                       root: /data/Imagenet/train
                       root: /data/Imagenet/train
                    Discard
                    @@ -1,26 +1,6 @@
                     defaults:
                     defaults:
                       - imagenet_dataset_params
                       - imagenet_dataset_params
                     
                     
                    -resize_size: 249
                    -batch_size: 64
                    -random_erase_prob: 0
                    -random_erase_value: random
                    -train_interpolation: random
                    -config_string: rand-m7-mstd0.5
                    -cutmix: True
                    -cutmix_params:
                    -  mixup_alpha: 0.2
                    -  cutmix_alpha: 1.0
                    -  label_smoothing: 0.1
                    -img_mean: [0.5, 0.5, 0.5]
                    -img_std: [0.5, 0.5, 0.5]
                    -
                    -
                    -# TODO: REMOVE ABOVE, HERE FOR COMPATIBILITY UNTIL WE REMOVE DATASET_INTERFACE
                    -# TODO: UNCOMMENT BELOW WHEN ABOVE IS REMOVED
                    -#defaults:
                    -#  - imagenet_dataset_params
                    -
                     train_dataset_params:
                     train_dataset_params:
                       root: /data/Imagenet/train
                       root: /data/Imagenet/train
                       transforms:
                       transforms:
                    Discard
                    @@ -17,12 +17,9 @@ defaults:
                     arch_params:
                     arch_params:
                       num_classes: 1000
                       num_classes: 1000
                     
                     
                    -dataset_interface:
                    -  _target_: super_gradients.training.datasets.dataset_interfaces.dataset_interface.ImageNetDatasetInterface
                    -  dataset_params: ${dataset_params}
                    -  data_dir: /data/Imagenet
                    +train_dataloader: imagenet_train
                    +val_dataloader: imagenet_val
                     
                     
                    -data_loader_num_workers: 8
                     
                     
                     resume: False
                     resume: False
                     training_hyperparams:
                     training_hyperparams:
                    Discard
                    @@ -15,14 +15,13 @@ defaults:
                       - arch_params: mobilenet_v2_arch_params
                       - arch_params: mobilenet_v2_arch_params
                       - checkpoint_params: default_checkpoint_params
                       - checkpoint_params: default_checkpoint_params
                     
                     
                    +train_dataloader: imagenet_train
                    +val_dataloader: imagenet_val
                    +
                     arch_params:
                     arch_params:
                       num_classes: 1000
                       num_classes: 1000
                       dropout: 0.2
                       dropout: 0.2
                     
                     
                    -dataset_interface:
                    -  _target_: super_gradients.training.datasets.dataset_interfaces.dataset_interface.ImageNetDatasetInterface
                    -  dataset_params: ${dataset_params}
                    -  data_dir: /data/Imagenet
                     
                     
                     data_loader_num_workers: 8
                     data_loader_num_workers: 8
                     
                     
                    Discard
                    @@ -5,12 +5,8 @@ defaults:
                       - dataset_params: imagenet_mobilenetv3_dataset_params
                       - dataset_params: imagenet_mobilenetv3_dataset_params
                       - checkpoint_params: default_checkpoint_params
                       - checkpoint_params: default_checkpoint_params
                     
                     
                    -dataset_interface:
                    -  _target_: super_gradients.training.datasets.dataset_interfaces.dataset_interface.ImageNetDatasetInterface
                    -  dataset_params: ${dataset_params}
                    -  data_dir: /data/Imagenet
                    -
                    -data_loader_num_workers: 16
                    +train_dataloader: imagenet_train
                    +val_dataloader: imagenet_val
                     
                     
                     model_checkpoints_location: local
                     model_checkpoints_location: local
                     resume: False
                     resume: False
                    Discard
                    @@ -33,12 +33,8 @@ arch_params:
                       dropout_prob: 0.5
                       dropout_prob: 0.5
                       droppath_prob: 0.0
                       droppath_prob: 0.0
                     
                     
                    -dataset_interface:
                    -  _target_: super_gradients.training.datasets.dataset_interfaces.dataset_interface.ImageNetDatasetInterface
                    -  dataset_params: ${dataset_params}
                    -  data_dir: /data/Imagenet
                    -
                    -data_loader_num_workers: 8
                    +train_dataloader: imagenet_train
                    +val_dataloader: imagenet_val
                     
                     
                     model_checkpoints_location: local
                     model_checkpoints_location: local
                     load_checkpoint: False
                     load_checkpoint: False
                    Discard
                    @@ -21,12 +21,8 @@ arch_params:
                       num_classes: 1000
                       num_classes: 1000
                       build_residual_branches: True
                       build_residual_branches: True
                     
                     
                    -dataset_interface:
                    -  _target_: super_gradients.training.datasets.dataset_interfaces.dataset_interface.ImageNetDatasetInterface
                    -  dataset_params: ${dataset_params}
                    -  data_dir: /data/Imagenet
                    -
                    -data_loader_num_workers: 8
                    +train_dataloader: imagenet_train
                    +val_dataloader: imagenet_val
                     
                     
                     
                     
                     model_checkpoints_location: local
                     model_checkpoints_location: local
                    Discard
                    @@ -21,11 +21,8 @@ defaults:
                     arch_params:
                     arch_params:
                       droppath_prob: 0.05
                       droppath_prob: 0.05
                     
                     
                    -dataset_interface:
                    -  imagenet:
                    -    dataset_params: ${dataset_params}
                    -
                    -data_loader_num_workers: 8
                    +train_dataloader: imagenet_train
                    +val_dataloader: imagenet_val
                     
                     
                     model_checkpoints_location: local
                     model_checkpoints_location: local
                     resume: False
                     resume: False
                    Discard
                    @@ -18,6 +18,8 @@ defaults:
                       - arch_params: default_arch_params
                       - arch_params: default_arch_params
                       - checkpoint_params: default_checkpoint_params
                       - checkpoint_params: default_checkpoint_params
                     
                     
                    +train_dataloader: imagenet_train
                    +val_dataloader: imagenet_val
                     
                     
                     resume: False
                     resume: False
                     training_hyperparams:
                     training_hyperparams:
                    @@ -64,12 +66,6 @@ student_checkpoint_params:
                       pretrained_weights: # a string describing the dataset of the pretrained weights (for example "imagenent").
                       pretrained_weights: # a string describing the dataset of the pretrained weights (for example "imagenent").
                     
                     
                     
                     
                    -dataset_interface:
                    -  imagenet:
                    -    dataset_params: ${dataset_params}
                    -
                    -data_loader_num_workers: 8
                    -
                     model_checkpoints_location: local
                     model_checkpoints_location: local
                     
                     
                     
                     
                    Discard
                    @@ -18,11 +18,8 @@ defaults:
                       - arch_params: vit_base_arch_params
                       - arch_params: vit_base_arch_params
                       - checkpoint_params: vit_base_imagenet_checkpoint_params
                       - checkpoint_params: vit_base_imagenet_checkpoint_params
                     
                     
                    -dataset_interface:
                    -  imagenet:
                    -    dataset_params: ${dataset_params}
                    -
                    -data_loader_num_workers: 8
                    +train_dataloader: imagenet_train
                    +val_dataloader: imagenet_val
                     
                     
                     model_checkpoints_location: local
                     model_checkpoints_location: local
                     
                     
                    Discard
                    @@ -16,7 +16,6 @@ defaults:
                       - imagenet_vit_base
                       - imagenet_vit_base
                     
                     
                     dataset_params:
                     dataset_params:
                    -  batch_size: 32
                       train_dataloader_params:
                       train_dataloader_params:
                         batch_size: 32
                         batch_size: 32
                     
                     
                    Discard
                    1
                    2
                    3
                    4
                    5
                    6
                    7
                    8
                    9
                    10
                    11
                    12
                    13
                    14
                    15
                    16
                    17
                    18
                    19
                    20
                    21
                    22
                    23
                    24
                    1. # This recipe is for testing purpose only
                    2. defaults:
                    3. - training_hyperparams: cifar10_resnet_train_params
                    4. - arch_params: resnet18_cifar_arch_params
                    5. - checkpoint_params: default_checkpoint_params
                    6. dataset_interface:
                    7. classification_test_dataset:
                    8. dataset_params:
                    9. batch_size: 10
                    10. data_loader_num_workers: 1
                    11. resume: False
                    12. training_hyperparams:
                    13. resume: $(resume}
                    14. experiment_name: test
                    15. model_checkpoints_location: local
                    16. architecture: resnet18
                    Discard
                    @@ -1,7 +1,6 @@
                     # PACKAGE IMPORTS FOR EXTERNAL USAGE
                     # PACKAGE IMPORTS FOR EXTERNAL USAGE
                     import super_gradients.training.utils.distributed_training_utils as distributed_training_utils
                     import super_gradients.training.utils.distributed_training_utils as distributed_training_utils
                    -from super_gradients.training.datasets import datasets_utils, DataAugmentation, TestDatasetInterface, SegmentationTestDatasetInterface, \
                    -    DetectionTestDatasetInterface, ClassificationTestDatasetInterface
                    +from super_gradients.training.datasets import datasets_utils, DataAugmentation
                     from super_gradients.training.models import ARCHITECTURES
                     from super_gradients.training.models import ARCHITECTURES
                     from super_gradients.training.sg_trainer import Trainer
                     from super_gradients.training.sg_trainer import Trainer
                     from super_gradients.training.kd_trainer import KDTrainer
                     from super_gradients.training.kd_trainer import KDTrainer
                    @@ -9,6 +8,5 @@ from super_gradients.training.sg_model import SgModel
                     from super_gradients.training.kd_model import KDModel
                     from super_gradients.training.kd_model import KDModel
                     from super_gradients.common import MultiGPUMode, StrictLoad, EvaluationType
                     from super_gradients.common import MultiGPUMode, StrictLoad, EvaluationType
                     
                     
                    -__all__ = ['distributed_training_utils', 'datasets_utils', 'DataAugmentation', 'TestDatasetInterface',
                    -           'ARCHITECTURES', 'Trainer', 'KDTrainer', 'MultiGPUMode', 'TestDatasetInterface', 'SegmentationTestDatasetInterface', 'DetectionTestDatasetInterface',
                    -           'ClassificationTestDatasetInterface', 'StrictLoad', 'SgModel', 'EvaluationType', 'KDModel']
                    +__all__ = ['distributed_training_utils', 'datasets_utils', 'DataAugmentation',
                    +           'ARCHITECTURES', 'Trainer', 'KDTrainer', 'MultiGPUMode', 'StrictLoad', 'SgModel', 'EvaluationType', 'KDModel']
                    Discard
                    @@ -0,0 +1,31 @@
                    +from .dataloaders import coco2017_train, coco2017_val, coco2017_train_yolox, coco2017_val_yolox, \
                    +    coco2017_train_ssd_lite_mobilenet_v2, coco2017_val_ssd_lite_mobilenet_v2, imagenet_train, imagenet_val, \
                    +    imagenet_efficientnet_train, imagenet_efficientnet_val, imagenet_mobilenetv2_train, imagenet_mobilenetv2_val, \
                    +    imagenet_mobilenetv3_train, imagenet_mobilenetv3_val, imagenet_regnetY_train, imagenet_regnetY_val, \
                    +    imagenet_resnet50_train, imagenet_resnet50_val, imagenet_resnet50_kd_train, imagenet_resnet50_kd_val, \
                    +    imagenet_vit_base_train, imagenet_vit_base_val, tiny_imagenet_train, tiny_imagenet_val, cifar10_train, cifar10_val, \
                    +    cifar100_train, cifar100_val, cityscapes_train, cityscapes_val, cityscapes_stdc_seg50_train, \
                    +    cityscapes_stdc_seg50_val, cityscapes_stdc_seg75_train, cityscapes_stdc_seg75_val, cityscapes_regseg48_train, \
                    +    cityscapes_regseg48_val, cityscapes_ddrnet_train, cityscapes_ddrnet_val, coco_segmentation_train, \
                    +    coco_segmentation_val, pascal_aug_segmentation_train, pascal_aug_segmentation_val, pascal_voc_segmentation_train, \
                    +    pascal_voc_segmentation_val, supervisely_persons_train, supervisely_persons_val, pascal_voc_detection_train, \
                    +    pascal_voc_detection_val, get_data_loader, get
                    +
                    +__all__ = ["coco2017_train", "coco2017_val", "coco2017_train_yolox", "coco2017_val_yolox",
                    +           "coco2017_train_ssd_lite_mobilenet_v2", "coco2017_val_ssd_lite_mobilenet_v2", "imagenet_train",
                    +           "imagenet_val",
                    +           "imagenet_efficientnet_train", "imagenet_efficientnet_val", "imagenet_mobilenetv2_train",
                    +           "imagenet_mobilenetv2_val",
                    +           "imagenet_mobilenetv3_train", "imagenet_mobilenetv3_val", "imagenet_regnetY_train", "imagenet_regnetY_val",
                    +           "imagenet_resnet50_train", "imagenet_resnet50_val", "imagenet_resnet50_kd_train", "imagenet_resnet50_kd_val",
                    +           "imagenet_vit_base_train", "imagenet_vit_base_val", "tiny_imagenet_train", "tiny_imagenet_val",
                    +           "cifar10_train", "cifar10_val",
                    +           "cifar100_train", "cifar100_val", "cityscapes_train", "cityscapes_val", "cityscapes_stdc_seg50_train",
                    +           "cityscapes_stdc_seg50_val", "cityscapes_stdc_seg75_train", "cityscapes_stdc_seg75_val",
                    +           "cityscapes_regseg48_train",
                    +           "cityscapes_regseg48_val", "cityscapes_ddrnet_train", "cityscapes_ddrnet_val", "coco_segmentation_train",
                    +           "coco_segmentation_val", "pascal_aug_segmentation_train", "pascal_aug_segmentation_val",
                    +           "pascal_voc_segmentation_train",
                    +           "pascal_voc_segmentation_val", "supervisely_persons_train", "supervisely_persons_val",
                    +           "pascal_voc_detection_train",
                    +           "pascal_voc_detection_val", "get_data_loader", "get"]
                    Discard
                    @@ -29,7 +29,7 @@ from super_gradients.training.utils.utils import override_default_params_without
                     logger = get_logger(__name__)
                     logger = get_logger(__name__)
                     
                     
                     
                     
                    -def get_data_loader(config_name, dataset_cls, train, dataset_params={}, dataloader_params={}):
                    +def get_data_loader(config_name, dataset_cls, train, dataset_params=None, dataloader_params=None):
                         """
                         """
                         Class for creating dataloaders for taking defaults from yaml files in src/super_gradients/recipes.
                         Class for creating dataloaders for taking defaults from yaml files in src/super_gradients/recipes.
                     
                     
                    @@ -44,6 +44,11 @@ def get_data_loader(config_name, dataset_cls, train, dataset_params={}, dataload
                         :param dataloader_params: DataLoader params that override the yaml configured defaults, then passed to the DataLoader.__init__
                         :param dataloader_params: DataLoader params that override the yaml configured defaults, then passed to the DataLoader.__init__
                         :return: DataLoader
                         :return: DataLoader
                         """
                         """
                    +    if dataloader_params is None:
                    +        dataloader_params = dict()
                    +    if dataset_params is None:
                    +        dataset_params = dict()
                    +
                         GlobalHydra.instance().clear()
                         GlobalHydra.instance().clear()
                         with initialize_config_dir(config_dir=pkg_resources.resource_filename("super_gradients.recipes", "")):
                         with initialize_config_dir(config_dir=pkg_resources.resource_filename("super_gradients.recipes", "")):
                             # config is relative to a module
                             # config is relative to a module
                    @@ -100,7 +105,7 @@ def _instantiate_sampler(dataset, dataloader_params):
                         return dataloader_params
                         return dataloader_params
                     
                     
                     
                     
                    -def coco2017_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def coco2017_train(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="coco_detection_dataset_params",
                         return get_data_loader(config_name="coco_detection_dataset_params",
                                                dataset_cls=COCODetectionDataset,
                                                dataset_cls=COCODetectionDataset,
                                                train=True,
                                                train=True,
                    @@ -109,7 +114,7 @@ def coco2017_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                                                )
                                                )
                     
                     
                     
                     
                    -def coco2017_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def coco2017_val(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="coco_detection_dataset_params",
                         return get_data_loader(config_name="coco_detection_dataset_params",
                                                dataset_cls=COCODetectionDataset,
                                                dataset_cls=COCODetectionDataset,
                                                train=False,
                                                train=False,
                    @@ -118,15 +123,15 @@ def coco2017_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                                                )
                                                )
                     
                     
                     
                     
                    -def coco2017_train_yolox(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def coco2017_train_yolox(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return coco2017_train(dataset_params, dataloader_params)
                         return coco2017_train(dataset_params, dataloader_params)
                     
                     
                     
                     
                    -def coco2017_val_yolox(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def coco2017_val_yolox(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return coco2017_val(dataset_params, dataloader_params)
                         return coco2017_val(dataset_params, dataloader_params)
                     
                     
                     
                     
                    -def coco2017_train_ssd_lite_mobilenet_v2(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def coco2017_train_ssd_lite_mobilenet_v2(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="coco_detection_ssd_lite_mobilenet_v2_dataset_params",
                         return get_data_loader(config_name="coco_detection_ssd_lite_mobilenet_v2_dataset_params",
                                                dataset_cls=COCODetectionDataset,
                                                dataset_cls=COCODetectionDataset,
                                                train=True,
                                                train=True,
                    @@ -135,7 +140,7 @@ def coco2017_train_ssd_lite_mobilenet_v2(dataset_params: Dict = {}, dataloader_p
                                                )
                                                )
                     
                     
                     
                     
                    -def coco2017_val_ssd_lite_mobilenet_v2(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def coco2017_val_ssd_lite_mobilenet_v2(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="coco_detection_ssd_lite_mobilenet_v2_dataset_params",
                         return get_data_loader(config_name="coco_detection_ssd_lite_mobilenet_v2_dataset_params",
                                                dataset_cls=COCODetectionDataset,
                                                dataset_cls=COCODetectionDataset,
                                                train=False,
                                                train=False,
                    @@ -144,7 +149,7 @@ def coco2017_val_ssd_lite_mobilenet_v2(dataset_params: Dict = {}, dataloader_par
                                                )
                                                )
                     
                     
                     
                     
                    -def imagenet_train(dataset_params={}, dataloader_params={}, config_name="imagenet_dataset_params"):
                    +def imagenet_train(dataset_params=None, dataloader_params=None, config_name="imagenet_dataset_params"):
                         return get_data_loader(config_name=config_name,
                         return get_data_loader(config_name=config_name,
                                                dataset_cls=ImageNetDataset,
                                                dataset_cls=ImageNetDataset,
                                                train=True,
                                                train=True,
                    @@ -152,7 +157,7 @@ def imagenet_train(dataset_params={}, dataloader_params={}, config_name="imagene
                                                dataloader_params=dataloader_params)
                                                dataloader_params=dataloader_params)
                     
                     
                     
                     
                    -def imagenet_val(dataset_params={}, dataloader_params={}, config_name="imagenet_dataset_params"):
                    +def imagenet_val(dataset_params=None, dataloader_params=None, config_name="imagenet_dataset_params"):
                         return get_data_loader(config_name=config_name,
                         return get_data_loader(config_name=config_name,
                                                dataset_cls=ImageNetDataset,
                                                dataset_cls=ImageNetDataset,
                                                train=False,
                                                train=False,
                    @@ -160,63 +165,63 @@ def imagenet_val(dataset_params={}, dataloader_params={}, config_name="imagenet_
                                                dataloader_params=dataloader_params)
                                                dataloader_params=dataloader_params)
                     
                     
                     
                     
                    -def imagenet_efficientnet_train(dataset_params={}, dataloader_params={}):
                    +def imagenet_efficientnet_train(dataset_params=None, dataloader_params=None):
                         return imagenet_train(dataset_params, dataloader_params, config_name="imagenet_efficientnet_dataset_params")
                         return imagenet_train(dataset_params, dataloader_params, config_name="imagenet_efficientnet_dataset_params")
                     
                     
                     
                     
                    -def imagenet_efficientnet_val(dataset_params={}, dataloader_params={}):
                    +def imagenet_efficientnet_val(dataset_params=None, dataloader_params=None):
                         return imagenet_val(dataset_params, dataloader_params, config_name="imagenet_efficientnet_dataset_params")
                         return imagenet_val(dataset_params, dataloader_params, config_name="imagenet_efficientnet_dataset_params")
                     
                     
                     
                     
                    -def imagenet_mobilenetv2_train(dataset_params={}, dataloader_params={}):
                    +def imagenet_mobilenetv2_train(dataset_params=None, dataloader_params=None):
                         return imagenet_train(dataset_params, dataloader_params, config_name="imagenet_mobilenetv2_dataset_params")
                         return imagenet_train(dataset_params, dataloader_params, config_name="imagenet_mobilenetv2_dataset_params")
                     
                     
                     
                     
                    -def imagenet_mobilenetv2_val(dataset_params={}, dataloader_params={}):
                    +def imagenet_mobilenetv2_val(dataset_params=None, dataloader_params=None):
                         return imagenet_val(dataset_params, dataloader_params, config_name="imagenet_mobilenetv2_dataset_params")
                         return imagenet_val(dataset_params, dataloader_params, config_name="imagenet_mobilenetv2_dataset_params")
                     
                     
                     
                     
                    -def imagenet_mobilenetv3_train(dataset_params={}, dataloader_params={}):
                    +def imagenet_mobilenetv3_train(dataset_params=None, dataloader_params=None):
                         return imagenet_train(dataset_params, dataloader_params, config_name="imagenet_mobilenetv3_dataset_params")
                         return imagenet_train(dataset_params, dataloader_params, config_name="imagenet_mobilenetv3_dataset_params")
                     
                     
                     
                     
                    -def imagenet_mobilenetv3_val(dataset_params={}, dataloader_params={}):
                    +def imagenet_mobilenetv3_val(dataset_params=None, dataloader_params=None):
                         return imagenet_val(dataset_params, dataloader_params, config_name="imagenet_mobilenetv3_dataset_params")
                         return imagenet_val(dataset_params, dataloader_params, config_name="imagenet_mobilenetv3_dataset_params")
                     
                     
                     
                     
                    -def imagenet_regnetY_train(dataset_params={}, dataloader_params={}):
                    +def imagenet_regnetY_train(dataset_params=None, dataloader_params=None):
                         return imagenet_train(dataset_params, dataloader_params, config_name="imagenet_regnetY_dataset_params")
                         return imagenet_train(dataset_params, dataloader_params, config_name="imagenet_regnetY_dataset_params")
                     
                     
                     
                     
                    -def imagenet_regnetY_val(dataset_params={}, dataloader_params={}):
                    +def imagenet_regnetY_val(dataset_params=None, dataloader_params=None):
                         return imagenet_val(dataset_params, dataloader_params, config_name="imagenet_regnetY_dataset_params")
                         return imagenet_val(dataset_params, dataloader_params, config_name="imagenet_regnetY_dataset_params")
                     
                     
                     
                     
                    -def imagenet_resnet50_train(dataset_params={}, dataloader_params={}):
                    +def imagenet_resnet50_train(dataset_params=None, dataloader_params=None):
                         return imagenet_train(dataset_params, dataloader_params, config_name="imagenet_resnet50_dataset_params")
                         return imagenet_train(dataset_params, dataloader_params, config_name="imagenet_resnet50_dataset_params")
                     
                     
                     
                     
                    -def imagenet_resnet50_val(dataset_params={}, dataloader_params={}):
                    +def imagenet_resnet50_val(dataset_params=None, dataloader_params=None):
                         return imagenet_val(dataset_params, dataloader_params, config_name="imagenet_resnet50_dataset_params")
                         return imagenet_val(dataset_params, dataloader_params, config_name="imagenet_resnet50_dataset_params")
                     
                     
                     
                     
                    -def imagenet_resnet50_kd_train(dataset_params={}, dataloader_params={}):
                    +def imagenet_resnet50_kd_train(dataset_params=None, dataloader_params=None):
                         return imagenet_train(dataset_params, dataloader_params, config_name="imagenet_resnet50_kd_dataset_params")
                         return imagenet_train(dataset_params, dataloader_params, config_name="imagenet_resnet50_kd_dataset_params")
                     
                     
                     
                     
                    -def imagenet_resnet50_kd_val(dataset_params={}, dataloader_params={}):
                    +def imagenet_resnet50_kd_val(dataset_params=None, dataloader_params=None):
                         return imagenet_val(dataset_params, dataloader_params, config_name="imagenet_resnet50_kd_dataset_params")
                         return imagenet_val(dataset_params, dataloader_params, config_name="imagenet_resnet50_kd_dataset_params")
                     
                     
                     
                     
                    -def imagenet_vit_base_train(dataset_params={}, dataloader_params={}):
                    +def imagenet_vit_base_train(dataset_params=None, dataloader_params=None):
                         return imagenet_train(dataset_params, dataloader_params, config_name="imagenet_vit_base_dataset_params")
                         return imagenet_train(dataset_params, dataloader_params, config_name="imagenet_vit_base_dataset_params")
                     
                     
                     
                     
                    -def imagenet_vit_base_val(dataset_params={}, dataloader_params={}):
                    +def imagenet_vit_base_val(dataset_params=None, dataloader_params=None):
                         return imagenet_val(dataset_params, dataloader_params, config_name="imagenet_vit_base_dataset_params")
                         return imagenet_val(dataset_params, dataloader_params, config_name="imagenet_vit_base_dataset_params")
                     
                     
                     
                     
                    -def tiny_imagenet_train(dataset_params={}, dataloader_params={}, config_name="tiny_imagenet_dataset_params"):
                    +def tiny_imagenet_train(dataset_params=None, dataloader_params=None, config_name="tiny_imagenet_dataset_params"):
                         return get_data_loader(config_name=config_name,
                         return get_data_loader(config_name=config_name,
                                                dataset_cls=ImageNetDataset,
                                                dataset_cls=ImageNetDataset,
                                                train=True,
                                                train=True,
                    @@ -224,7 +229,7 @@ def tiny_imagenet_train(dataset_params={}, dataloader_params={}, config_name="ti
                                                dataloader_params=dataloader_params)
                                                dataloader_params=dataloader_params)
                     
                     
                     
                     
                    -def tiny_imagenet_val(dataset_params={}, dataloader_params={}, config_name="tiny_imagenet_dataset_params"):
                    +def tiny_imagenet_val(dataset_params=None, dataloader_params=None, config_name="tiny_imagenet_dataset_params"):
                         return get_data_loader(config_name=config_name,
                         return get_data_loader(config_name=config_name,
                                                dataset_cls=ImageNetDataset,
                                                dataset_cls=ImageNetDataset,
                                                train=False,
                                                train=False,
                    @@ -232,7 +237,7 @@ def tiny_imagenet_val(dataset_params={}, dataloader_params={}, config_name="tiny
                                                dataloader_params=dataloader_params)
                                                dataloader_params=dataloader_params)
                     
                     
                     
                     
                    -def cifar10_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def cifar10_train(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="cifar10_dataset_params",
                         return get_data_loader(config_name="cifar10_dataset_params",
                                                dataset_cls=Cifar10,
                                                dataset_cls=Cifar10,
                                                train=True,
                                                train=True,
                    @@ -241,7 +246,7 @@ def cifar10_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                                                )
                                                )
                     
                     
                     
                     
                    -def cifar10_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def cifar10_val(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="cifar10_dataset_params",
                         return get_data_loader(config_name="cifar10_dataset_params",
                                                dataset_cls=Cifar10,
                                                dataset_cls=Cifar10,
                                                train=False,
                                                train=False,
                    @@ -250,7 +255,7 @@ def cifar10_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                                                )
                                                )
                     
                     
                     
                     
                    -def cifar100_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def cifar100_train(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="cifar100_dataset_params",
                         return get_data_loader(config_name="cifar100_dataset_params",
                                                dataset_cls=Cifar100,
                                                dataset_cls=Cifar100,
                                                train=True,
                                                train=True,
                    @@ -259,7 +264,7 @@ def cifar100_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                                                )
                                                )
                     
                     
                     
                     
                    -def cifar100_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def cifar100_val(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="cifar100_dataset_params",
                         return get_data_loader(config_name="cifar100_dataset_params",
                                                dataset_cls=Cifar100,
                                                dataset_cls=Cifar100,
                                                train=False,
                                                train=False,
                    @@ -268,28 +273,31 @@ def cifar100_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                                                )
                                                )
                     
                     
                     
                     
                    -def classification_test_dataloader(batch_size: int = 5, image_size: int = 32) -> DataLoader:
                    -    images = torch.Tensor(np.zeros((batch_size, 3, image_size, image_size)))
                    -    ground_truth = torch.LongTensor(np.zeros((batch_size)))
                    +def classification_test_dataloader(batch_size: int = 5, image_size: int = 32, dataset_size=None) -> DataLoader:
                    +    dataset_size = dataset_size or batch_size
                    +    images = torch.Tensor(np.zeros((dataset_size, 3, image_size, image_size)))
                    +    ground_truth = torch.LongTensor(np.zeros((dataset_size)))
                         dataset = TensorDataset(images, ground_truth)
                         dataset = TensorDataset(images, ground_truth)
                         return DataLoader(dataset=dataset, batch_size=batch_size)
                         return DataLoader(dataset=dataset, batch_size=batch_size)
                     
                     
                     
                     
                    -def detection_test_dataloader(batch_size: int = 5, image_size: int = 320) -> DataLoader:
                    -    images = torch.Tensor(np.zeros((batch_size, 3, image_size, image_size)))
                    -    ground_truth = torch.LongTensor(np.zeros((batch_size, 6)))
                    +def detection_test_dataloader(batch_size: int = 5, image_size: int = 320, dataset_size=None) -> DataLoader:
                    +    dataset_size = dataset_size or batch_size
                    +    images = torch.Tensor(np.zeros((dataset_size, 3, image_size, image_size)))
                    +    ground_truth = torch.Tensor(np.zeros((dataset_size, 6)))
                         dataset = TensorDataset(images, ground_truth)
                         dataset = TensorDataset(images, ground_truth)
                         return DataLoader(dataset=dataset, batch_size=batch_size)
                         return DataLoader(dataset=dataset, batch_size=batch_size)
                     
                     
                     
                     
                    -def segmentation_test_dataloader(batch_size: int = 5, image_size: int = 512) -> DataLoader:
                    -    images = torch.Tensor(np.zeros((batch_size, 3, image_size, image_size)))
                    -    ground_truth = torch.LongTensor(np.zeros((batch_size, image_size, image_size)))
                    +def segmentation_test_dataloader(batch_size: int = 5, image_size: int = 512, dataset_size=None) -> DataLoader:
                    +    dataset_size = dataset_size or batch_size
                    +    images = torch.Tensor(np.zeros((dataset_size, 3, image_size, image_size)))
                    +    ground_truth = torch.LongTensor(np.zeros((dataset_size, image_size, image_size)))
                         dataset = TensorDataset(images, ground_truth)
                         dataset = TensorDataset(images, ground_truth)
                         return DataLoader(dataset=dataset, batch_size=batch_size)
                         return DataLoader(dataset=dataset, batch_size=batch_size)
                     
                     
                     
                     
                    -def cityscapes_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def cityscapes_train(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="cityscapes_dataset_params",
                         return get_data_loader(config_name="cityscapes_dataset_params",
                                                dataset_cls=CityscapesDataset,
                                                dataset_cls=CityscapesDataset,
                                                train=True,
                                                train=True,
                    @@ -298,7 +306,7 @@ def cityscapes_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                                                )
                                                )
                     
                     
                     
                     
                    -def cityscapes_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def cityscapes_val(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="cityscapes_dataset_params",
                         return get_data_loader(config_name="cityscapes_dataset_params",
                                                dataset_cls=CityscapesDataset,
                                                dataset_cls=CityscapesDataset,
                                                train=False,
                                                train=False,
                    @@ -307,7 +315,7 @@ def cityscapes_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                                                )
                                                )
                     
                     
                     
                     
                    -def cityscapes_stdc_seg50_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def cityscapes_stdc_seg50_train(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="cityscapes_stdc_seg50_dataset_params",
                         return get_data_loader(config_name="cityscapes_stdc_seg50_dataset_params",
                                                dataset_cls=CityscapesDataset,
                                                dataset_cls=CityscapesDataset,
                                                train=True,
                                                train=True,
                    @@ -316,7 +324,7 @@ def cityscapes_stdc_seg50_train(dataset_params: Dict = {}, dataloader_params: Di
                                                )
                                                )
                     
                     
                     
                     
                    -def cityscapes_stdc_seg50_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def cityscapes_stdc_seg50_val(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="cityscapes_stdc_seg50_dataset_params",
                         return get_data_loader(config_name="cityscapes_stdc_seg50_dataset_params",
                                                dataset_cls=CityscapesDataset,
                                                dataset_cls=CityscapesDataset,
                                                train=False,
                                                train=False,
                    @@ -325,7 +333,7 @@ def cityscapes_stdc_seg50_val(dataset_params: Dict = {}, dataloader_params: Dict
                                                )
                                                )
                     
                     
                     
                     
                    -def cityscapes_stdc_seg75_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def cityscapes_stdc_seg75_train(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="cityscapes_stdc_seg75_dataset_params",
                         return get_data_loader(config_name="cityscapes_stdc_seg75_dataset_params",
                                                dataset_cls=CityscapesDataset,
                                                dataset_cls=CityscapesDataset,
                                                train=True,
                                                train=True,
                    @@ -334,7 +342,7 @@ def cityscapes_stdc_seg75_train(dataset_params: Dict = {}, dataloader_params: Di
                                                )
                                                )
                     
                     
                     
                     
                    -def cityscapes_stdc_seg75_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def cityscapes_stdc_seg75_val(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="cityscapes_stdc_seg75_dataset_params",
                         return get_data_loader(config_name="cityscapes_stdc_seg75_dataset_params",
                                                dataset_cls=CityscapesDataset,
                                                dataset_cls=CityscapesDataset,
                                                train=False,
                                                train=False,
                    @@ -343,7 +351,7 @@ def cityscapes_stdc_seg75_val(dataset_params: Dict = {}, dataloader_params: Dict
                                                )
                                                )
                     
                     
                     
                     
                    -def cityscapes_regseg48_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def cityscapes_regseg48_train(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="cityscapes_regseg48_dataset_params",
                         return get_data_loader(config_name="cityscapes_regseg48_dataset_params",
                                                dataset_cls=CityscapesDataset,
                                                dataset_cls=CityscapesDataset,
                                                train=True,
                                                train=True,
                    @@ -352,7 +360,7 @@ def cityscapes_regseg48_train(dataset_params: Dict = {}, dataloader_params: Dict
                                                )
                                                )
                     
                     
                     
                     
                    -def cityscapes_regseg48_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def cityscapes_regseg48_val(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="cityscapes_regseg48_dataset_params",
                         return get_data_loader(config_name="cityscapes_regseg48_dataset_params",
                                                dataset_cls=CityscapesDataset,
                                                dataset_cls=CityscapesDataset,
                                                train=False,
                                                train=False,
                    @@ -361,7 +369,7 @@ def cityscapes_regseg48_val(dataset_params: Dict = {}, dataloader_params: Dict =
                                                )
                                                )
                     
                     
                     
                     
                    -def cityscapes_ddrnet_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def cityscapes_ddrnet_train(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="cityscapes_ddrnet_dataset_params",
                         return get_data_loader(config_name="cityscapes_ddrnet_dataset_params",
                                                dataset_cls=CityscapesDataset,
                                                dataset_cls=CityscapesDataset,
                                                train=True,
                                                train=True,
                    @@ -370,7 +378,7 @@ def cityscapes_ddrnet_train(dataset_params: Dict = {}, dataloader_params: Dict =
                                                )
                                                )
                     
                     
                     
                     
                    -def cityscapes_ddrnet_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def cityscapes_ddrnet_val(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="cityscapes_ddrnet_dataset_params",
                         return get_data_loader(config_name="cityscapes_ddrnet_dataset_params",
                                                dataset_cls=CityscapesDataset,
                                                dataset_cls=CityscapesDataset,
                                                train=False,
                                                train=False,
                    @@ -379,7 +387,7 @@ def cityscapes_ddrnet_val(dataset_params: Dict = {}, dataloader_params: Dict = {
                                                )
                                                )
                     
                     
                     
                     
                    -def coco_segmentation_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def coco_segmentation_train(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="coco_segmentation_dataset_params",
                         return get_data_loader(config_name="coco_segmentation_dataset_params",
                                                dataset_cls=CoCoSegmentationDataSet,
                                                dataset_cls=CoCoSegmentationDataSet,
                                                train=True,
                                                train=True,
                    @@ -388,7 +396,7 @@ def coco_segmentation_train(dataset_params: Dict = {}, dataloader_params: Dict =
                                                )
                                                )
                     
                     
                     
                     
                    -def coco_segmentation_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def coco_segmentation_val(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="coco_segmentation_dataset_params",
                         return get_data_loader(config_name="coco_segmentation_dataset_params",
                                                dataset_cls=CoCoSegmentationDataSet,
                                                dataset_cls=CoCoSegmentationDataSet,
                                                train=False,
                                                train=False,
                    @@ -397,7 +405,7 @@ def coco_segmentation_val(dataset_params: Dict = {}, dataloader_params: Dict = {
                                                )
                                                )
                     
                     
                     
                     
                    -def pascal_aug_segmentation_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def pascal_aug_segmentation_train(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="pascal_aug_segmentation_dataset_params",
                         return get_data_loader(config_name="pascal_aug_segmentation_dataset_params",
                                                dataset_cls=PascalAUG2012SegmentationDataSet,
                                                dataset_cls=PascalAUG2012SegmentationDataSet,
                                                train=True,
                                                train=True,
                    @@ -406,7 +414,7 @@ def pascal_aug_segmentation_train(dataset_params: Dict = {}, dataloader_params:
                                                )
                                                )
                     
                     
                     
                     
                    -def pascal_aug_segmentation_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def pascal_aug_segmentation_val(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="pascal_aug_segmentation_dataset_params",
                         return get_data_loader(config_name="pascal_aug_segmentation_dataset_params",
                                                dataset_cls=PascalAUG2012SegmentationDataSet,
                                                dataset_cls=PascalAUG2012SegmentationDataSet,
                                                train=False,
                                                train=False,
                    @@ -415,7 +423,7 @@ def pascal_aug_segmentation_val(dataset_params: Dict = {}, dataloader_params: Di
                                                )
                                                )
                     
                     
                     
                     
                    -def pascal_voc_segmentation_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def pascal_voc_segmentation_train(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="pascal_voc_segmentation_dataset_params",
                         return get_data_loader(config_name="pascal_voc_segmentation_dataset_params",
                                                dataset_cls=PascalVOC2012SegmentationDataSet,
                                                dataset_cls=PascalVOC2012SegmentationDataSet,
                                                train=True,
                                                train=True,
                    @@ -424,7 +432,7 @@ def pascal_voc_segmentation_train(dataset_params: Dict = {}, dataloader_params:
                                                )
                                                )
                     
                     
                     
                     
                    -def pascal_voc_segmentation_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def pascal_voc_segmentation_val(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="pascal_voc_segmentation_dataset_params",
                         return get_data_loader(config_name="pascal_voc_segmentation_dataset_params",
                                                dataset_cls=PascalVOC2012SegmentationDataSet,
                                                dataset_cls=PascalVOC2012SegmentationDataSet,
                                                train=False,
                                                train=False,
                    @@ -433,7 +441,7 @@ def pascal_voc_segmentation_val(dataset_params: Dict = {}, dataloader_params: Di
                                                )
                                                )
                     
                     
                     
                     
                    -def supervisely_persons_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def supervisely_persons_train(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="supervisely_persons_dataset_params",
                         return get_data_loader(config_name="supervisely_persons_dataset_params",
                                                dataset_cls=SuperviselyPersonsDataset,
                                                dataset_cls=SuperviselyPersonsDataset,
                                                train=True,
                                                train=True,
                    @@ -441,7 +449,7 @@ def supervisely_persons_train(dataset_params: Dict = {}, dataloader_params: Dict
                                                dataloader_params=dataloader_params)
                                                dataloader_params=dataloader_params)
                     
                     
                     
                     
                    -def supervisely_persons_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def supervisely_persons_val(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="supervisely_persons_dataset_params",
                         return get_data_loader(config_name="supervisely_persons_dataset_params",
                                                dataset_cls=SuperviselyPersonsDataset,
                                                dataset_cls=SuperviselyPersonsDataset,
                                                train=False,
                                                train=False,
                    @@ -449,7 +457,7 @@ def supervisely_persons_val(dataset_params: Dict = {}, dataloader_params: Dict =
                                                dataloader_params=dataloader_params)
                                                dataloader_params=dataloader_params)
                     
                     
                     
                     
                    -def pascal_voc_detection_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def pascal_voc_detection_train(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="pascal_voc_detection_dataset_params",
                         return get_data_loader(config_name="pascal_voc_detection_dataset_params",
                                                dataset_cls=PascalVOCUnifiedDetectionTrainDataset,
                                                dataset_cls=PascalVOCUnifiedDetectionTrainDataset,
                                                train=True,
                                                train=True,
                    @@ -458,10 +466,78 @@ def pascal_voc_detection_train(dataset_params: Dict = {}, dataloader_params: Dic
                                                )
                                                )
                     
                     
                     
                     
                    -def pascal_voc_detection_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
                    +def pascal_voc_detection_val(dataset_params: Dict = None, dataloader_params: Dict = None):
                         return get_data_loader(config_name="pascal_voc_detection_dataset_params",
                         return get_data_loader(config_name="pascal_voc_detection_dataset_params",
                                                dataset_cls=PascalVOCDetectionDataset,
                                                dataset_cls=PascalVOCDetectionDataset,
                                                train=False,
                                                train=False,
                                                dataset_params=dataset_params,
                                                dataset_params=dataset_params,
                                                dataloader_params=dataloader_params
                                                dataloader_params=dataloader_params
                                                )
                                                )
                    +
                    +
                    +ALL_DATALOADERS = {"coco2017_train": coco2017_train,
                    +                   "coco2017_val": coco2017_val,
                    +                   "coco2017_train_yolox": coco2017_train_yolox,
                    +                   "coco2017_val_yolox": coco2017_val_yolox,
                    +                   "coco2017_train_ssd_lite_mobilenet_v2": coco2017_train_ssd_lite_mobilenet_v2,
                    +                   "coco2017_val_ssd_lite_mobilenet_v2": coco2017_val_ssd_lite_mobilenet_v2,
                    +                   "imagenet_train": imagenet_train,
                    +                   "imagenet_val": imagenet_val,
                    +                   "imagenet_efficientnet_train": imagenet_efficientnet_train,
                    +                   "imagenet_efficientnet_val": imagenet_efficientnet_val,
                    +                   "imagenet_mobilenetv2_train": imagenet_mobilenetv2_train,
                    +                   "imagenet_mobilenetv2_val": imagenet_mobilenetv2_val,
                    +                   "imagenet_mobilenetv3_train": imagenet_mobilenetv3_train,
                    +                   "imagenet_mobilenetv3_val": imagenet_mobilenetv3_val,
                    +                   "imagenet_regnetY_train": imagenet_regnetY_train,
                    +                   "imagenet_regnetY_val": imagenet_regnetY_val,
                    +                   "imagenet_resnet50_train": imagenet_resnet50_train,
                    +                   "imagenet_resnet50_val": imagenet_resnet50_val,
                    +                   "imagenet_resnet50_kd_train": imagenet_resnet50_kd_train,
                    +                   "imagenet_resnet50_kd_val": imagenet_resnet50_kd_val,
                    +                   "imagenet_vit_base_train": imagenet_vit_base_train,
                    +                   "imagenet_vit_base_val": imagenet_vit_base_val,
                    +                   "tiny_imagenet_train": tiny_imagenet_train,
                    +                   "tiny_imagenet_val": tiny_imagenet_val,
                    +                   "cifar10_train": cifar10_train,
                    +                   "cifar10_val": cifar10_val,
                    +                   "cifar100_train": cifar100_train,
                    +                   "cifar100_val": cifar100_val,
                    +                   "cityscapes_train": cityscapes_train,
                    +                   "cityscapes_val": cityscapes_val,
                    +                   "cityscapes_stdc_seg50_train": cityscapes_stdc_seg50_train,
                    +                   "cityscapes_stdc_seg50_val": cityscapes_stdc_seg50_val,
                    +                   "cityscapes_stdc_seg75_train": cityscapes_stdc_seg75_train,
                    +                   "cityscapes_stdc_seg75_val": cityscapes_stdc_seg75_val,
                    +                   "cityscapes_regseg48_train": cityscapes_regseg48_train,
                    +                   "cityscapes_regseg48_val": cityscapes_regseg48_val,
                    +                   "cityscapes_ddrnet_train": cityscapes_ddrnet_train,
                    +                   "cityscapes_ddrnet_val": cityscapes_ddrnet_val,
                    +                   "coco_segmentation_train": coco_segmentation_train,
                    +                   "coco_segmentation_val": coco_segmentation_val,
                    +                   "pascal_aug_segmentation_train": pascal_aug_segmentation_train,
                    +                   "pascal_aug_segmentation_val": pascal_aug_segmentation_val,
                    +                   "pascal_voc_segmentation_train": pascal_voc_segmentation_train,
                    +                   "pascal_voc_segmentation_val": pascal_voc_segmentation_val,
                    +                   "supervisely_persons_train": supervisely_persons_train,
                    +                   "supervisely_persons_val": supervisely_persons_val,
                    +                   "pascal_voc_detection_train": pascal_voc_detection_train,
                    +                   "pascal_voc_detection_val": pascal_voc_detection_val
                    +                   }
                    +
                    +
                    +def get(name: str, dataset_params: Dict = None, dataloader_params: Dict = None):
                    +    """
                    +
                    +    
                    +    :param name: 
                    +    :param dataset_params: 
                    +    :param dataloader_params: 
                    +    :return: 
                    +    """
                    +
                    +    if name not in ALL_DATALOADERS.keys():
                    +        raise ValueError("Unsupported dataloader: " + str(name))
                    +
                    +    dataloader_cls = ALL_DATALOADERS[name]
                    +    return dataloader_cls(dataset_params=dataset_params, dataloader_params=dataloader_params)
                    Discard
                    @@ -3,8 +3,6 @@ import cv2
                     
                     
                     from super_gradients.training.datasets.data_augmentation import DataAugmentation
                     from super_gradients.training.datasets.data_augmentation import DataAugmentation
                     from super_gradients.training.datasets.sg_dataset import ListDataset, DirectoryDataSet
                     from super_gradients.training.datasets.sg_dataset import ListDataset, DirectoryDataSet
                    -from super_gradients.training.datasets.all_datasets import CLASSIFICATION_DATASETS, OBJECT_DETECTION_DATASETS, \
                    -    SEMANTIC_SEGMENTATION_DATASETS
                     from super_gradients.training.datasets.classification_datasets import ImageNetDataset, Cifar10, Cifar100
                     from super_gradients.training.datasets.classification_datasets import ImageNetDataset, Cifar10, Cifar100
                     from super_gradients.training.datasets.detection_datasets import DetectionDataset, COCODetectionDataset, PascalVOCDetectionDataset
                     from super_gradients.training.datasets.detection_datasets import DetectionDataset, COCODetectionDataset, PascalVOCDetectionDataset
                     from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
                     from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
                    @@ -13,23 +11,12 @@ from super_gradients.training.datasets.segmentation_datasets.pascal_aug_segmenta
                     from super_gradients.training.datasets.segmentation_datasets.coco_segmentation import CoCoSegmentationDataSet
                     from super_gradients.training.datasets.segmentation_datasets.coco_segmentation import CoCoSegmentationDataSet
                     from super_gradients.training.datasets.segmentation_datasets.supervisely_persons_segmentation import SuperviselyPersonsDataset
                     from super_gradients.training.datasets.segmentation_datasets.supervisely_persons_segmentation import SuperviselyPersonsDataset
                     
                     
                    -from super_gradients.training.datasets.dataset_interfaces.dataset_interface import TestDatasetInterface, DatasetInterface, \
                    -    Cifar10DatasetInterface, CoCoSegmentationDatasetInterface, \
                    -    PascalVOC2012SegmentationDataSetInterface, PascalAUG2012SegmentationDataSetInterface, \
                    -    TestYoloDetectionDatasetInterface, SegmentationTestDatasetInterface, DetectionTestDatasetInterface, \
                    -    ClassificationTestDatasetInterface, ImageNetDatasetInterface
                     
                     
                     cv2.setNumThreads(0)
                     cv2.setNumThreads(0)
                     
                     
                     
                     
                    -__all__ = ['DataAugmentation', 'ListDataset', 'DirectoryDataSet', 'CLASSIFICATION_DATASETS', 'OBJECT_DETECTION_DATASETS',
                    -           'SEMANTIC_SEGMENTATION_DATASETS', 'SegmentationDataSet',
                    +__all__ = ['DataAugmentation', 'ListDataset', 'DirectoryDataSet', 'SegmentationDataSet',
                                'PascalVOC2012SegmentationDataSet',
                                'PascalVOC2012SegmentationDataSet',
                    -           'PascalAUG2012SegmentationDataSet', 'CoCoSegmentationDataSet', 'TestDatasetInterface', 'DatasetInterface',
                    -           'Cifar10DatasetInterface', 'CoCoSegmentationDatasetInterface',
                    -           'PascalVOC2012SegmentationDataSetInterface', 'PascalAUG2012SegmentationDataSetInterface',
                    -           'TestYoloDetectionDatasetInterface', 'DetectionTestDatasetInterface', 'ClassificationTestDatasetInterface',
                    -           'SegmentationTestDatasetInterface',
                    -           'ImageNetDatasetInterface',
                    +           'PascalAUG2012SegmentationDataSet', 'CoCoSegmentationDataSet',
                                'DetectionDataset', 'COCODetectionDataset', 'PascalVOCDetectionDataset', 'ImageNetDataset',
                                'DetectionDataset', 'COCODetectionDataset', 'PascalVOCDetectionDataset', 'ImageNetDataset',
                                'Cifar10', 'Cifar100', 'SuperviselyPersonsDataset']
                                'Cifar10', 'Cifar100', 'SuperviselyPersonsDataset']
                    Discard
                    1
                    2
                    3
                    4
                    5
                    6
                    7
                    8
                    9
                    10
                    11
                    12
                    13
                    14
                    15
                    16
                    17
                    18
                    19
                    20
                    21
                    22
                    23
                    24
                    25
                    26
                    27
                    28
                    29
                    30
                    31
                    32
                    33
                    34
                    35
                    36
                    37
                    38
                    39
                    40
                    41
                    42
                    43
                    44
                    45
                    46
                    47
                    48
                    49
                    50
                    51
                    52
                    53
                    54
                    55
                    56
                    57
                    58
                    59
                    60
                    61
                    62
                    63
                    64
                    65
                    66
                    67
                    68
                    69
                    70
                    71
                    72
                    73
                    74
                    75
                    76
                    77
                    78
                    79
                    80
                    81
                    82
                    83
                    84
                    85
                    86
                    87
                    88
                    89
                    1. from collections import defaultdict
                    2. from typing import Dict, List, Type
                    3. from super_gradients.training.datasets.dataset_interfaces import DatasetInterface, TestDatasetInterface, \
                    4. LibraryDatasetInterface, \
                    5. ClassificationDatasetInterface, Cifar10DatasetInterface, Cifar100DatasetInterface, \
                    6. ImageNetDatasetInterface, TinyImageNetDatasetInterface, CoCoSegmentationDatasetInterface,\
                    7. PascalAUG2012SegmentationDataSetInterface, PascalVOC2012SegmentationDataSetInterface
                    8. from super_gradients.common.data_types.enum.deep_learning_task import DeepLearningTask
                    9. from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CoCoDetectionDatasetInterface
                    10. CLASSIFICATION_DATASETS = {
                    11. "test_dataset": TestDatasetInterface,
                    12. "library_dataset": LibraryDatasetInterface,
                    13. "classification_dataset": ClassificationDatasetInterface,
                    14. "cifar_10": Cifar10DatasetInterface,
                    15. "cifar_100": Cifar100DatasetInterface,
                    16. "imagenet": ImageNetDatasetInterface,
                    17. "tiny_imagenet": TinyImageNetDatasetInterface
                    18. }
                    19. OBJECT_DETECTION_DATASETS = {
                    20. "coco": CoCoDetectionDatasetInterface,
                    21. }
                    22. SEMANTIC_SEGMENTATION_DATASETS = {
                    23. "coco": CoCoSegmentationDatasetInterface,
                    24. "pascal_voc": PascalVOC2012SegmentationDataSetInterface,
                    25. "pascal_aug": PascalAUG2012SegmentationDataSetInterface
                    26. }
                    27. class DataSetDoesNotExistException(Exception):
                    28. """
                    29. The requested dataset does not exist, or is not implemented.
                    30. """
                    31. pass
                    32. class SgLibraryDatasets(object):
                    33. """
                    34. Holds all of the different library dataset dictionaries, by DL Task mapping
                    35. Attributes:
                    36. CLASSIFICATION Dictionary of Classification Data sets
                    37. OBJECT_DETECTION Dictionary of Object Detection Data sets
                    38. SEMANTIC_SEGMENTATION Dictionary of Semantic Segmentation Data sets
                    39. """
                    40. CLASSIFICATION = CLASSIFICATION_DATASETS
                    41. OBJECT_DETECTION = OBJECT_DETECTION_DATASETS
                    42. SEMANTIC_SEGMENTATION = SEMANTIC_SEGMENTATION_DATASETS
                    43. _datasets_mapping = {
                    44. DeepLearningTask.CLASSIFICATION: CLASSIFICATION,
                    45. DeepLearningTask.SEMANTIC_SEGMENTATION: SEMANTIC_SEGMENTATION,
                    46. DeepLearningTask.OBJECT_DETECTION: OBJECT_DETECTION,
                    47. }
                    48. @staticmethod
                    49. def get_all_available_datasets() -> Dict[str, List[str]]:
                    50. """
                    51. Gets all the available datasets.
                    52. """
                    53. all_datasets: Dict[str, List[str]] = defaultdict(list)
                    54. for dl_task, task_datasets in SgLibraryDatasets._datasets_mapping.items():
                    55. for dataset_name, dataset_interface in task_datasets.items():
                    56. all_datasets[dl_task].append(dataset_name)
                    57. # TODO: Return Dataset Metadata list from the dataset interfaces objects
                    58. # TODO: Transform DatasetInterface -> DataSetMetadata
                    59. return all_datasets
                    60. @staticmethod
                    61. def get_dataset(dl_task: str, dataset_name: str) -> Type[DatasetInterface]:
                    62. """
                    63. Get's a dataset with a given name for a given deep learning task.
                    64. examp:
                    65. >>> SgLibraryDatasets.get_dataset(dl_task='classification', dataset_name='cifar_100')
                    66. >>> <Cifar100DatasetInterface instance>
                    67. """
                    68. task_datasets: Dict[str, DatasetInterface] = SgLibraryDatasets._datasets_mapping.get(dl_task)
                    69. if not task_datasets:
                    70. raise ValueError(f"Invalid Deep Learining Task: {dl_task}")
                    71. dataset: DatasetInterface = task_datasets.get(dataset_name)
                    72. if not dataset:
                    73. raise DataSetDoesNotExistException(dataset_name)
                    74. return dataset
                    Discard
                    1
                    2
                    3
                    4
                    5
                    6
                    7
                    8
                    9
                    10
                    11
                    12
                    13
                    14
                    15
                    1. from super_gradients.training.datasets.dataset_interfaces.dataset_interface import DatasetInterface, TestDatasetInterface, \
                    2. LibraryDatasetInterface, \
                    3. ClassificationDatasetInterface, Cifar10DatasetInterface, Cifar100DatasetInterface, \
                    4. ImageNetDatasetInterface, TinyImageNetDatasetInterface, CoCoSegmentationDatasetInterface, \
                    5. PascalAUG2012SegmentationDataSetInterface, PascalVOC2012SegmentationDataSetInterface, \
                    6. TestYoloDetectionDatasetInterface, SegmentationTestDatasetInterface, DetectionTestDatasetInterface, ClassificationTestDatasetInterface,\
                    7. CoCoDetectionDatasetInterface, PascalVOCUnifiedDetectionDatasetInterface
                    8. __all__ = ['DatasetInterface', 'TestDatasetInterface', 'LibraryDatasetInterface', 'ClassificationDatasetInterface', 'Cifar10DatasetInterface',
                    9. 'Cifar100DatasetInterface', 'ImageNetDatasetInterface', 'TinyImageNetDatasetInterface',
                    10. 'CoCoSegmentationDatasetInterface', 'PascalAUG2012SegmentationDataSetInterface',
                    11. 'PascalVOC2012SegmentationDataSetInterface', 'TestYoloDetectionDatasetInterface', 'SegmentationTestDatasetInterface',
                    12. 'DetectionTestDatasetInterface', 'ClassificationTestDatasetInterface', 'CoCoDetectionDatasetInterface',
                    13. 'PascalVOCUnifiedDetectionDatasetInterface']
                    Discard
                    1
                    2
                    3
                    4
                    5
                    6
                    7
                    8
                    9
                    10
                    11
                    12
                    13
                    14
                    15
                    16
                    17
                    18
                    19
                    20
                    21
                    22
                    23
                    24
                    25
                    26
                    27
                    28
                    29
                    30
                    31
                    32
                    33
                    34
                    35
                    36
                    37
                    38
                    39
                    40
                    41
                    42
                    43
                    44
                    45
                    46
                    47
                    48
                    49
                    50
                    51
                    52
                    53
                    54
                    55
                    56
                    57
                    58
                    59
                    60
                    61
                    62
                    63
                    64
                    65
                    66
                    67
                    68
                    69
                    70
                    71
                    72
                    73
                    74
                    75
                    76
                    77
                    78
                    79
                    80
                    81
                    82
                    83
                    84
                    85
                    86
                    87
                    88
                    89
                    90
                    91
                    92
                    93
                    94
                    95
                    96
                    97
                    98
                    99
                    100
                    101
                    102
                    103
                    104
                    105
                    106
                    107
                    108
                    109
                    110
                    111
                    112
                    113
                    114
                    115
                    116
                    117
                    118
                    119
                    120
                    121
                    122
                    123
                    124
                    125
                    126
                    127
                    128
                    129
                    130
                    131
                    132
                    133
                    134
                    135
                    136
                    137
                    138
                    139
                    140
                    141
                    142
                    143
                    144
                    145
                    146
                    147
                    148
                    149
                    150
                    151
                    152
                    153
                    154
                    155
                    156
                    157
                    158
                    159
                    160
                    161
                    162
                    163
                    164
                    165
                    166
                    167
                    168
                    169
                    170
                    171
                    172
                    173
                    174
                    175
                    176
                    177
                    178
                    179
                    180
                    181
                    182
                    183
                    184
                    185
                    186
                    187
                    188
                    189
                    190
                    191
                    192
                    193
                    194
                    195
                    196
                    197
                    198
                    199
                    200
                    201
                    202
                    203
                    204
                    205
                    206
                    207
                    208
                    209
                    210
                    211
                    212
                    213
                    214
                    215
                    216
                    217
                    218
                    219
                    220
                    221
                    222
                    223
                    224
                    225
                    226
                    227
                    228
                    229
                    230
                    231
                    232
                    233
                    234
                    235
                    236
                    237
                    238
                    239
                    240
                    241
                    242
                    243
                    244
                    245
                    246
                    247
                    248
                    249
                    250
                    251
                    252
                    253
                    254
                    255
                    256
                    257
                    258
                    259
                    260
                    261
                    262
                    263
                    264
                    265
                    266
                    267
                    268
                    269
                    270
                    271
                    272
                    273
                    274
                    275
                    276
                    277
                    278
                    279
                    280
                    281
                    282
                    283
                    284
                    285
                    286
                    287
                    288
                    289
                    290
                    291
                    292
                    293
                    294
                    295
                    296
                    297
                    298
                    299
                    300
                    301
                    302
                    303
                    304
                    305
                    306
                    307
                    308
                    309
                    310
                    311
                    312
                    313
                    314
                    315
                    316
                    317
                    318
                    319
                    320
                    321
                    322
                    323
                    324
                    325
                    326
                    327
                    328
                    329
                    330
                    331
                    332
                    333
                    334
                    335
                    336
                    337
                    338
                    339
                    340
                    341
                    342
                    343
                    344
                    345
                    346
                    347
                    348
                    349
                    350
                    351
                    352
                    353
                    354
                    355
                    356
                    357
                    358
                    359
                    360
                    361
                    362
                    363
                    364
                    365
                    366
                    367
                    368
                    369
                    370
                    371
                    372
                    373
                    374
                    375
                    376
                    377
                    378
                    379
                    380
                    381
                    382
                    383
                    384
                    385
                    386
                    387
                    388
                    389
                    390
                    391
                    392
                    393
                    394
                    395
                    396
                    397
                    398
                    399
                    400
                    401
                    402
                    403
                    404
                    405
                    406
                    407
                    408
                    409
                    410
                    411
                    412
                    413
                    414
                    415
                    416
                    417
                    418
                    419
                    420
                    421
                    422
                    423
                    424
                    425
                    426
                    427
                    428
                    429
                    430
                    431
                    432
                    433
                    434
                    435
                    436
                    437
                    438
                    439
                    440
                    441
                    442
                    443
                    444
                    445
                    446
                    447
                    448
                    449
                    450
                    451
                    452
                    453
                    454
                    455
                    456
                    457
                    458
                    459
                    460
                    461
                    462
                    463
                    464
                    465
                    466
                    467
                    468
                    469
                    470
                    471
                    472
                    473
                    474
                    475
                    476
                    477
                    478
                    479
                    480
                    481
                    482
                    483
                    484
                    485
                    486
                    487
                    488
                    489
                    490
                    491
                    492
                    493
                    494
                    495
                    496
                    497
                    498
                    499
                    500
                    501
                    502
                    503
                    504
                    505
                    506
                    507
                    508
                    509
                    510
                    511
                    512
                    513
                    514
                    515
                    516
                    517
                    518
                    519
                    520
                    521
                    522
                    523
                    524
                    525
                    526
                    527
                    528
                    529
                    530
                    531
                    532
                    533
                    534
                    535
                    536
                    537
                    538
                    539
                    540
                    541
                    542
                    543
                    544
                    545
                    546
                    547
                    548
                    549
                    550
                    551
                    552
                    553
                    554
                    555
                    556
                    557
                    558
                    559
                    560
                    561
                    562
                    563
                    564
                    565
                    566
                    567
                    568
                    569
                    570
                    571
                    572
                    573
                    574
                    575
                    576
                    577
                    578
                    579
                    580
                    581
                    582
                    583
                    584
                    585
                    586
                    587
                    588
                    589
                    590
                    591
                    592
                    593
                    594
                    595
                    596
                    597
                    598
                    599
                    600
                    601
                    602
                    603
                    604
                    605
                    606
                    607
                    608
                    609
                    610
                    611
                    612
                    613
                    614
                    615
                    616
                    617
                    618
                    619
                    620
                    621
                    622
                    623
                    624
                    625
                    626
                    627
                    628
                    629
                    630
                    631
                    632
                    633
                    634
                    635
                    636
                    637
                    638
                    639
                    640
                    641
                    642
                    643
                    644
                    645
                    646
                    647
                    648
                    649
                    650
                    651
                    652
                    653
                    654
                    655
                    656
                    657
                    658
                    659
                    660
                    661
                    662
                    663
                    664
                    665
                    666
                    667
                    668
                    669
                    670
                    671
                    672
                    673
                    674
                    675
                    676
                    677
                    678
                    679
                    680
                    681
                    682
                    683
                    684
                    685
                    686
                    687
                    688
                    689
                    690
                    691
                    692
                    693
                    694
                    695
                    696
                    697
                    698
                    699
                    700
                    701
                    702
                    703
                    704
                    705
                    706
                    707
                    708
                    709
                    710
                    711
                    712
                    713
                    714
                    715
                    716
                    717
                    718
                    719
                    720
                    721
                    722
                    723
                    724
                    725
                    726
                    727
                    728
                    729
                    730
                    731
                    732
                    733
                    734
                    735
                    736
                    737
                    738
                    739
                    740
                    741
                    742
                    743
                    744
                    745
                    746
                    747
                    748
                    749
                    750
                    751
                    752
                    753
                    754
                    755
                    756
                    757
                    758
                    759
                    760
                    761
                    762
                    763
                    764
                    765
                    766
                    767
                    768
                    769
                    770
                    771
                    772
                    773
                    774
                    775
                    776
                    777
                    778
                    779
                    780
                    781
                    782
                    783
                    784
                    785
                    786
                    787
                    788
                    789
                    790
                    791
                    792
                    793
                    794
                    795
                    796
                    797
                    798
                    799
                    800
                    801
                    802
                    803
                    804
                    805
                    806
                    807
                    1. import os
                    2. import numpy as np
                    3. import torch
                    4. import torchvision
                    5. import torchvision.datasets as datasets
                    6. import torchvision.transforms as transforms
                    7. from torch.utils.data import ConcatDataset, BatchSampler, DataLoader
                    8. from torch.utils.data.distributed import DistributedSampler
                    9. from super_gradients.common import DatasetDataInterface
                    10. from super_gradients.common.abstractions.abstract_logger import get_logger
                    11. from super_gradients.common.environment import AWS_ENV_NAME
                    12. from super_gradients.training import utils as core_utils
                    13. from super_gradients.training.datasets import datasets_utils, DataAugmentation
                    14. from super_gradients.training.datasets.data_augmentation import Lighting, RandomErase
                    15. from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
                    16. from super_gradients.training.datasets.datasets_utils import RandomResizedCropAndInterpolation, worker_init_reset_seed
                    17. from super_gradients.training.datasets.detection_datasets import COCODetectionDataset, PascalVOCDetectionDataset
                    18. from super_gradients.training.datasets.mixup import CollateMixup
                    19. from super_gradients.training.datasets.samplers.infinite_sampler import InfiniteSampler
                    20. from super_gradients.training.datasets.samplers.repeated_augmentation_sampler import RepeatAugSampler
                    21. from super_gradients.training.datasets.segmentation_datasets import PascalVOC2012SegmentationDataSet, \
                    22. PascalAUG2012SegmentationDataSet, CoCoSegmentationDataSet
                    23. from super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation import CityscapesDataset
                    24. from super_gradients.training.datasets.segmentation_datasets.supervisely_persons_segmentation import \
                    25. SuperviselyPersonsDataset
                    26. from super_gradients.training.exceptions.dataset_exceptions import IllegalDatasetParameterException
                    27. from super_gradients.training.transforms.transforms import RandomFlip, Rescale, RandomRescale, CropImageAndMask, \
                    28. PadShortToCropSize
                    29. from super_gradients.training.utils import get_param
                    30. from super_gradients.training.utils.distributed_training_utils import get_local_rank, wait_for_the_master
                    31. default_dataset_params = {"batch_size": 64, "val_batch_size": 200, "test_batch_size": 200, "dataset_dir": "./data/",
                    32. "s3_link": None}
                    33. LIBRARY_DATASETS = {
                    34. "cifar10": {'class': datasets.CIFAR10, 'mean': (0.4914, 0.4822, 0.4465), 'std': (0.2023, 0.1994, 0.2010)},
                    35. "cifar100": {'class': datasets.CIFAR100, 'mean': (0.5071, 0.4865, 0.4409), 'std': (0.2673, 0.2564, 0.2762)},
                    36. "SVHN": {'class': datasets.SVHN, 'mean': None, 'std': None}
                    37. }
                    38. logger = get_logger(__name__)
                    39. class DatasetInterface:
                    40. """
                    41. DatasetInterface - This class manages all of the "communiation" the Model has with the Data Sets
                    42. """
                    43. def __init__(self, dataset_params={}, train_loader=None, val_loader=None, test_loader=None, classes=None):
                    44. """
                    45. @param train_loader: torch.utils.data.Dataloader (optional) dataloader for training.
                    46. @param test_loader: torch.utils.data.Dataloader (optional) dataloader for testing.
                    47. @param classes: list of classes.
                    48. Note: the above parameters will be discarded in case dataset_params is passed.
                    49. @param dataset_params:
                    50. - `batch_size` : int (default=64)
                    51. Number of examples per batch for training. Large batch sizes are recommended.
                    52. - `val_batch_size` : int (default=200)
                    53. Number of examples per batch for validation. Large batch sizes are recommended.
                    54. - `dataset_dir` : str (default="./data/")
                    55. Directory location for the data. Data will be downloaded to this directory when getting it from a
                    56. remote url.
                    57. - `s3_link` : str (default=None)
                    58. remote s3 link to download the data (optional).
                    59. - `aug_repeat_count` : int (default=0)
                    60. amount of repetitions (each repetition of an example is augmented differently) for each
                    61. example for the trainset.
                    62. """
                    63. self.dataset_params = core_utils.HpmStruct(**default_dataset_params)
                    64. self.dataset_params.override(**dataset_params)
                    65. self.trainset, self.valset, self.testset = None, None, None
                    66. self.train_loader, self.val_loader, self.test_loader = train_loader, val_loader, test_loader
                    67. self.classes = classes
                    68. self.batch_size_factor = 1
                    69. if self.dataset_params.s3_link is not None:
                    70. self.download_from_cloud()
                    71. def download_from_cloud(self):
                    72. if self.dataset_params.s3_link is not None:
                    73. env_name = AWS_ENV_NAME
                    74. downloader = DatasetDataInterface(env=env_name)
                    75. target_dir = self.dataset_params.dataset_dir
                    76. if not os.path.exists(target_dir):
                    77. os.mkdir(target_dir)
                    78. downloader.load_remote_dataset_file(self.dataset_params.s3_link, target_dir)
                    79. def build_data_loaders(self, batch_size_factor=1, num_workers=8, train_batch_size=None, val_batch_size=None,
                    80. test_batch_size=None, distributed_sampler: bool = False):
                    81. """
                    82. define train, val (and optionally test) loaders. The method deals separately with distributed training and standard
                    83. (non distributed, or parallel training). In the case of distributed training we need to rely on distributed
                    84. samplers.
                    85. :param batch_size_factor: int - factor to multiply the batch size (usually for multi gpu)
                    86. :param num_workers: int - number of workers (parallel processes) for dataloaders
                    87. :param train_batch_size: int - batch size for train loader, if None will be taken from dataset_params
                    88. :param val_batch_size: int - batch size for val loader, if None will be taken from dataset_params
                    89. :param distributed_sampler: boolean flag for distributed training mode
                    90. :return: train_loader, val_loader, classes: list of classes
                    91. """
                    92. # CHANGE THE BATCH SIZE ACCORDING TO THE NUMBER OF DEVICES - ONLY IN NON-DISTRIBUTED TRAINING MODE
                    93. # IN DISTRIBUTED MODE WE NEED DISTRIBUTED SAMPLERS
                    94. # NO SHUFFLE IN DISTRIBUTED TRAINING
                    95. aug_repeat_count = get_param(self.dataset_params, "aug_repeat_count", 0)
                    96. if aug_repeat_count > 0 and not distributed_sampler:
                    97. raise IllegalDatasetParameterException("repeated augmentation is only supported with DDP.")
                    98. if distributed_sampler:
                    99. self.batch_size_factor = 1
                    100. train_sampler = RepeatAugSampler(self.trainset,
                    101. num_repeats=aug_repeat_count) if aug_repeat_count > 0 else DistributedSampler(
                    102. self.trainset)
                    103. val_sampler = DistributedSampler(self.valset)
                    104. test_sampler = DistributedSampler(self.testset) if self.testset is not None else None
                    105. train_shuffle = False
                    106. else:
                    107. self.batch_size_factor = batch_size_factor
                    108. train_sampler = None
                    109. val_sampler = None
                    110. test_sampler = None
                    111. train_shuffle = True
                    112. if train_batch_size is None:
                    113. train_batch_size = self.dataset_params.batch_size * self.batch_size_factor
                    114. if val_batch_size is None:
                    115. val_batch_size = self.dataset_params.val_batch_size * self.batch_size_factor
                    116. if test_batch_size is None:
                    117. test_batch_size = self.dataset_params.test_batch_size * self.batch_size_factor
                    118. train_loader_drop_last = core_utils.get_param(self.dataset_params, 'train_loader_drop_last', default_val=False)
                    119. cutmix = core_utils.get_param(self.dataset_params, 'cutmix', False)
                    120. cutmix_params = core_utils.get_param(self.dataset_params, 'cutmix_params')
                    121. # WRAPPING collate_fn
                    122. train_collate_fn = core_utils.get_param(self.trainset, 'collate_fn')
                    123. val_collate_fn = core_utils.get_param(self.valset, 'collate_fn')
                    124. test_collate_fn = core_utils.get_param(self.testset, 'collate_fn')
                    125. if cutmix and train_collate_fn is not None:
                    126. raise IllegalDatasetParameterException("cutmix and collate function cannot be used together")
                    127. if cutmix:
                    128. # FIXME - cutmix should be available only in classification dataset. once we make sure all classification
                    129. # datasets inherit from the same super class, we should move cutmix code to that class
                    130. logger.warning("Cutmix/mixup was enabled. This feature is currently supported only "
                    131. "for classification datasets.")
                    132. train_collate_fn = CollateMixup(**cutmix_params)
                    133. # FIXME - UNDERSTAND IF THE num_replicas VARIBALE IS NEEDED
                    134. # train_sampler = DistributedSampler(self.trainset,
                    135. # num_replicas=distributed_gpus_num) if distributed_sampler else None
                    136. # val_sampler = DistributedSampler(self.valset,
                    137. # num_replicas=distributed_gpus_num) if distributed_sampler else None
                    138. self.train_loader = torch.utils.data.DataLoader(self.trainset,
                    139. batch_size=train_batch_size,
                    140. shuffle=train_shuffle,
                    141. num_workers=num_workers,
                    142. pin_memory=True,
                    143. sampler=train_sampler,
                    144. collate_fn=train_collate_fn,
                    145. drop_last=train_loader_drop_last)
                    146. self.val_loader = torch.utils.data.DataLoader(self.valset,
                    147. batch_size=val_batch_size,
                    148. shuffle=False,
                    149. num_workers=num_workers,
                    150. pin_memory=True,
                    151. sampler=val_sampler,
                    152. collate_fn=val_collate_fn)
                    153. if self.testset is not None:
                    154. self.test_loader = torch.utils.data.DataLoader(self.testset,
                    155. batch_size=test_batch_size,
                    156. shuffle=False,
                    157. num_workers=num_workers,
                    158. pin_memory=True,
                    159. sampler=test_sampler,
                    160. collate_fn=test_collate_fn)
                    161. self.classes = self.trainset.classes
                    162. def get_data_loaders(self, **kwargs):
                    163. """
                    164. Get self.train_loader, self.val_loader, self.test_loader, self.classes.
                    165. If the data loaders haven't been initialized yet, build them first.
                    166. :param kwargs: kwargs are passed to build_data_loaders.
                    167. """
                    168. if self.train_loader is None and self.val_loader is None:
                    169. self.build_data_loaders(**kwargs)
                    170. return self.train_loader, self.val_loader, self.test_loader, self.classes
                    171. def get_val_sample(self, num_samples=1):
                    172. if num_samples > len(self.valset):
                    173. raise Exception("Tried to load more samples than val-set size")
                    174. if num_samples == 1:
                    175. return self.valset[0]
                    176. else:
                    177. return self.valset[0:num_samples]
                    178. def get_dataset_params(self):
                    179. return self.dataset_params
                    180. def print_dataset_details(self):
                    181. logger.info("{} training samples, {} val samples, {} classes".format(len(self.trainset), len(self.valset),
                    182. len(self.trainset.classes)))
                    183. class ExternalDatasetInterface(DatasetInterface):
                    184. def __init__(self, train_loader, val_loader, num_classes, dataset_params={}):
                    185. """
                    186. ExternalDatasetInterface - A wrapper for external dataset interface that gets dataloaders from keras/TF
                    187. and converts them to Torch-like dataloaders that return torch.Tensors after
                    188. optional collate_fn while maintaining the same interface (connect_dataset_interface etc.)
                    189. :train_loader: The external train_loader
                    190. :val_loader: The external val_loader
                    191. :num_classes: The number of classes
                    192. :dataset_params The dict that includes the batch_size and/or the collate_fn
                    193. :return: DataLoaders that generate torch.Tensors batches after collate_fn
                    194. """
                    195. super().__init__(dataset_params)
                    196. self.train_loader = train_loader
                    197. self.val_loader = val_loader
                    198. self.classes = num_classes
                    199. def get_data_loaders(self, batch_size_factor: int = 1, num_workers: int = 8, train_batch_size: int = None,
                    200. val_batch_size: int = None, distributed_sampler: bool = False):
                    201. # CHANGE THE BATCH SIZE ACCORDING TO THE NUMBER OF DEVICES - ONLY IN NON-DISTRIBUED TRAINING MODE
                    202. # IN DISTRIBUTED MODE WE NEED DISTRIBUTED SAMPLERS
                    203. # NO SHUFFLE IN DISTRIBUTED TRAINING
                    204. if distributed_sampler:
                    205. self.batch_size_factor = 1
                    206. train_sampler = DistributedSampler(self.trainset, shuffle=True)
                    207. val_sampler = DistributedSampler(self.valset)
                    208. train_shuffle = False
                    209. else:
                    210. self.batch_size_factor = batch_size_factor
                    211. train_sampler = None
                    212. val_sampler = None
                    213. train_shuffle = True
                    214. if train_batch_size is None:
                    215. train_batch_size = self.dataset_params.batch_size * self.batch_size_factor
                    216. if val_batch_size is None:
                    217. val_batch_size = self.dataset_params.val_batch_size * self.batch_size_factor
                    218. train_loader_drop_last = core_utils.get_param(self.dataset_params, 'train_loader_drop_last', default_val=False)
                    219. # WRAPPING collate_fn
                    220. train_collate_fn = core_utils.get_param(self.dataset_params, 'train_collate_fn')
                    221. val_collate_fn = core_utils.get_param(self.dataset_params, 'val_collate_fn')
                    222. # FIXME - UNDERSTAND IF THE num_replicas VARIBALE IS NEEDED
                    223. # train_sampler = DistributedSampler(self.trainset,
                    224. # num_replicas=distributed_gpus_num) if distributed_sampler else None
                    225. # val_sampler = DistributedSampler(self.valset,
                    226. # num_replicas=distributed_gpus_num) if distributed_sampler else None
                    227. self.torch_train_loader = torch.utils.data.DataLoader(self.train_loader,
                    228. batch_size=train_batch_size,
                    229. shuffle=train_shuffle,
                    230. num_workers=num_workers,
                    231. pin_memory=True,
                    232. sampler=train_sampler,
                    233. collate_fn=train_collate_fn,
                    234. drop_last=train_loader_drop_last)
                    235. self.torch_val_loader = torch.utils.data.DataLoader(self.val_loader,
                    236. batch_size=val_batch_size,
                    237. shuffle=False,
                    238. num_workers=num_workers,
                    239. pin_memory=True,
                    240. sampler=val_sampler,
                    241. collate_fn=val_collate_fn)
                    242. return self.torch_train_loader, self.torch_val_loader, None, self.classes
                    243. class LibraryDatasetInterface(DatasetInterface):
                    244. def __init__(self, name="cifar10", dataset_params={}, to_cutout=False):
                    245. super(LibraryDatasetInterface, self).__init__(dataset_params)
                    246. self.dataset_name = name
                    247. if self.dataset_name not in LIBRARY_DATASETS.keys():
                    248. raise Exception('dataset not found')
                    249. self.lib_dataset_params = LIBRARY_DATASETS[self.dataset_name]
                    250. if self.lib_dataset_params['mean'] is None:
                    251. trainset = torchvision.datasets.SVHN(root=self.dataset_params.dataset_dir, split='train', download=True,
                    252. transform=transforms.ToTensor())
                    253. self.lib_dataset_params['mean'], self.lib_dataset_params['std'] = datasets_utils.get_mean_and_std(trainset)
                    254. # OVERWRITE MEAN AND STD IF DEFINED IN DATASET PARAMS
                    255. self.lib_dataset_params['mean'] = core_utils.get_param(self.dataset_params, 'img_mean',
                    256. default_val=self.lib_dataset_params['mean'])
                    257. self.lib_dataset_params['std'] = core_utils.get_param(self.dataset_params, 'img_std',
                    258. default_val=self.lib_dataset_params['std'])
                    259. crop_size = core_utils.get_param(self.dataset_params, 'crop_size', default_val=32)
                    260. if to_cutout:
                    261. transform_train = transforms.Compose([
                    262. transforms.RandomCrop(crop_size, padding=4),
                    263. transforms.RandomHorizontalFlip(),
                    264. DataAugmentation.normalize(self.lib_dataset_params['mean'], self.lib_dataset_params['std']),
                    265. DataAugmentation.cutout(16),
                    266. DataAugmentation.to_tensor()
                    267. ])
                    268. else:
                    269. transform_train = transforms.Compose([
                    270. transforms.RandomCrop(crop_size, padding=4),
                    271. transforms.RandomHorizontalFlip(),
                    272. transforms.ToTensor(),
                    273. transforms.Normalize(self.lib_dataset_params['mean'], self.lib_dataset_params['std']),
                    274. ])
                    275. transform_val = transforms.Compose([
                    276. transforms.ToTensor(),
                    277. transforms.Normalize(self.lib_dataset_params['mean'], self.lib_dataset_params['std']),
                    278. ])
                    279. dataset_cls = self.lib_dataset_params["class"]
                    280. self.trainset = dataset_cls(root=self.dataset_params.dataset_dir, train=True, download=True,
                    281. transform=transform_train)
                    282. self.valset = dataset_cls(root=self.dataset_params.dataset_dir, train=False, download=True,
                    283. transform=transform_val)
                    284. class Cifar10DatasetInterface(LibraryDatasetInterface):
                    285. def __init__(self, dataset_params={}):
                    286. super(Cifar10DatasetInterface, self).__init__(name="cifar10", dataset_params=dataset_params)
                    287. class Cifar100DatasetInterface(LibraryDatasetInterface):
                    288. def __init__(self, dataset_params={}):
                    289. super(Cifar100DatasetInterface, self).__init__(name="cifar100", dataset_params=dataset_params)
                    290. class TestDatasetInterface(DatasetInterface):
                    291. def __init__(self, trainset, dataset_params={}, classes=None):
                    292. super(TestDatasetInterface, self).__init__(dataset_params)
                    293. self.trainset = trainset
                    294. self.valset = self.trainset
                    295. self.testset = self.trainset
                    296. self.classes = classes
                    297. def get_data_loaders(self, batch_size_factor=1, num_workers=8, train_batch_size=None, val_batch_size=None,
                    298. distributed_sampler=False):
                    299. self.trainset.classes = [0, 1, 2, 3, 4] if self.classes is None else self.classes
                    300. return super().get_data_loaders(batch_size_factor=batch_size_factor,
                    301. num_workers=num_workers,
                    302. train_batch_size=train_batch_size,
                    303. val_batch_size=val_batch_size,
                    304. distributed_sampler=distributed_sampler)
                    305. class ClassificationTestDatasetInterface(TestDatasetInterface):
                    306. def __init__(self, dataset_params={}, image_size=32, batch_size=5, classes=None):
                    307. trainset = torch.utils.data.TensorDataset(torch.Tensor(np.zeros((batch_size, 3, image_size, image_size))),
                    308. torch.LongTensor(np.zeros((batch_size))))
                    309. super(ClassificationTestDatasetInterface, self).__init__(trainset=trainset, dataset_params=dataset_params,
                    310. classes=classes)
                    311. class SegmentationTestDatasetInterface(TestDatasetInterface):
                    312. def __init__(self, dataset_params={}, image_size=512, batch_size=4):
                    313. trainset = torch.utils.data.TensorDataset(torch.Tensor(np.zeros((batch_size, 3, image_size, image_size))),
                    314. torch.LongTensor(np.zeros((batch_size, image_size, image_size))))
                    315. super(SegmentationTestDatasetInterface, self).__init__(trainset=trainset, dataset_params=dataset_params)
                    316. class DetectionTestDatasetInterface(TestDatasetInterface):
                    317. def __init__(self, dataset_params={}, image_size=320, batch_size=4, classes=None):
                    318. trainset = torch.utils.data.TensorDataset(torch.Tensor(np.zeros((batch_size, 3, image_size, image_size))),
                    319. torch.Tensor(np.zeros((batch_size, 6))))
                    320. super(DetectionTestDatasetInterface, self).__init__(trainset=trainset, dataset_params=dataset_params,
                    321. classes=classes)
                    322. class TestYoloDetectionDatasetInterface(DatasetInterface):
                    323. """
                    324. note: the output size is (batch_size, 6) in the test while in real training
                    325. the size of axis 0 can vary (the number of bounding boxes)
                    326. """
                    327. def __init__(self, dataset_params={}, input_dims=(3, 32, 32), batch_size=5):
                    328. super().__init__(dataset_params)
                    329. self.trainset = torch.utils.data.TensorDataset(torch.ones((batch_size, *input_dims)),
                    330. torch.ones((batch_size, 6)))
                    331. self.trainset.classes = [0, 1, 2, 3, 4]
                    332. self.valset = self.trainset
                    333. class ImageNetDatasetInterface(DatasetInterface):
                    334. def __init__(self, dataset_params={}, data_dir="/data/Imagenet"):
                    335. super(ImageNetDatasetInterface, self).__init__(dataset_params)
                    336. data_dir = dataset_params['dataset_dir'] if 'dataset_dir' in dataset_params.keys() else data_dir
                    337. traindir = os.path.join(os.path.abspath(data_dir), 'train')
                    338. valdir = os.path.join(data_dir, 'val')
                    339. img_mean = core_utils.get_param(self.dataset_params, 'img_mean', default_val=[0.485, 0.456, 0.406])
                    340. img_std = core_utils.get_param(self.dataset_params, 'img_std', default_val=[0.229, 0.224, 0.225])
                    341. normalize = transforms.Normalize(mean=img_mean, std=img_std)
                    342. crop_size = core_utils.get_param(self.dataset_params, 'crop_size', default_val=224)
                    343. resize_size = core_utils.get_param(self.dataset_params, 'resize_size', default_val=256)
                    344. color_jitter = core_utils.get_param(self.dataset_params, 'color_jitter', default_val=0.0)
                    345. imagenet_pca_aug = core_utils.get_param(self.dataset_params, 'imagenet_pca_aug', default_val=0.0)
                    346. train_interpolation = core_utils.get_param(self.dataset_params, 'train_interpolation', default_val='default')
                    347. rand_augment_config_string = core_utils.get_param(self.dataset_params, 'rand_augment_config_string',
                    348. default_val=None)
                    349. color_jitter = (float(color_jitter),) * 3 if isinstance(color_jitter, float) else color_jitter
                    350. assert len(color_jitter) in (3, 4), "color_jitter must be a scalar or tuple of len 3 or 4"
                    351. color_augmentation = datasets_utils.get_color_augmentation(rand_augment_config_string, color_jitter,
                    352. crop_size=crop_size, img_mean=img_mean)
                    353. train_transformation_list = [
                    354. RandomResizedCropAndInterpolation(crop_size, interpolation=train_interpolation),
                    355. transforms.RandomHorizontalFlip(),
                    356. color_augmentation,
                    357. transforms.ToTensor(),
                    358. Lighting(imagenet_pca_aug),
                    359. normalize]
                    360. rndm_erase_prob = core_utils.get_param(self.dataset_params, 'random_erase_prob', default_val=0.)
                    361. if rndm_erase_prob:
                    362. train_transformation_list.append(RandomErase(rndm_erase_prob, self.dataset_params.random_erase_value))
                    363. self.trainset = datasets.ImageFolder(traindir, transforms.Compose(train_transformation_list))
                    364. self.valset = datasets.ImageFolder(valdir, transforms.Compose([
                    365. transforms.Resize(resize_size),
                    366. transforms.CenterCrop(crop_size),
                    367. transforms.ToTensor(),
                    368. normalize,
                    369. ]))
                    370. class TinyImageNetDatasetInterface(DatasetInterface):
                    371. def __init__(self, dataset_params={}, data_dir="/data/TinyImagenet"):
                    372. super(TinyImageNetDatasetInterface, self).__init__(dataset_params)
                    373. data_dir = dataset_params['dataset_dir'] if 'dataset_dir' in dataset_params.keys() else data_dir
                    374. traindir = os.path.join(os.path.abspath(data_dir), 'train')
                    375. valdir = os.path.join(data_dir, 'val')
                    376. img_mean = core_utils.get_param(self.dataset_params, 'img_mean', default_val=[0.4802, 0.4481, 0.3975])
                    377. img_std = core_utils.get_param(self.dataset_params, 'img_std', default_val=[0.2770, 0.2691, 0.2821])
                    378. normalize = transforms.Normalize(mean=img_mean,
                    379. std=img_std)
                    380. crop_size = core_utils.get_param(self.dataset_params, 'crop_size', default_val=56)
                    381. resize_size = core_utils.get_param(self.dataset_params, 'resize_size', default_val=64)
                    382. self.trainset = datasets.ImageFolder(
                    383. traindir,
                    384. transforms.Compose([
                    385. transforms.RandomResizedCrop(crop_size),
                    386. transforms.RandomHorizontalFlip(),
                    387. transforms.ToTensor(),
                    388. normalize,
                    389. ]))
                    390. self.valset = datasets.ImageFolder(valdir, transforms.Compose([
                    391. transforms.Resize(resize_size),
                    392. transforms.CenterCrop(crop_size),
                    393. transforms.ToTensor(),
                    394. normalize,
                    395. ]))
                    396. class ClassificationDatasetInterface(DatasetInterface):
                    397. def __init__(self, normalization_mean=(0, 0, 0), normalization_std=(1, 1, 1), resolution=64,
                    398. dataset_params={}):
                    399. super(ClassificationDatasetInterface, self).__init__(dataset_params)
                    400. data_dir = self.dataset_params.dataset_dir
                    401. traindir = os.path.join(os.path.abspath(data_dir), 'train')
                    402. valdir = os.path.join(data_dir, 'val')
                    403. normalize = transforms.Normalize(mean=normalization_mean,
                    404. std=normalization_std)
                    405. self.trainset = datasets.ImageFolder(
                    406. traindir,
                    407. transforms.Compose([
                    408. transforms.RandomResizedCrop(resolution),
                    409. transforms.RandomHorizontalFlip(),
                    410. transforms.ToTensor(),
                    411. normalize,
                    412. ]))
                    413. self.valset = datasets.ImageFolder(valdir, transforms.Compose([
                    414. transforms.Resize(int(resolution * 1.15)),
                    415. transforms.CenterCrop(resolution),
                    416. transforms.ToTensor(),
                    417. normalize,
                    418. ]))
                    419. self.data_dir = data_dir
                    420. self.normalization_mean = normalization_mean
                    421. self.normalization_std = normalization_std
                    422. class PascalVOC2012SegmentationDataSetInterface(DatasetInterface):
                    423. def __init__(self, dataset_params=None, cache_labels=False, cache_images=False):
                    424. if dataset_params is None:
                    425. dataset_params = dict()
                    426. super().__init__(dataset_params=dataset_params)
                    427. self.root_dir = dataset_params['dataset_dir'] if 'dataset_dir' in dataset_params.keys() \
                    428. else '/data/pascal_voc_2012/VOCdevkit/VOC2012/'
                    429. self.trainset = PascalVOC2012SegmentationDataSet(root=self.root_dir,
                    430. list_file='ImageSets/Segmentation/train.txt',
                    431. samples_sub_directory='JPEGImages',
                    432. targets_sub_directory='SegmentationClass', augment=True,
                    433. dataset_hyper_params=dataset_params, cache_labels=cache_labels,
                    434. cache_images=cache_images)
                    435. self.valset = PascalVOC2012SegmentationDataSet(root=self.root_dir,
                    436. list_file='ImageSets/Segmentation/val.txt',
                    437. samples_sub_directory='JPEGImages',
                    438. targets_sub_directory='SegmentationClass', augment=True,
                    439. dataset_hyper_params=dataset_params, cache_labels=cache_labels,
                    440. cache_images=cache_images)
                    441. self.classes = self.trainset.classes
                    442. class PascalAUG2012SegmentationDataSetInterface(DatasetInterface):
                    443. def __init__(self, dataset_params=None, cache_labels=False, cache_images=False):
                    444. if dataset_params is None:
                    445. dataset_params = dict()
                    446. super().__init__(dataset_params=dataset_params)
                    447. self.root_dir = dataset_params['dataset_dir'] if 'dataset_dir' in dataset_params.keys() \
                    448. else '/data/pascal_voc_2012/VOCaug/dataset/'
                    449. self.trainset = PascalAUG2012SegmentationDataSet(
                    450. root=self.root_dir,
                    451. list_file='trainval.txt',
                    452. samples_sub_directory='img',
                    453. targets_sub_directory='cls', augment=True,
                    454. dataset_hyper_params=dataset_params, cache_labels=cache_labels,
                    455. cache_images=cache_images)
                    456. self.valset = PascalAUG2012SegmentationDataSet(
                    457. root=self.root_dir,
                    458. list_file='val.txt',
                    459. samples_sub_directory='img',
                    460. targets_sub_directory='cls', augment=False,
                    461. dataset_hyper_params=dataset_params, cache_labels=cache_labels,
                    462. cache_images=cache_images)
                    463. self.classes = self.trainset.classes
                    464. class CoCoDataSetInterfaceBase(DatasetInterface):
                    465. def __init__(self, dataset_params=None):
                    466. if dataset_params is None:
                    467. dataset_params = dict()
                    468. super().__init__(dataset_params=dataset_params)
                    469. self.root_dir = dataset_params['dataset_dir'] if 'dataset_dir' in dataset_params.keys() else '/data/coco/'
                    470. class CoCoSegmentationDatasetInterface(CoCoDataSetInterfaceBase):
                    471. def __init__(self, dataset_params=None, cache_labels: bool = False, cache_images: bool = False,
                    472. dataset_classes_inclusion_tuples_list: list = None):
                    473. super().__init__(dataset_params=dataset_params)
                    474. # backwards compatability patch for legacy dataset params
                    475. img_size = core_utils.get_param(dataset_params, "img_size")
                    476. crop_size = core_utils.get_param(dataset_params, "crop_size")
                    477. train_transforms = [RandomFlip(),
                    478. Rescale(long_size=img_size),
                    479. RandomRescale(scales=(0.5, 2.0)),
                    480. PadShortToCropSize(crop_size=crop_size),
                    481. CropImageAndMask(crop_size=crop_size, mode="random")]
                    482. val_transforms = [Rescale(short_size=crop_size),
                    483. CropImageAndMask(crop_size=crop_size, mode="center")]
                    484. self.trainset = CoCoSegmentationDataSet(
                    485. root_dir=self.root_dir,
                    486. list_file='instances_train2017.json',
                    487. samples_sub_directory='images/train2017',
                    488. targets_sub_directory='annotations',
                    489. cache_labels=cache_labels,
                    490. cache_images=cache_images,
                    491. transforms=train_transforms,
                    492. dataset_classes_inclusion_tuples_list=dataset_classes_inclusion_tuples_list)
                    493. self.valset = CoCoSegmentationDataSet(
                    494. root_dir=self.root_dir,
                    495. list_file='instances_val2017.json',
                    496. samples_sub_directory='images/val2017',
                    497. targets_sub_directory='annotations',
                    498. cache_labels=cache_labels,
                    499. cache_images=cache_images,
                    500. transforms=val_transforms,
                    501. dataset_classes_inclusion_tuples_list=dataset_classes_inclusion_tuples_list)
                    502. self.coco_classes = self.trainset.classes
                    503. class CityscapesDatasetInterface(DatasetInterface):
                    504. def __init__(self, dataset_params=None, cache_labels: bool = False, cache_images: bool = False):
                    505. super().__init__(dataset_params=dataset_params)
                    506. root_dir = core_utils.get_param(dataset_params, "dataset_dir", "/data/cityscapes")
                    507. image_mask_transforms = core_utils.get_param(dataset_params, "image_mask_transforms")
                    508. image_mask_transforms_aug = core_utils.get_param(dataset_params, "image_mask_transforms_aug")
                    509. # Backwards compatability fix for SegmentationDataset refactor
                    510. train_transforms = image_mask_transforms_aug['Compose']['transforms']
                    511. val_transforms = image_mask_transforms['Compose']['transforms']
                    512. self.trainset = CityscapesDataset(
                    513. root_dir=root_dir,
                    514. list_file='lists/train.lst',
                    515. labels_csv_path="lists/labels.csv",
                    516. cache_labels=cache_labels,
                    517. cache_images=cache_images,
                    518. transforms=train_transforms)
                    519. self.valset = CityscapesDataset(
                    520. root_dir=root_dir,
                    521. list_file='lists/val.lst',
                    522. labels_csv_path="lists/labels.csv",
                    523. cache_labels=cache_labels,
                    524. cache_images=cache_images,
                    525. transforms=val_transforms)
                    526. self.classes = self.trainset.classes
                    527. class SuperviselyPersonsDatasetInterface(DatasetInterface):
                    528. def __init__(self, dataset_params=None, cache_labels: bool = False, cache_images: bool = False):
                    529. super().__init__(dataset_params=dataset_params)
                    530. root_dir = get_param(dataset_params, "dataset_dir", "/data/supervisely-persons")
                    531. self.trainset = SuperviselyPersonsDataset(
                    532. root_dir=root_dir,
                    533. list_file='train.csv',
                    534. dataset_hyper_params=dataset_params,
                    535. cache_labels=cache_labels,
                    536. cache_images=cache_images,
                    537. image_mask_transforms_aug=get_param(dataset_params, "image_mask_transforms_aug", transforms.Compose([])),
                    538. augment=True
                    539. )
                    540. self.valset = SuperviselyPersonsDataset(
                    541. root_dir=root_dir,
                    542. list_file='val.csv',
                    543. dataset_hyper_params=dataset_params,
                    544. cache_labels=cache_labels,
                    545. cache_images=cache_images,
                    546. image_mask_transforms=get_param(dataset_params, "image_mask_transforms", transforms.Compose([])),
                    547. augment=False
                    548. )
                    549. self.classes = self.trainset.classes
                    550. class DetectionDatasetInterface(DatasetInterface):
                    551. def build_data_loaders(self, batch_size_factor=1, num_workers=8, train_batch_size=None, val_batch_size=None,
                    552. test_batch_size=None, distributed_sampler: bool = False):
                    553. train_sampler = InfiniteSampler(self.trainset, seed=0)
                    554. train_batch_sampler = BatchSampler(
                    555. sampler=train_sampler,
                    556. batch_size=self.dataset_params.batch_size,
                    557. drop_last=False,
                    558. )
                    559. self.train_loader = DataLoader(self.trainset,
                    560. batch_sampler=train_batch_sampler,
                    561. num_workers=num_workers,
                    562. pin_memory=True,
                    563. worker_init_fn=worker_init_reset_seed,
                    564. collate_fn=self.dataset_params.train_collate_fn)
                    565. if distributed_sampler:
                    566. sampler = torch.utils.data.distributed.DistributedSampler(self.valset, shuffle=False)
                    567. else:
                    568. sampler = torch.utils.data.SequentialSampler(self.valset)
                    569. val_loader = torch.utils.data.DataLoader(self.valset,
                    570. num_workers=num_workers,
                    571. pin_memory=True,
                    572. sampler=sampler,
                    573. batch_size=self.dataset_params.val_batch_size,
                    574. collate_fn=self.dataset_params.val_collate_fn)
                    575. self.val_loader = val_loader
                    576. class PascalVOCUnifiedDetectionDatasetInterface(DetectionDatasetInterface):
                    577. def __init__(self, dataset_params=None):
                    578. if dataset_params is None:
                    579. dataset_params = dict()
                    580. super().__init__(dataset_params=dataset_params)
                    581. self.data_dir = self.dataset_params.data_dir
                    582. train_input_dim = (self.dataset_params.train_image_size, self.dataset_params.train_image_size)
                    583. val_input_dim = (self.dataset_params.val_image_size, self.dataset_params.val_image_size)
                    584. train_max_num_samples = get_param(self.dataset_params, "train_max_num_samples")
                    585. val_max_num_samples = get_param(self.dataset_params, "val_max_num_samples")
                    586. if self.dataset_params.download:
                    587. PascalVOCDetectionDataset.download(data_dir=self.data_dir)
                    588. train_dataset_names = ["train2007", "val2007", "train2012", "val2012"]
                    589. # We divide train_max_num_samples between the datasets
                    590. if train_max_num_samples:
                    591. max_num_samples_per_train_dataset = [len(segment) for segment in
                    592. np.array_split(range(train_max_num_samples), len(train_dataset_names))]
                    593. else:
                    594. max_num_samples_per_train_dataset = [None] * len(train_dataset_names)
                    595. train_sets = [PascalVOCDetectionDataset(data_dir=self.data_dir,
                    596. input_dim=train_input_dim,
                    597. cache=self.dataset_params.cache_train_images,
                    598. cache_dir=self.dataset_params.cache_dir,
                    599. transforms=self.dataset_params.train_transforms,
                    600. images_sub_directory='images/' + trainset_name + '/',
                    601. class_inclusion_list=self.dataset_params.class_inclusion_list,
                    602. max_num_samples=max_num_samples_per_train_dataset[i])
                    603. for i, trainset_name in enumerate(train_dataset_names)]
                    604. testset2007 = PascalVOCDetectionDataset(data_dir=self.data_dir,
                    605. input_dim=val_input_dim,
                    606. cache=self.dataset_params.cache_val_images,
                    607. cache_dir=self.dataset_params.cache_dir,
                    608. transforms=self.dataset_params.val_transforms,
                    609. images_sub_directory='images/test2007/',
                    610. class_inclusion_list=self.dataset_params.class_inclusion_list,
                    611. max_num_samples=val_max_num_samples)
                    612. self.classes = train_sets[1].classes
                    613. self.trainset = ConcatDataset(train_sets)
                    614. self.valset = testset2007
                    615. self.trainset.collate_fn = self.dataset_params.train_collate_fn
                    616. self.trainset.classes = self.classes
                    617. self.trainset.img_size = self.dataset_params.train_image_size
                    618. self.trainset.cache_labels = self.dataset_params.cache_train_images
                    619. class CoCoDetectionDatasetInterface(DetectionDatasetInterface):
                    620. def __init__(self, dataset_params={}):
                    621. super(CoCoDetectionDatasetInterface, self).__init__(dataset_params=dataset_params)
                    622. # IF CACHE- CREATING THE CACHE FILE WILL HAPPEN ONLY FOR RANK 0, THEN ALL THE OTHER RANKS SIMPLY READ FROM IT.
                    623. local_rank = get_local_rank()
                    624. with wait_for_the_master(local_rank):
                    625. self.trainset = COCODetectionDataset(data_dir=self.dataset_params.data_dir,
                    626. subdir=self.dataset_params.train_subdir,
                    627. json_file=self.dataset_params.train_json_file,
                    628. input_dim=self.dataset_params.train_input_dim,
                    629. cache=self.dataset_params.cache_train_images,
                    630. cache_dir=self.dataset_params.cache_dir,
                    631. transforms=self.dataset_params.train_transforms,
                    632. tight_box_rotation=self.dataset_params.tight_box_rotation,
                    633. class_inclusion_list=self.dataset_params.class_inclusion_list,
                    634. max_num_samples=self.dataset_params.train_max_num_samples,
                    635. with_crowd=False)
                    636. # IF CACHE- CREATING THE CACHE FILE WILL HAPPEN ONLY FOR RANK 0, THEN ALL THE OTHER RANKS SIMPLY READ FROM IT.
                    637. with wait_for_the_master(local_rank):
                    638. self.valset = COCODetectionDataset(
                    639. data_dir=self.dataset_params.data_dir,
                    640. json_file=self.dataset_params.val_json_file,
                    641. subdir=self.dataset_params.val_subdir,
                    642. cache_dir=self.dataset_params.cache_dir,
                    643. cache=self.dataset_params.cache_val_images,
                    644. input_dim=self.dataset_params.val_input_dim,
                    645. transforms=self.dataset_params.val_transforms,
                    646. class_inclusion_list=self.dataset_params.class_inclusion_list,
                    647. max_num_samples=self.dataset_params.val_max_num_samples,
                    648. with_crowd=self.dataset_params.with_crowd)
                    649. self.classes = COCO_DETECTION_CLASSES_LIST
                    Discard
                    @@ -4,6 +4,7 @@ from omegaconf import DictConfig
                     from torch.utils.data import DataLoader
                     from torch.utils.data import DataLoader
                     
                     
                     from super_gradients.common import MultiGPUMode
                     from super_gradients.common import MultiGPUMode
                    +from super_gradients.training.dataloaders import dataloaders
                     from super_gradients.training.models import SgModule
                     from super_gradients.training.models import SgModule
                     from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
                     from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
                     from super_gradients.training.models.kd_modules.kd_module import KDModule
                     from super_gradients.training.models.kd_modules.kd_module import KDModule
                    @@ -57,8 +58,14 @@ class KDTrainer(Trainer):
                     
                     
                             trainer = KDTrainer(**kwargs)
                             trainer = KDTrainer(**kwargs)
                     
                     
                    -        # CONNECT THE DATASET INTERFACE WITH DECI MODEL
                    -        trainer.connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
                    +        # INSTANTIATE DATA LOADERS
                    +        train_dataloader = dataloaders.get(name=cfg.train_dataloader,
                    +                                           dataset_params=cfg.dataset_params.train_dataset_params,
                    +                                           dataloader_params=cfg.dataset_params.train_dataloader_params)
                    +
                    +        val_dataloader = dataloaders.get(name=cfg.val_dataloader,
                    +                                         dataset_params=cfg.dataset_params.val_dataset_params,
                    +                                         dataloader_params=cfg.dataset_params.val_dataloader_params)
                     
                     
                             student = models.get(cfg.student_architecture, arch_params=cfg.student_arch_params,
                             student = models.get(cfg.student_architecture, arch_params=cfg.student_arch_params,
                                                  strict_load=cfg.student_checkpoint_params.strict_load,
                                                  strict_load=cfg.student_checkpoint_params.strict_load,
                    @@ -75,7 +82,8 @@ class KDTrainer(Trainer):
                             # TRAIN
                             # TRAIN
                             trainer.train(training_params=cfg.training_hyperparams, student=student, teacher=teacher,
                             trainer.train(training_params=cfg.training_hyperparams, student=student, teacher=teacher,
                                           kd_architecture=cfg.architecture, kd_arch_params=cfg.arch_params,
                                           kd_architecture=cfg.architecture, kd_arch_params=cfg.arch_params,
                    -                      run_teacher_on_eval=cfg.run_teacher_on_eval)
                    +                      run_teacher_on_eval=cfg.run_teacher_on_eval,
                    +                      train_loader=train_dataloader, valid_loader=val_dataloader)
                     
                     
                         def build_model(self,
                         def build_model(self,
                                         # noqa: C901 - too complex
                                         # noqa: C901 - too complex
                    @@ -303,7 +311,8 @@ class KDTrainer(Trainer):
                                                        })
                                                        })
                             return hyper_param_config
                             return hyper_param_config
                     
                     
                    -    def _instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> KDModelEMA:
                    +    def _instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15,
                    +                               exp_activation: bool = True) -> KDModelEMA:
                             """Instantiate KD ema model for KDModule.
                             """Instantiate KD ema model for KDModule.
                     
                     
                             If the model is of class KDModule, the instance will be adapted to work on knowledge distillation.
                             If the model is of class KDModule, the instance will be adapted to work on knowledge distillation.
                    @@ -330,7 +339,8 @@ class KDTrainer(Trainer):
                     
                     
                         def train(self, model: KDModule = None, training_params: dict = dict(), student: SgModule = None,
                         def train(self, model: KDModule = None, training_params: dict = dict(), student: SgModule = None,
                                   teacher: torch.nn.Module = None, kd_architecture: Union[KDModule.__class__, str] = 'kd_module',
                                   teacher: torch.nn.Module = None, kd_architecture: Union[KDModule.__class__, str] = 'kd_module',
                    -              kd_arch_params: dict = dict(), run_teacher_on_eval=False, *args, **kwargs):
                    +              kd_arch_params: dict = dict(), run_teacher_on_eval=False, train_loader: DataLoader = None,
                    +              valid_loader: DataLoader = None, *args, **kwargs):
                             """
                             """
                             Trains the student network (wrapped in KDModule network).
                             Trains the student network (wrapped in KDModule network).
                     
                     
                    @@ -342,6 +352,8 @@ class KDTrainer(Trainer):
                             :param kd_architecture: KDModule architecture to use, currently only 'kd_module' is supported (default='kd_module').
                             :param kd_architecture: KDModule architecture to use, currently only 'kd_module' is supported (default='kd_module').
                             :param kd_arch_params: architecture params to pas to kd_architecture constructor.
                             :param kd_arch_params: architecture params to pas to kd_architecture constructor.
                             :param run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode)
                             :param run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode)
                    +        :param train_loader: Dataloader for train set.
                    +        :param valid_loader: Dataloader for validation.
                             """
                             """
                             kd_net = self.net or model
                             kd_net = self.net or model
                             if kd_net is None:
                             if kd_net is None:
                    @@ -352,4 +364,5 @@ class KDTrainer(Trainer):
                                                                   run_teacher_on_eval=run_teacher_on_eval,
                                                                   run_teacher_on_eval=run_teacher_on_eval,
                                                                   student=student,
                                                                   student=student,
                                                                   teacher=teacher)
                                                                   teacher=teacher)
                    -        super(KDTrainer, self).train(model=kd_net, training_params=training_params)
                    +        super(KDTrainer, self).train(model=kd_net, training_params=training_params,
                    +                                     train_loader=train_loader, valid_loader=valid_loader)
                    Discard
                    @@ -1,5 +1,5 @@
                     from super_gradients.training.utils import HpmStruct
                     from super_gradients.training.utils import HpmStruct
                    -
                    +from copy import deepcopy
                     DEFAULT_TRAINING_PARAMS = {"lr_warmup_epochs": 0,
                     DEFAULT_TRAINING_PARAMS = {"lr_warmup_epochs": 0,
                                                "lr_cooldown_epochs": 0,
                                                "lr_cooldown_epochs": 0,
                                                "warmup_initial_lr": None,
                                                "warmup_initial_lr": None,
                    @@ -99,7 +99,8 @@ class TrainingParams(HpmStruct):
                     
                     
                         def __init__(self, **entries):
                         def __init__(self, **entries):
                             # WE initialize by the default training params, overridden by the provided params
                             # WE initialize by the default training params, overridden by the provided params
                    -        super().__init__(**DEFAULT_TRAINING_PARAMS)
                    +        default_training_params = deepcopy(DEFAULT_TRAINING_PARAMS)
                    +        super().__init__(**default_training_params)
                             self.set_schema(TRAINING_PARAM_SCHEMA)
                             self.set_schema(TRAINING_PARAM_SCHEMA)
                             if len(entries) > 0:
                             if len(entries) > 0:
                                 self.override(**entries)
                                 self.override(**entries)
                    Discard
                    @@ -23,14 +23,13 @@ from super_gradients.training.models.all_architectures import ARCHITECTURES
                     from super_gradients.common.decorators.factory_decorator import resolve_param
                     from super_gradients.common.decorators.factory_decorator import resolve_param
                     from super_gradients.common.environment import env_helpers
                     from super_gradients.common.environment import env_helpers
                     from super_gradients.common.abstractions.abstract_logger import get_logger
                     from super_gradients.common.abstractions.abstract_logger import get_logger
                    -from super_gradients.common.factories.datasets_factory import DatasetsFactory
                     from super_gradients.common.factories.list_factory import ListFactory
                     from super_gradients.common.factories.list_factory import ListFactory
                     from super_gradients.common.factories.losses_factory import LossesFactory
                     from super_gradients.common.factories.losses_factory import LossesFactory
                     from super_gradients.common.factories.metrics_factory import MetricsFactory
                     from super_gradients.common.factories.metrics_factory import MetricsFactory
                     from super_gradients.common.sg_loggers import SG_LOGGERS
                     from super_gradients.common.sg_loggers import SG_LOGGERS
                     from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
                     from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
                     from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
                     from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
                    -from super_gradients.training import utils as core_utils, models
                    +from super_gradients.training import utils as core_utils, models, dataloaders
                     from super_gradients.training.models import SgModule
                     from super_gradients.training.models import SgModule
                     from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
                     from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
                     from super_gradients.training.utils import sg_trainer_utils
                     from super_gradients.training.utils import sg_trainer_utils
                    @@ -38,7 +37,6 @@ from super_gradients.training.utils.quantization_utils import QATCallback
                     from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, parse_args
                     from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, parse_args
                     from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat, \
                     from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat, \
                         IllegalDataloaderInitialization
                         IllegalDataloaderInitialization
                    -from super_gradients.training.datasets import DatasetInterface
                     from super_gradients.training.losses import LOSSES
                     from super_gradients.training.losses import LOSSES
                     from super_gradients.training.metrics.metric_utils import get_metrics_titles, get_metrics_results_tuple, \
                     from super_gradients.training.metrics.metric_utils import get_metrics_titles, get_metrics_results_tuple, \
                         get_logging_values, \
                         get_logging_values, \
                    @@ -203,8 +201,14 @@ class Trainer:
                     
                     
                             trainer = Trainer(**kwargs)
                             trainer = Trainer(**kwargs)
                     
                     
                    -        # CONNECT THE DATASET INTERFACE WITH DECI MODEL
                    -        trainer.connect_dataset_interface(cfg.dataset_interface, data_loader_num_workers=cfg.data_loader_num_workers)
                    +        # INSTANTIATE DATA LOADERS
                    +        train_dataloader = dataloaders.get(name=cfg.train_dataloader,
                    +                                           dataset_params=cfg.dataset_params.train_dataset_params,
                    +                                           dataloader_params=cfg.dataset_params.train_dataloader_params)
                    +
                    +        val_dataloader = dataloaders.get(name=cfg.val_dataloader,
                    +                                         dataset_params=cfg.dataset_params.val_dataset_params,
                    +                                         dataloader_params=cfg.dataset_params.val_dataloader_params)
                     
                     
                             # BUILD NETWORK
                             # BUILD NETWORK
                             model = models.get(name=cfg.architecture,
                             model = models.get(name=cfg.architecture,
                    @@ -217,7 +221,10 @@ class Trainer:
                                                )
                                                )
                     
                     
                             # TRAIN
                             # TRAIN
                    -        trainer.train(model=model, training_params=cfg.training_hyperparams)
                    +        trainer.train(model=model,
                    +                      train_loader=train_dataloader,
                    +                      valid_loader=val_dataloader,
                    +                      training_params=cfg.training_hyperparams)
                     
                     
                         def _set_dataset_properties(self, classes, test_loader, train_loader, valid_loader):
                         def _set_dataset_properties(self, classes, test_loader, train_loader, valid_loader):
                             if any([train_loader, valid_loader, classes]) and not all([train_loader, valid_loader, classes]):
                             if any([train_loader, valid_loader, classes]) and not all([train_loader, valid_loader, classes]):
                    @@ -239,23 +246,6 @@ class Trainer:
                             self.dataset_params, self.train_loader, self.valid_loader, self.test_loader, self.classes = \
                             self.dataset_params, self.train_loader, self.valid_loader, self.test_loader, self.classes = \
                                 HpmStruct(**dataset_params), train_loader, valid_loader, test_loader, classes
                                 HpmStruct(**dataset_params), train_loader, valid_loader, test_loader, classes
                     
                     
                    -    @resolve_param('dataset_interface', DatasetsFactory())
                    -    def connect_dataset_interface(self, dataset_interface: DatasetInterface, data_loader_num_workers: int = 8):
                    -        """
                    -        :param dataset_interface: DatasetInterface object
                    -        :param data_loader_num_workers: The number of threads to initialize the Data Loaders with
                    -            The dataset to be connected
                    -        """
                    -        if self.train_loader:
                    -            logger.warning("Overriding the dataloaders that Trainer was initialized with")
                    -        self.dataset_interface = dataset_interface
                    -        self.train_loader, self.valid_loader, self.test_loader, self.classes = \
                    -            self.dataset_interface.get_data_loaders(batch_size_factor=self.num_devices,
                    -                                                    num_workers=data_loader_num_workers,
                    -                                                    distributed_sampler=self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL)
                    -
                    -        self.dataset_params = self.dataset_interface.get_dataset_params()
                    -
                         # FIXME - we need to resolve flake8's 'function is too complex' for this function
                         # FIXME - we need to resolve flake8's 'function is too complex' for this function
                         @deprecated(target=None, deprecated_in='2.3.0', remove_in='3.0.0')
                         @deprecated(target=None, deprecated_in='2.3.0', remove_in='3.0.0')
                         def build_model(self,  # noqa: C901 - too complex
                         def build_model(self,  # noqa: C901 - too complex
                    @@ -533,11 +523,7 @@ class Trainer:
                     
                     
                         def _prep_net_for_train(self):
                         def _prep_net_for_train(self):
                             if self.arch_params is None:
                             if self.arch_params is None:
                    -            default_arch_params = HpmStruct(sync_bn=False)
                    -            arch_params = getattr(self.net, "arch_params", default_arch_params)
                    -            self.arch_params = default_arch_params
                    -            if arch_params is not None:
                    -                self.arch_params.override(**arch_params.to_dict())
                    +            self._init_arch_params()
                     
                     
                             # TODO: REMOVE THE BELOW LINE (FOR BACKWARD COMPATIBILITY)
                             # TODO: REMOVE THE BELOW LINE (FOR BACKWARD COMPATIBILITY)
                             if self.checkpoint_params is None:
                             if self.checkpoint_params is None:
                    @@ -555,8 +541,16 @@ class Trainer:
                             self.external_checkpoint_path = core_utils.get_param(self.training_params, "resume_path")
                             self.external_checkpoint_path = core_utils.get_param(self.training_params, "resume_path")
                             self._load_checkpoint_to_model()
                             self._load_checkpoint_to_model()
                     
                     
                    +    def _init_arch_params(self):
                    +        default_arch_params = HpmStruct(sync_bn=False)
                    +        arch_params = getattr(self.net, "arch_params", default_arch_params)
                    +        self.arch_params = default_arch_params
                    +        if arch_params is not None:
                    +            self.arch_params.override(**arch_params.to_dict())
                    +
                         # FIXME - we need to resolve flake8's 'function is too complex' for this function
                         # FIXME - we need to resolve flake8's 'function is too complex' for this function
                    -    def train(self, model: nn.Module = None, training_params: dict = dict(), train_loader: DataLoader = None, valid_loader: DataLoader = None):  # noqa: C901
                    +    def train(self, model: nn.Module = None, training_params: dict = None, train_loader: DataLoader = None,
                    +              valid_loader: DataLoader = None):  # noqa: C901
                             """
                             """
                     
                     
                             train - Trains the Model
                             train - Trains the Model
                    @@ -835,6 +829,8 @@ class Trainer:
                             :return:
                             :return:
                             """
                             """
                             global logger
                             global logger
                    +        if training_params is None:
                    +            training_params = dict()
                     
                     
                             self.train_loader = train_loader or self.train_loader
                             self.train_loader = train_loader or self.train_loader
                             self.valid_loader = valid_loader or self.valid_loader
                             self.valid_loader = valid_loader or self.valid_loader
                    @@ -1427,13 +1423,16 @@ class Trainer:
                                                  "calling test or through training_params when calling train(...)")
                                                  "calling test or through training_params when calling train(...)")
                             if self.test_loader is None:
                             if self.test_loader is None:
                                 raise ValueError("Test dataloader is required to perform test. Make sure to either pass it through "
                                 raise ValueError("Test dataloader is required to perform test. Make sure to either pass it through "
                    -                             "test_loader arg or calling connect_dataset_interface upon a DatasetInterface instance "
                    -                             "with a non empty testset attribute.")
                    +                             "test_loader arg.")
                     
                     
                             # RESET METRIC RUNNERS
                             # RESET METRIC RUNNERS
                             self._reset_metrics()
                             self._reset_metrics()
                             self.test_metrics.to(self.device)
                             self.test_metrics.to(self.device)
                     
                     
                    +        if self.arch_params is None:
                    +            self._init_arch_params()
                    +        self._net_to_device()
                    +
                         def _add_metrics_update_callback(self, phase: Phase):
                         def _add_metrics_update_callback(self, phase: Phase):
                             """
                             """
                             Adds MetricsUpdateCallback to be fired at phase
                             Adds MetricsUpdateCallback to be fired at phase
                    Discard
                    @@ -5,8 +5,7 @@ from super_gradients.training import models
                     import super_gradients
                     import super_gradients
                     
                     
                     from super_gradients import Trainer
                     from super_gradients import Trainer
                    -from super_gradients.training.datasets.dataset_interfaces import LibraryDatasetInterface
                    -from super_gradients.training.dataloaders.dataloader_factory import (
                    +from super_gradients.training.dataloaders.dataloaders import (
                         cifar10_train,
                         cifar10_train,
                         cifar10_val,
                         cifar10_val,
                         cifar100_train,
                         cifar100_train,
                    @@ -15,24 +14,6 @@ from super_gradients.training.dataloaders.dataloader_factory import (
                     
                     
                     
                     
                     class TestCifarTrainer(unittest.TestCase):
                     class TestCifarTrainer(unittest.TestCase):
                    -    def test_train_cifar10(self):
                    -        super_gradients.init_trainer()
                    -        trainer = Trainer("test", model_checkpoints_location="local")
                    -        cifar_10_dataset_interface = LibraryDatasetInterface(name="cifar10")
                    -        trainer.connect_dataset_interface(cifar_10_dataset_interface)
                    -        model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
                    -        trainer.train(
                    -            model=model,
                    -            training_params={
                    -                "max_epochs": 1,
                    -                "initial_lr": 0.1,
                    -                "loss": "cross_entropy",
                    -                "train_metrics_list": ["Accuracy"],
                    -                "valid_metrics_list": ["Accuracy"],
                    -                "metric_to_watch": "Accuracy",
                    -            },
                    -        )
                    -
                         def test_train_cifar10_dataloader(self):
                         def test_train_cifar10_dataloader(self):
                             super_gradients.init_trainer()
                             super_gradients.init_trainer()
                             trainer = Trainer("test", model_checkpoints_location="local")
                             trainer = Trainer("test", model_checkpoints_location="local")
                    @@ -52,24 +33,6 @@ class TestCifarTrainer(unittest.TestCase):
                                 valid_loader=cifar10_val_dl,
                                 valid_loader=cifar10_val_dl,
                             )
                             )
                     
                     
                    -    def test_train_cifar100(self):
                    -        super_gradients.init_trainer()
                    -        trainer = Trainer("test", model_checkpoints_location="local")
                    -        cifar_10_dataset_interface = LibraryDatasetInterface(name="cifar100")
                    -        trainer.connect_dataset_interface(cifar_10_dataset_interface)
                    -        model = models.get("resnet18_cifar", arch_params={"num_classes": 100})
                    -        trainer.train(
                    -            model=model,
                    -            training_params={
                    -                "max_epochs": 1,
                    -                "initial_lr": 0.1,
                    -                "loss": "cross_entropy",
                    -                "train_metrics_list": ["Accuracy"],
                    -                "valid_metrics_list": ["Accuracy"],
                    -                "metric_to_watch": "Accuracy",
                    -            },
                    -        )
                    -
                         def test_train_cifar100_dataloader(self):
                         def test_train_cifar100_dataloader(self):
                             super_gradients.init_trainer()
                             super_gradients.init_trainer()
                             trainer = Trainer("test", model_checkpoints_location="local")
                             trainer = Trainer("test", model_checkpoints_location="local")
                    Discard
                    1
                    2
                    3
                    4
                    5
                    6
                    7
                    8
                    9
                    10
                    11
                    12
                    13
                    14
                    15
                    16
                    17
                    18
                    19
                    20
                    21
                    22
                    23
                    24
                    25
                    26
                    27
                    28
                    29
                    30
                    31
                    32
                    33
                    34
                    35
                    36
                    37
                    38
                    39
                    40
                    41
                    42
                    43
                    44
                    45
                    46
                    47
                    48
                    49
                    50
                    51
                    52
                    53
                    54
                    55
                    56
                    57
                    58
                    59
                    60
                    61
                    62
                    63
                    64
                    65
                    66
                    67
                    68
                    69
                    70
                    71
                    72
                    73
                    74
                    75
                    76
                    77
                    78
                    79
                    80
                    81
                    82
                    83
                    84
                    85
                    86
                    87
                    88
                    89
                    90
                    91
                    92
                    93
                    94
                    95
                    96
                    97
                    98
                    99
                    100
                    101
                    102
                    103
                    104
                    105
                    106
                    107
                    108
                    109
                    110
                    111
                    112
                    113
                    114
                    115
                    116
                    117
                    118
                    119
                    120
                    121
                    122
                    123
                    124
                    125
                    126
                    127
                    128
                    129
                    130
                    131
                    132
                    133
                    134
                    135
                    1. import super_gradients
                    2. import torch
                    3. import unittest
                    4. import numpy as np
                    5. from PIL import Image
                    6. import tensorflow.keras as keras
                    7. from super_gradients.training import MultiGPUMode, models
                    8. from super_gradients.training import Trainer
                    9. from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ExternalDatasetInterface, \
                    10. ImageNetDatasetInterface
                    11. from super_gradients.training.metrics import Accuracy, Top5
                    12. class DataGenerator(keras.utils.Sequence):
                    13. def __init__(self, samples, batch_size=1, dims=(320, 320), n_channels=3,
                    14. n_classes=1000, shuffle=True):
                    15. self.dims = dims
                    16. self.batch_size = batch_size
                    17. self.samples = samples
                    18. self.n_channels = n_channels
                    19. self.n_classes = n_classes
                    20. self.shuffle = shuffle
                    21. self.on_epoch_end()
                    22. def __len__(self):
                    23. # Fraction of dataset to be used - for faster testing
                    24. fraction_of_dataset = 0.01
                    25. return int(np.floor(len(self.samples) / self.batch_size) * fraction_of_dataset)
                    26. def __getitem__(self, index):
                    27. indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]
                    28. list_IDs_temp = [self.samples[k] for k in indices]
                    29. X, y = self.__data_generation(list_IDs_temp)
                    30. return X, y
                    31. def on_epoch_end(self):
                    32. self.indices = np.arange(len(self.samples))
                    33. if self.shuffle:
                    34. np.random.shuffle(self.indices)
                    35. def __data_generation(self, list_IDs_temp):
                    36. X = np.empty((self.batch_size, *self.dims, self.n_channels), dtype=np.float32)
                    37. y = np.empty((self.batch_size), dtype=int)
                    38. for i, ID in enumerate(list_IDs_temp):
                    39. image = Image.open(ID[0])
                    40. image = image.resize((self.dims))
                    41. rgb_image = Image.new("RGB", image.size)
                    42. rgb_image.paste(image)
                    43. X[i, ] = np.array(rgb_image)
                    44. y[i] = ID[1]
                    45. return X, keras.utils.to_categorical(y, num_classes=self.n_classes)
                    46. def create_imagenet_dataset():
                    47. dataset_params = {"batch_size": 1}
                    48. dataset = ImageNetDatasetInterface(data_dir="/data/Imagenet", dataset_params=dataset_params)
                    49. return dataset
                    50. class TransposeCollateFn(object):
                    51. def __init__(self, new_shape):
                    52. self.new_shape = new_shape
                    53. def __call__(self, batch):
                    54. new_inputs = []
                    55. new_targets = []
                    56. for img in batch:
                    57. squeezed_input = img[0].squeeze(axis=0)
                    58. transposed_data = np.transpose(squeezed_input, self.new_shape)
                    59. new_inputs.append(torch.from_numpy(transposed_data))
                    60. argmax_target = np.argmax(img[1], 1)
                    61. new_targets.append(torch.from_numpy(argmax_target))
                    62. return torch.stack(new_inputs, 0), torch.cat(new_targets, 0)
                    63. class TestExternalDatasetInterface(unittest.TestCase):
                    64. def setUp(self):
                    65. super_gradients.init_trainer()
                    66. dataset = create_imagenet_dataset()
                    67. data_samples_train = dataset.trainset.samples
                    68. data_samples_val = dataset.valset.samples
                    69. # batch size: 1 is only for the creation of the external keras loader
                    70. self.keras_params = {'dims': (256, 256),
                    71. 'batch_size': 1,
                    72. 'n_classes': 1000,
                    73. 'n_channels': 3,
                    74. 'shuffle': True}
                    75. training_generator = DataGenerator(data_samples_train, **self.keras_params)
                    76. testing_generator = DataGenerator(data_samples_val, **self.keras_params)
                    77. external_num_classes = 1000
                    78. collate_fn = TransposeCollateFn((2, 0, 1))
                    79. self.external_dataset_params = {'batch_size': 16,
                    80. 'test_batch_size': 16,
                    81. 'train_collate_fn': collate_fn,
                    82. 'val_collate_fn': collate_fn}
                    83. self.test_external_dataset_interface = ExternalDatasetInterface(train_loader=training_generator,
                    84. val_loader=testing_generator,
                    85. num_classes=external_num_classes,
                    86. dataset_params=self.external_dataset_params)
                    87. def test_transpose_collate_fn(self):
                    88. collate_fn = TransposeCollateFn((2, 0, 1))
                    89. dims = self.keras_params['dims']
                    90. n_channels = self.keras_params['n_channels']
                    91. batch_size = self.external_dataset_params['batch_size']
                    92. dummy_batch = []
                    93. dummy_input = np.expand_dims(np.random.rand(dims[0], dims[1], n_channels), axis=0)
                    94. dummy_target = np.expand_dims(np.random.rand(1), axis=0)
                    95. for i in range(batch_size):
                    96. dummy_batch.append((dummy_input, dummy_target))
                    97. collate_fn_output = collate_fn.__call__(dummy_batch)
                    98. dummy_tensor = torch.rand(batch_size, n_channels, dims[0], dims[1])
                    99. self.assertEqual(dummy_tensor.shape, collate_fn_output[0].shape)
                    100. def test_model_train(self):
                    101. train_params = {"max_epochs": 2, "lr_decay_factor": 0.1, "initial_lr": 0.025,
                    102. "loss": "cross_entropy",
                    103. "train_metrics_list": [Accuracy(), Top5()],
                    104. "valid_metrics_list": [Accuracy(), Top5()],
                    105. "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                    106. "greater_metric_to_watch_is_better": True}
                    107. arch_params = {'num_classes': 1000}
                    108. trainer = Trainer("test", model_checkpoints_location='local',
                    109. multi_gpu=MultiGPUMode.OFF)
                    110. trainer.connect_dataset_interface(dataset_interface=self.test_external_dataset_interface,
                    111. data_loader_num_workers=8)
                    112. model = models.get("resnet50", arch_params)
                    113. trainer.train(model=model, training_params=train_params)
                    114. if __name__ == '__main__':
                    115. unittest.main()
                    Discard
                    @@ -6,7 +6,8 @@ from super_gradients.training import models
                     import super_gradients
                     import super_gradients
                     import torch
                     import torch
                     import os
                     import os
                    -from super_gradients import Trainer, ClassificationTestDatasetInterface
                    +from super_gradients import Trainer
                    +from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
                     from super_gradients.training.metrics import Accuracy, Top5
                     from super_gradients.training.metrics import Accuracy, Top5
                     
                     
                     
                     
                    @@ -38,32 +39,32 @@ class TestTrainer(unittest.TestCase):
                         @staticmethod
                         @staticmethod
                         def get_classification_trainer(name=''):
                         def get_classification_trainer(name=''):
                             trainer = Trainer(name, model_checkpoints_location='local')
                             trainer = Trainer(name, model_checkpoints_location='local')
                    -        dataset_params = {"batch_size": 4}
                    -        dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params, image_size=224)
                    -        trainer.connect_dataset_interface(dataset)
                    -        model = models.get("resnet18", arch_params={"num_classes": 5})
                    +        model = models.get("resnet18", num_classes=5)
                             return trainer, model
                             return trainer, model
                     
                     
                         def test_train(self):
                         def test_train(self):
                             trainer, model = self.get_classification_trainer(self.folder_names[0])
                             trainer, model = self.get_classification_trainer(self.folder_names[0])
                    -        trainer.train(model=model, training_params=self.training_params)
                    +        trainer.train(model=model, training_params=self.training_params, train_loader=classification_test_dataloader(),
                    +                      valid_loader=classification_test_dataloader())
                     
                     
                         def test_save_load(self):
                         def test_save_load(self):
                             trainer, model = self.get_classification_trainer(self.folder_names[1])
                             trainer, model = self.get_classification_trainer(self.folder_names[1])
                    -        trainer.train(model=model, training_params=self.training_params)
                    -
                    +        trainer.train(model=model, training_params=self.training_params, train_loader=classification_test_dataloader(),
                    +                      valid_loader=classification_test_dataloader())
                             resume_training_params = self.training_params.copy()
                             resume_training_params = self.training_params.copy()
                             resume_training_params["resume"] = True
                             resume_training_params["resume"] = True
                             resume_training_params["max_epochs"] = 2
                             resume_training_params["max_epochs"] = 2
                             trainer, model = self.get_classification_trainer(self.folder_names[1])
                             trainer, model = self.get_classification_trainer(self.folder_names[1])
                    -        trainer.train(model=model, training_params=resume_training_params)
                    +        trainer.train(model=model, training_params=resume_training_params, train_loader=classification_test_dataloader(),
                    +                      valid_loader=classification_test_dataloader())
                     
                     
                         def test_checkpoint_content(self):
                         def test_checkpoint_content(self):
                             """VERIFY THAT ALL CHECKPOINTS ARE SAVED AND CONTAIN ALL THE EXPECTED KEYS"""
                             """VERIFY THAT ALL CHECKPOINTS ARE SAVED AND CONTAIN ALL THE EXPECTED KEYS"""
                             trainer, model = self.get_classification_trainer(self.folder_names[5])
                             trainer, model = self.get_classification_trainer(self.folder_names[5])
                             params = self.training_params.copy()
                             params = self.training_params.copy()
                             params["save_ckpt_epoch_list"] = [1]
                             params["save_ckpt_epoch_list"] = [1]
                    -        trainer.train(model=model, training_params=params)
                    +        trainer.train(model=model, training_params=params, train_loader=classification_test_dataloader(),
                    +                      valid_loader=classification_test_dataloader())
                             ckpt_filename = ['ckpt_best.pth', 'ckpt_latest.pth', 'ckpt_epoch_1.pth']
                             ckpt_filename = ['ckpt_best.pth', 'ckpt_latest.pth', 'ckpt_epoch_1.pth']
                             ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
                             ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
                             for ckpt_path in ckpt_paths:
                             for ckpt_path in ckpt_paths:
                    Discard
                    @@ -1,7 +1,6 @@
                     # PACKAGE IMPORTS FOR EXTERNAL USAGE
                     # PACKAGE IMPORTS FOR EXTERNAL USAGE
                     
                     
                    -from tests.integration_tests.s3_dataset_test import TestDataset
                     from tests.integration_tests.ema_train_integration_test import EMAIntegrationTest
                     from tests.integration_tests.ema_train_integration_test import EMAIntegrationTest
                     from tests.integration_tests.lr_test import LRTest
                     from tests.integration_tests.lr_test import LRTest
                     
                     
                    -_all__ = [TestDataset, EMAIntegrationTest, LRTest]
                    +__all__ = ["EMAIntegrationTest", "LRTest"]
                    Discard
                    @@ -4,11 +4,9 @@ import re
                     
                     
                     from super_gradients.training import models
                     from super_gradients.training import models
                     
                     
                    -from super_gradients import (
                    -    Trainer,
                    -    ClassificationTestDatasetInterface,
                    -    SegmentationTestDatasetInterface,
                    -)
                    +from super_gradients import Trainer
                    +from super_gradients.training.dataloaders.dataloaders import segmentation_test_dataloader, \
                    +    classification_test_dataloader
                     from super_gradients.training.utils.callbacks import ModelConversionCheckCallback
                     from super_gradients.training.utils.callbacks import ModelConversionCheckCallback
                     from super_gradients.training.metrics import Accuracy, Top5, IoU
                     from super_gradients.training.metrics import Accuracy, Top5, IoU
                     from super_gradients.training.losses.stdc_loss import STDCLoss
                     from super_gradients.training.losses.stdc_loss import STDCLoss
                    @@ -16,7 +14,6 @@ from super_gradients.training.losses.ddrnet_loss import DDRNetLoss
                     
                     
                     from deci_lab_client.models import ModelMetadata, HardwareType, FrameworkType
                     from deci_lab_client.models import ModelMetadata, HardwareType, FrameworkType
                     
                     
                    -
                     checkpoint_dir = "/Users/daniel/Documents/LALA"
                     checkpoint_dir = "/Users/daniel/Documents/LALA"
                     
                     
                     
                     
                    @@ -44,6 +41,8 @@ def generate_model_metadata(architecture: str, task: Task):
                     
                     
                     CLASSIFICATION = ["efficientnet_b0", "regnetY200", "regnetY400", "regnetY600", "regnetY800", "mobilenet_v3_large"]
                     CLASSIFICATION = ["efficientnet_b0", "regnetY200", "regnetY400", "regnetY600", "regnetY800", "mobilenet_v3_large"]
                     SEMANTIC_SEGMENTATION = ["ddrnet_23", "stdc1_seg", "stdc2_seg", "regseg48"]
                     SEMANTIC_SEGMENTATION = ["ddrnet_23", "stdc1_seg", "stdc2_seg", "regseg48"]
                    +
                    +
                     # TODO: ADD YOLOX ARCHITECTURES AND TESTS
                     # TODO: ADD YOLOX ARCHITECTURES AND TESTS
                     
                     
                     
                     
                    @@ -70,13 +69,12 @@ class ConversionCallbackTest(unittest.TestCase):
                                     "phase_callbacks": phase_callbacks,
                                     "phase_callbacks": phase_callbacks,
                                 }
                                 }
                     
                     
                    -            trainer = Trainer(f"{architecture}_example", model_checkpoints_location="local", ckpt_root_dir=checkpoint_dir)
                    -            dataset = ClassificationTestDatasetInterface(dataset_params={"batch_size": 10})
                    -
                    -            trainer.connect_dataset_interface(dataset, data_loader_num_workers=0)
                    +            trainer = Trainer(f"{architecture}_example", model_checkpoints_location="local",
                    +                              ckpt_root_dir=checkpoint_dir)
                                 model = models.get(architecture=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
                                 model = models.get(architecture=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
                                 try:
                                 try:
                    -                trainer.train(model=model, training_params=train_params)
                    +                trainer.train(model=model, training_params=train_params, train_loader=classification_test_dataloader(),
                    +                              valid_loader=classification_test_dataloader())
                                 except Exception as e:
                                 except Exception as e:
                                     self.fail(f"Model training didn't succeed due to {e}")
                                     self.fail(f"Model training didn't succeed due to {e}")
                                 else:
                                 else:
                    @@ -104,10 +102,9 @@ class ConversionCallbackTest(unittest.TestCase):
                     
                     
                             for architecture in SEMANTIC_SEGMENTATION:
                             for architecture in SEMANTIC_SEGMENTATION:
                                 model_meta_data = generate_model_metadata(architecture=architecture, task=Task.SEMANTIC_SEGMENTATION)
                                 model_meta_data = generate_model_metadata(architecture=architecture, task=Task.SEMANTIC_SEGMENTATION)
                    -            dataset = SegmentationTestDatasetInterface(dataset_params={"batch_size": 10})
                    -            trainer = Trainer(f"{architecture}_example", model_checkpoints_location="local", ckpt_root_dir=checkpoint_dir)
                    -            trainer.connect_dataset_interface(dataset, data_loader_num_workers=0)
                    -            model = models.get(architecture=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
                    +            trainer = Trainer(f"{architecture}_example", model_checkpoints_location="local",
                    +                              ckpt_root_dir=checkpoint_dir)
                    +            model = models.get(name=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
                     
                     
                                 phase_callbacks = [
                                 phase_callbacks = [
                                     ModelConversionCheckCallback(model_meta_data=model_meta_data, opset_version=11, rtol=1, atol=1),
                                     ModelConversionCheckCallback(model_meta_data=model_meta_data, opset_version=11, rtol=1, atol=1),
                    @@ -131,7 +128,8 @@ class ConversionCallbackTest(unittest.TestCase):
                                 train_params.update(custom_config)
                                 train_params.update(custom_config)
                     
                     
                                 try:
                                 try:
                    -                trainer.train(model=model, training_params=train_params)
                    +                trainer.train(model=model, training_params=train_params, train_loader=segmentation_test_dataloader(image_size=512),
                    +                              valid_loader=segmentation_test_dataloader(image_size=512))
                                 except Exception as e:
                                 except Exception as e:
                                     self.fail(f"Model training didn't succeed for {architecture} due to {e}")
                                     self.fail(f"Model training didn't succeed for {architecture} due to {e}")
                                 else:
                                 else:
                    Discard
                    @@ -1,6 +1,6 @@
                     import unittest
                     import unittest
                    -from super_gradients import Trainer, \
                    -    ClassificationTestDatasetInterface
                    +from super_gradients import Trainer
                    +from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
                     from super_gradients.training.metrics import Accuracy, Top5
                     from super_gradients.training.metrics import Accuracy, Top5
                     from super_gradients.training.models import ResNet18
                     from super_gradients.training.models import ResNet18
                     from torch.optim import SGD
                     from torch.optim import SGD
                    @@ -11,8 +11,6 @@ from deci_lab_client.models import Metric, QuantizationLevel, ModelMetadata, Opt
                     class DeciLabUploadTest(unittest.TestCase):
                     class DeciLabUploadTest(unittest.TestCase):
                         def setUp(self) -> None:
                         def setUp(self) -> None:
                             self.trainer = Trainer("deci_lab_export_test_model", model_checkpoints_location='local')
                             self.trainer = Trainer("deci_lab_export_test_model", model_checkpoints_location='local')
                    -        dataset = ClassificationTestDatasetInterface(dataset_params={"batch_size": 10})
                    -        self.trainer.connect_dataset_interface(dataset)
                     
                     
                         def test_train_with_deci_lab_integration(self):
                         def test_train_with_deci_lab_integration(self):
                             model_meta_data = ModelMetadata(name='model_for_deci_lab_upload_test',
                             model_meta_data = ModelMetadata(name='model_for_deci_lab_upload_test',
                    @@ -49,7 +47,8 @@ class DeciLabUploadTest(unittest.TestCase):
                                             "phase_callbacks": [model_conversion_callback, deci_lab_callback]}
                                             "phase_callbacks": [model_conversion_callback, deci_lab_callback]}
                             self.optimizer = SGD(params=net.parameters(), lr=0.1)
                             self.optimizer = SGD(params=net.parameters(), lr=0.1)
                     
                     
                    -        self.trainer.train(model=net, training_params=train_params)
                    +        self.trainer.train(model=net, training_params=train_params,
                    +                           train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
                     
                     
                             # CLEANUP
                             # CLEANUP
                     
                     
                    Discard
                    @@ -1,6 +1,6 @@
                    -from super_gradients import ClassificationTestDatasetInterface
                     from super_gradients.training import MultiGPUMode, models
                     from super_gradients.training import MultiGPUMode, models
                     from super_gradients.training import Trainer
                     from super_gradients.training import Trainer
                    +from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
                     from super_gradients.training.metrics import Accuracy, Top5
                     from super_gradients.training.metrics import Accuracy, Top5
                     import unittest
                     import unittest
                     
                     
                    @@ -25,8 +25,6 @@ class EMAIntegrationTest(unittest.TestCase):
                         def _init_model(self) -> None:
                         def _init_model(self) -> None:
                             self.trainer = Trainer("resnet18_cifar_ema_test", model_checkpoints_location='local',
                             self.trainer = Trainer("resnet18_cifar_ema_test", model_checkpoints_location='local',
                                                    device='cpu', multi_gpu=MultiGPUMode.OFF)
                                                    device='cpu', multi_gpu=MultiGPUMode.OFF)
                    -        dataset_interface = ClassificationTestDatasetInterface({"batch_size": 32})
                    -        self.trainer.connect_dataset_interface(dataset_interface, 8)
                             self.model = models.get("resnet18_cifar", arch_params={"num_classes": 5})
                             self.model = models.get("resnet18_cifar", arch_params={"num_classes": 5})
                     
                     
                         @classmethod
                         @classmethod
                    @@ -65,7 +63,9 @@ class EMAIntegrationTest(unittest.TestCase):
                             self.trainer.test = CallWrapper(self.trainer.test, check_before=before_test)
                             self.trainer.test = CallWrapper(self.trainer.test, check_before=before_test)
                             self.trainer._train_epoch = CallWrapper(self.trainer._train_epoch, check_before=before_train_epoch)
                             self.trainer._train_epoch = CallWrapper(self.trainer._train_epoch, check_before=before_train_epoch)
                     
                     
                    -        self.trainer.train(model=self.model, training_params=training_params)
                    +        self.trainer.train(model=self.model, training_params=training_params,
                    +                           train_loader=classification_test_dataloader(),
                    +                           valid_loader=classification_test_dataloader())
                     
                     
                             self.assertIsNotNone(self.trainer.ema_model)
                             self.assertIsNotNone(self.trainer.ema_model)
                     
                     
                    Discard
                    @@ -4,7 +4,8 @@ import os
                     
                     
                     from super_gradients.training import models
                     from super_gradients.training import models
                     
                     
                    -from super_gradients import Trainer, ClassificationTestDatasetInterface
                    +from super_gradients import Trainer
                    +from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
                     from super_gradients.training.metrics import Accuracy, Top5
                     from super_gradients.training.metrics import Accuracy, Top5
                     
                     
                     
                     
                    @@ -30,10 +31,7 @@ class LRTest(unittest.TestCase):
                         @staticmethod
                         @staticmethod
                         def get_trainer(name=''):
                         def get_trainer(name=''):
                             trainer = Trainer(name, model_checkpoints_location='local')
                             trainer = Trainer(name, model_checkpoints_location='local')
                    -        dataset_params = {"batch_size": 4}
                    -        dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
                    -        trainer.connect_dataset_interface(dataset)
                    -        model = models.get("resnet18_cifar", arch_params={"num_classes": 5})
                    +        model = models.get("resnet18_cifar", num_classes=5)
                             return trainer, model
                             return trainer, model
                     
                     
                         def test_function_lr(self):
                         def test_function_lr(self):
                    @@ -44,22 +42,25 @@ class LRTest(unittest.TestCase):
                     
                     
                             # test if we are able that lr_function supports functions with this structure
                             # test if we are able that lr_function supports functions with this structure
                             training_params = {**self.training_params, "lr_mode": "function", "lr_schedule_function": test_lr_function}
                             training_params = {**self.training_params, "lr_mode": "function", "lr_schedule_function": test_lr_function}
                    -        trainer.train(model=model, training_params=training_params)
                    -
                    +        trainer.train(model=model, training_params=training_params, train_loader=classification_test_dataloader(),
                    +                      valid_loader=classification_test_dataloader())
                             # test that we assert lr_function is callable
                             # test that we assert lr_function is callable
                             training_params = {**self.training_params, "lr_mode": "function"}
                             training_params = {**self.training_params, "lr_mode": "function"}
                             with self.assertRaises(AssertionError):
                             with self.assertRaises(AssertionError):
                    -            trainer.train(model=model, training_params=training_params)
                    +            trainer.train(model=model, training_params=training_params, train_loader=classification_test_dataloader(),
                    +                          valid_loader=classification_test_dataloader())
                     
                     
                         def test_cosine_lr(self):
                         def test_cosine_lr(self):
                             trainer, model = self.get_trainer(self.folder_name)
                             trainer, model = self.get_trainer(self.folder_name)
                             training_params = {**self.training_params, "lr_mode": "cosine", "cosine_final_lr_ratio": 0.01}
                             training_params = {**self.training_params, "lr_mode": "cosine", "cosine_final_lr_ratio": 0.01}
                    -        trainer.train(model=model, training_params=training_params)
                    +        trainer.train(model=model, training_params=training_params, train_loader=classification_test_dataloader(),
                    +                      valid_loader=classification_test_dataloader())
                     
                     
                         def test_step_lr(self):
                         def test_step_lr(self):
                             trainer, model = self.get_trainer(self.folder_name)
                             trainer, model = self.get_trainer(self.folder_name)
                             training_params = {**self.training_params, "lr_mode": "step", "lr_decay_factor": 0.1, "lr_updates": [4]}
                             training_params = {**self.training_params, "lr_mode": "step", "lr_decay_factor": 0.1, "lr_updates": [4]}
                    -        trainer.train(model=model, training_params=training_params)
                    +        trainer.train(model=model, training_params=training_params, train_loader=classification_test_dataloader(),
                    +                      valid_loader=classification_test_dataloader())
                     
                     
                     
                     
                     if __name__ == '__main__':
                     if __name__ == '__main__':
                    Discard
                    Only showing up to 1000 lines per file, please use a local Git client to see the full diff.
                    @@ -1,11 +1,13 @@
                     import unittest
                     import unittest
                    -import super_gradients
                    +
                     from super_gradients.training import MultiGPUMode
                     from super_gradients.training import MultiGPUMode
                     from super_gradients.training import Trainer
                     from super_gradients.training import Trainer
                    -from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ImageNetDatasetInterface, \
                    -    ClassificationTestDatasetInterface, CityscapesDatasetInterface, SegmentationTestDatasetInterface, \
                    -    CoCoSegmentationDatasetInterface, DetectionTestDatasetInterface
                    -from super_gradients.training.utils.segmentation_utils import coco_sub_classes_inclusion_tuples_list
                    +from super_gradients.training.dataloaders import imagenet_val, imagenet_vit_base_val
                    +from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader, coco2017_val_yolox, \
                    +    coco2017_val_ssd_lite_mobilenet_v2, detection_test_dataloader, coco_segmentation_val, cityscapes_val, \
                    +    cityscapes_stdc_seg50_val, cityscapes_stdc_seg75_val, segmentation_test_dataloader
                    +from super_gradients.training.utils.detection_utils import CrowdDetectionCollateFN
                    +
                     from super_gradients.training.metrics import Accuracy, IoU
                     from super_gradients.training.metrics import Accuracy, IoU
                     import os
                     import os
                     import shutil
                     import shutil
                    @@ -13,13 +15,10 @@ from super_gradients.training.utils.ssd_utils import SSDPostPredictCallback
                     from super_gradients.training.models.detection_models.ssd import DEFAULT_SSD_LITE_MOBILENET_V2_ARCH_PARAMS
                     from super_gradients.training.models.detection_models.ssd import DEFAULT_SSD_LITE_MOBILENET_V2_ARCH_PARAMS
                     from super_gradients.training.losses.ddrnet_loss import DDRNetLoss
                     from super_gradients.training.losses.ddrnet_loss import DDRNetLoss
                     from super_gradients.training.metrics import DetectionMetrics
                     from super_gradients.training.metrics import DetectionMetrics
                    -from super_gradients.training.transforms.transforms import Rescale
                     from super_gradients.training.losses.stdc_loss import STDCLoss
                     from super_gradients.training.losses.stdc_loss import STDCLoss
                     from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
                     from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
                    -from super_gradients.training.utils.detection_utils import DetectionCollateFN, CrowdDetectionCollateFN
                    -from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CoCoDetectionDatasetInterface
                    -from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
                     from super_gradients.training import models
                     from super_gradients.training import models
                    +import super_gradients
                     
                     
                     
                     
                     class PretrainedModelsTest(unittest.TestCase):
                     class PretrainedModelsTest(unittest.TestCase):
                    @@ -34,6 +33,15 @@ class PretrainedModelsTest(unittest.TestCase):
                                                                         {"image_size": (224, 224),
                                                                         {"image_size": (224, 224),
                                                                          "patch_size": (16, 16)}}
                                                                          "patch_size": (16, 16)}}
                     
                     
                    +        self.imagenet_pretrained_trainsfer_learning_arch_params = {"resnet": {},
                    +                                                                   "regnet": {},
                    +                                                                   "repvgg_a0": {"build_residual_branches": True},
                    +                                                                   "efficientnet_b0": {},
                    +                                                                   "mobilenet": {},
                    +                                                                   "vit_base":
                    +                                                                       {"image_size": (224, 224),
                    +                                                                        "patch_size": (16, 16)}}
                    +
                             self.imagenet_pretrained_ckpt_params = {"pretrained_weights": "imagenet"}
                             self.imagenet_pretrained_ckpt_params = {"pretrained_weights": "imagenet"}
                     
                     
                             self.imagenet21k_pretrained_ckpt_params = {"pretrained_weights": "imagenet21k"}
                             self.imagenet21k_pretrained_ckpt_params = {"pretrained_weights": "imagenet21k"}
                    @@ -54,16 +62,11 @@ class PretrainedModelsTest(unittest.TestCase):
                                                                    "vit_large": 0.8564,
                                                                    "vit_large": 0.8564,
                                                                    "beit_base_patch16_224": 0.85
                                                                    "beit_base_patch16_224": 0.85
                                                                    }
                                                                    }
                    -        self.imagenet_dataset = ImageNetDatasetInterface(data_dir="/data/Imagenet", dataset_params={"batch_size": 128})
                    +        self.imagenet_dataset = imagenet_val(dataloader_params={"batch_size": 128})
                     
                     
                    -        self.imagenet_dataset_05_mean_std = ImageNetDatasetInterface(data_dir="/data/Imagenet",
                    -                                                                     dataset_params={"batch_size": 128,
                    -                                                                                     "img_mean": [0.5, 0.5, 0.5],
                    -                                                                                     "img_std": [0.5, 0.5, 0.5],
                    -                                                                                     "resize_size": 248
                    -                                                                                     })
                    +        self.imagenet_dataset_05_mean_std = imagenet_vit_base_val(dataloader_params={"batch_size": 128})
                     
                     
                    -        self.transfer_classification_dataset = ClassificationTestDatasetInterface(image_size=224)
                    +        self.transfer_classification_dataloader = classification_test_dataloader(image_size=224)
                     
                     
                             self.transfer_classification_train_params = {"max_epochs": 3,
                             self.transfer_classification_train_params = {"max_epochs": 3,
                                                                          "lr_updates": [1],
                                                                          "lr_updates": [1],
                    @@ -83,76 +86,13 @@ class PretrainedModelsTest(unittest.TestCase):
                                                                 'coco_ssd_mobilenet_v1': {'num_classes': 80}}
                                                                 'coco_ssd_mobilenet_v1': {'num_classes': 80}}
                             self.coco_pretrained_ckpt_params = {"pretrained_weights": "coco"}
                             self.coco_pretrained_ckpt_params = {"pretrained_weights": "coco"}
                     
                     
                    -        from super_gradients.training.transforms.transforms import DetectionMosaic, DetectionMixup, \
                    -            DetectionRandomAffine, \
                    -            DetectionTargetsFormatTransform, DetectionPaddedRescale, DetectionHSV, DetectionHorizontalFlip
                    -
                    -        yolox_train_transforms = [DetectionMosaic(input_dim=(640, 640), prob=1.0),
                    -                                  DetectionRandomAffine(degrees=10., translate=0.1, scales=[0.1, 2], shear=2.0,
                    -                                                        target_size=(640, 640),
                    -                                                        filter_box_candidates=False, wh_thr=0, area_thr=0, ar_thr=0),
                    -                                  DetectionMixup(input_dim=(640, 640), mixup_scale=[0.5, 1.5], prob=1.0, flip_prob=0.5),
                    -                                  DetectionHSV(prob=1.0, hgain=5, sgain=30, vgain=30),
                    -                                  DetectionHorizontalFlip(prob=0.5),
                    -                                  DetectionPaddedRescale(input_dim=(640, 640), max_targets=120),
                    -                                  DetectionTargetsFormatTransform(output_format=DetectionTargetsFormat.LABEL_CXCYWH)]
                    -        yolox_val_transforms = [DetectionPaddedRescale(input_dim=(640, 640)),
                    -                                DetectionTargetsFormatTransform(max_targets=50,
                    -                                                                output_format=DetectionTargetsFormat.LABEL_CXCYWH)]
                    -
                    -        ssd_train_transforms = [DetectionMosaic(input_dim=(640, 640), prob=1.0),
                    -                                DetectionRandomAffine(degrees=0., translate=0.1, scales=[0.5, 1.5], shear=.0,
                    -                                                      target_size=(640, 640),
                    -                                                      filter_box_candidates=True, wh_thr=2, area_thr=0.1, ar_thr=20),
                    -                                DetectionMixup(input_dim=(640, 640), mixup_scale=[0.5, 1.5], prob=0., flip_prob=0.),
                    -                                DetectionHSV(prob=.0, hgain=5, sgain=30, vgain=30),
                    -                                DetectionHorizontalFlip(prob=0.),
                    -                                DetectionPaddedRescale(input_dim=(640, 640), max_targets=120),
                    -                                DetectionTargetsFormatTransform(output_format=DetectionTargetsFormat.LABEL_CXCYWH)]
                    -        ssd_val_transforms = [DetectionPaddedRescale(input_dim=(640, 640)),
                    -                              DetectionTargetsFormatTransform(max_targets=50,
                    -                                                              output_format=DetectionTargetsFormat.LABEL_CXCYWH)]
                    -
                             self.coco_dataset = {
                             self.coco_dataset = {
                    -            'yolox': CoCoDetectionDatasetInterface(
                    -                dataset_params={"data_dir": "/data/coco",
                    -                                "train_subdir": "images/train2017",
                    -                                "val_subdir": "images/val2017",
                    -                                "train_json_file": "instances_train2017.json",
                    -                                "val_json_file": "instances_val2017.json",
                    -                                "batch_size": 16,
                    -                                "val_batch_size": 128,
                    -                                "val_image_size": 640,
                    -                                "train_image_size": 640,
                    -                                "train_transforms": yolox_train_transforms,
                    -                                "val_transforms": yolox_val_transforms,
                    -
                    -                                "val_collate_fn": CrowdDetectionCollateFN(),
                    -                                "train_collate_fn": DetectionCollateFN(),
                    -                                "cache_dir_path": None,
                    -                                "cache_train_images": False,
                    -                                "cache_val_images": False,
                    -                                "with_crowd": True}),
                    -
                    -            'ssd_mobilenet': CoCoDetectionDatasetInterface(dataset_params={"data_dir": "/data/coco",
                    -                                                                           "train_subdir": "images/train2017",
                    -                                                                           "val_subdir": "images/val2017",
                    -                                                                           "train_json_file": "instances_train2017.json",
                    -                                                                           "val_json_file": "instances_val2017.json",
                    -                                                                           "batch_size": 16,
                    -                                                                           "val_batch_size": 128,
                    -                                                                           "val_image_size": 320,
                    -                                                                           "train_image_size": 320,
                    -                                                                           "train_transforms": ssd_train_transforms,
                    -                                                                           "val_transforms": ssd_val_transforms,
                    -
                    -                                                                           "val_collate_fn": CrowdDetectionCollateFN(),
                    -                                                                           "train_collate_fn": DetectionCollateFN(),
                    -                                                                           "cache_dir_path": None,
                    -                                                                           "cache_train_images": False,
                    -                                                                           "cache_val_images": False,
                    -                                                                           "with_crowd": True})
                    -        }
                    +            'yolox': coco2017_val_yolox(dataloader_params={"collate_fn": CrowdDetectionCollateFN()},
                    +                                        dataset_params={"with_crowd": True}),
                    +
                    +            'ssd_mobilenet': coco2017_val_ssd_lite_mobilenet_v2(
                    +                dataloader_params={"collate_fn": CrowdDetectionCollateFN()},
                    +                dataset_params={"with_crowd": True})}
                     
                     
                             self.coco_pretrained_maps = {'ssd_lite_mobilenet_v2': 0.2052,
                             self.coco_pretrained_maps = {'ssd_lite_mobilenet_v2': 0.2052,
                                                          'coco_ssd_mobilenet_v1': 0.243,
                                                          'coco_ssd_mobilenet_v1': 0.243,
                    @@ -162,70 +102,60 @@ class PretrainedModelsTest(unittest.TestCase):
                                                          "yolox_n": 0.2677,
                                                          "yolox_n": 0.2677,
                                                          "yolox_t": 0.3718}
                                                          "yolox_t": 0.3718}
                     
                     
                    -        self.transfer_detection_dataset = DetectionTestDatasetInterface(image_size=320, classes=['class1', 'class2'])
                    +        self.transfer_detection_dataset = detection_test_dataloader()
                     
                     
                             ssd_dboxes = DEFAULT_SSD_LITE_MOBILENET_V2_ARCH_PARAMS['anchors']
                             ssd_dboxes = DEFAULT_SSD_LITE_MOBILENET_V2_ARCH_PARAMS['anchors']
                    -        self.transfer_detection_train_params = {
                    -            'ssd_lite_mobilenet_v2':
                    -                {
                    -                    "max_epochs": 3,
                    -                    "lr_mode": "cosine",
                    -                    "initial_lr": 0.01,
                    -                    "cosine_final_lr_ratio": 0.01,
                    -                    "lr_warmup_epochs": 3,
                    -                    "batch_accumulate": 1,
                    -                    "loss": "ssd_loss",
                    -                    "criterion_params": {"dboxes": ssd_dboxes},
                    -                    "optimizer": "SGD",
                    -                    "warmup_momentum": 0.8,
                    -                    "optimizer_params": {"momentum": 0.937,
                    -                                         "weight_decay": 0.0005,
                    -                                         "nesterov": True},
                    -                    "train_metrics_list": [],
                    -                    "valid_metrics_list": [
                    -                        DetectionMetrics(
                    -                            post_prediction_callback=SSDPostPredictCallback(),
                    -                            num_cls=len(self.transfer_detection_dataset.classes))],
                    -                    "loss_logging_items_names": ['smooth_l1', 'closs', 'Loss'],
                    -                    "metric_to_watch": "mAP@0.50:0.95",
                    -                    "greater_metric_to_watch_is_better": True
                    -                },
                    -            "yolox":
                    -                {"max_epochs": 3,
                    -                 "lr_mode": "cosine",
                    -                 "cosine_final_lr_ratio": 0.05,
                    -                 "warmup_bias_lr": 0.0,
                    -                 "warmup_momentum": 0.9,
                    -                 "initial_lr": 0.02,
                    -                 "loss": "yolox_loss",
                    -                 "criterion_params": {
                    -                     "strides": [8, 16, 32],  # output strides of all yolo outputs
                    -                     "num_classes": len(self.transfer_detection_dataset.classes)},
                    -
                    -                 "loss_logging_items_names": ["iou", "obj", "cls", "l1", "num_fg", "Loss"],
                    -
                    -                 "train_metrics_list": [],
                    -                 "valid_metrics_list": [
                    -                     DetectionMetrics(
                    -                         post_prediction_callback=YoloPostPredictionCallback(),
                    -                         normalize_targets=True,
                    -                         num_cls=len(self.transfer_detection_dataset.classes))],
                    -                 "metric_to_watch": 'mAP@0.50:0.95',
                    -                 "greater_metric_to_watch_is_better": True}
                    +        self.transfer_detection_train_params_ssd = {
                    +            "max_epochs": 3,
                    +            "lr_mode": "cosine",
                    +            "initial_lr": 0.01,
                    +            "cosine_final_lr_ratio": 0.01,
                    +            "lr_warmup_epochs": 3,
                    +            "batch_accumulate": 1,
                    +            "loss": "ssd_loss",
                    +            "criterion_params": {"dboxes": ssd_dboxes},
                    +            "optimizer": "SGD",
                    +            "warmup_momentum": 0.8,
                    +            "optimizer_params": {"momentum": 0.937,
                    +                                 "weight_decay": 0.0005,
                    +                                 "nesterov": True},
                    +            "train_metrics_list": [],
                    +            "valid_metrics_list": [
                    +                DetectionMetrics(
                    +                    post_prediction_callback=SSDPostPredictCallback(),
                    +                    num_cls=5)],
                    +            "loss_logging_items_names": ['smooth_l1', 'closs', 'Loss'],
                    +            "metric_to_watch": "mAP@0.50:0.95",
                    +            "greater_metric_to_watch_is_better": True
                             }
                             }
                    +        self.transfer_detection_train_params_yolox = {"max_epochs": 3,
                    +                                                      "lr_mode": "cosine",
                    +                                                      "cosine_final_lr_ratio": 0.05,
                    +                                                      "warmup_bias_lr": 0.0,
                    +                                                      "warmup_momentum": 0.9,
                    +                                                      "initial_lr": 0.02,
                    +                                                      "loss": "yolox_loss",
                    +                                                      "criterion_params": {
                    +                                                          "strides": [8, 16, 32],  # output strides of all yolo outputs
                    +                                                          "num_classes": 5},
                    +
                    +                                                      "loss_logging_items_names": ["iou", "obj", "cls", "l1", "num_fg",
                    +                                                                                   "Loss"],
                    +
                    +                                                      "train_metrics_list": [],
                    +                                                      "valid_metrics_list": [
                    +                                                          DetectionMetrics(
                    +                                                              post_prediction_callback=YoloPostPredictionCallback(),
                    +                                                              normalize_targets=True,
                    +                                                              num_cls=5)],
                    +                                                      "metric_to_watch": 'mAP@0.50:0.95',
                    +                                                      "greater_metric_to_watch_is_better": True}
                     
                     
                             self.coco_segmentation_subclass_pretrained_arch_params = {
                             self.coco_segmentation_subclass_pretrained_arch_params = {
                                 "shelfnet34_lw": {"num_classes": 21, "image_size": 512}}
                                 "shelfnet34_lw": {"num_classes": 21, "image_size": 512}}
                             self.coco_segmentation_subclass_pretrained_ckpt_params = {"pretrained_weights": "coco_segmentation_subclass"}
                             self.coco_segmentation_subclass_pretrained_ckpt_params = {"pretrained_weights": "coco_segmentation_subclass"}
                             self.coco_segmentation_subclass_pretrained_mious = {"shelfnet34_lw": 0.651}
                             self.coco_segmentation_subclass_pretrained_mious = {"shelfnet34_lw": 0.651}
                    -        self.coco_segmentation_dataset = CoCoSegmentationDatasetInterface(dataset_params={
                    -            "batch_size": 24,
                    -            "val_batch_size": 24,
                    -            "dataset_dir": "/data/coco/",
                    -            "img_size": 608,
                    -            "crop_size": 512
                    -        }, dataset_classes_inclusion_tuples_list=coco_sub_classes_inclusion_tuples_list()
                    -        )
                    +        self.coco_segmentation_dataset = coco_segmentation_val()
                     
                     
                             self.cityscapes_pretrained_models = ["ddrnet_23", "ddrnet_23_slim", "stdc1_seg50", "regseg48"]
                             self.cityscapes_pretrained_models = ["ddrnet_23", "ddrnet_23_slim", "stdc1_seg50", "regseg48"]
                             self.cityscapes_pretrained_arch_params = {
                             self.cityscapes_pretrained_arch_params = {
                    @@ -248,31 +178,12 @@ class PretrainedModelsTest(unittest.TestCase):
                                                                 "pp_lite_b_seg50": 0.7648,
                                                                 "pp_lite_b_seg50": 0.7648,
                                                                 "pp_lite_b_seg75": 0.7852}
                                                                 "pp_lite_b_seg75": 0.7852}
                     
                     
                    -        self.cityscapes_dataset = CityscapesDatasetInterface(dataset_params={
                    -            "batch_size": 3,
                    -            "val_batch_size": 3,
                    -            "dataset_dir": "/data/cityscapes/",
                    -            "crop_size": 1024,
                    -            "img_size": 1024,
                    -            "image_mask_transforms_aug": [],
                    -            "image_mask_transforms": []  # no transform for evaluation
                    -        }, cache_labels=False)
                    -
                    -        self.cityscapes_dataset_rescaled50 = CityscapesDatasetInterface(dataset_params={
                    -            "batch_size": 3,
                    -            "val_batch_size": 3,
                    -            "image_mask_transforms_aug": [],
                    -            "image_mask_transforms": [Rescale(scale_factor=0.5)]  # no transform for evaluation
                    -        }, cache_labels=False)
                    -
                    -        self.cityscapes_dataset_rescaled75 = CityscapesDatasetInterface(dataset_params={
                    -            "batch_size": 3,
                    -            "val_batch_size": 3,
                    -            "image_mask_transforms_aug": [],
                    -            "image_mask_transforms": [Rescale(scale_factor=0.75)]  # no transform for evaluation
                    -        }, cache_labels=False)
                    -
                    -        self.transfer_segmentation_dataset = SegmentationTestDatasetInterface(image_size=1024)
                    +        self.cityscapes_dataset = cityscapes_val()
                    +
                    +        self.cityscapes_dataset_rescaled50 = cityscapes_stdc_seg50_val()
                    +        self.cityscapes_dataset_rescaled75 = cityscapes_stdc_seg75_val()
                    +
                    +        self.transfer_segmentation_dataset = segmentation_test_dataloader(image_size=1024)
                             self.ddrnet_transfer_segmentation_train_params = {"max_epochs": 3,
                             self.ddrnet_transfer_segmentation_train_params = {"max_epochs": 3,
                                                                               "initial_lr": 1e-2,
                                                                               "initial_lr": 1e-2,
                                                                               "loss": DDRNetLoss(),
                                                                               "loss": DDRNetLoss(),
                    @@ -327,6 +238,7 @@ class PretrainedModelsTest(unittest.TestCase):
                                 "train_metrics_list": [IoU(5)],
                                 "train_metrics_list": [IoU(5)],
                                 "valid_metrics_list": [IoU(5)],
                                 "valid_metrics_list": [IoU(5)],
                                 "loss_logging_items_names": ["loss"],
                                 "loss_logging_items_names": ["loss"],
                    +
                                 "metric_to_watch": "IoU",
                                 "metric_to_watch": "IoU",
                                 "greater_metric_to_watch_is_better": True
                                 "greater_metric_to_watch_is_better": True
                             }
                             }
                    @@ -334,154 +246,160 @@ class PretrainedModelsTest(unittest.TestCase):
                         def test_pretrained_resnet50_imagenet(self):
                         def test_pretrained_resnet50_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_resnet50', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_resnet50', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
                             model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                             model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                                                **self.imagenet_pretrained_ckpt_params)
                                                **self.imagenet_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                    +        res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet50"], delta=0.001)
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet50"], delta=0.001)
                     
                     
                         def test_transfer_learning_resnet50_imagenet(self):
                         def test_transfer_learning_resnet50_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_resnet50_transfer_learning', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_resnet50_transfer_learning', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
                             model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                             model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                    -                           **self.imagenet_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.transfer_classification_train_params)
                    +                           **self.imagenet_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.transfer_classification_train_params,
                    +                      train_loader=self.transfer_classification_dataloader,
                    +                      valid_loader=self.transfer_classification_dataloader)
                     
                     
                         def test_pretrained_resnet34_imagenet(self):
                         def test_pretrained_resnet34_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_resnet34', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_resnet34', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
                    +
                             model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                             model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                                                **self.imagenet_pretrained_ckpt_params)
                                                **self.imagenet_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                    +        res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet34"], delta=0.001)
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet34"], delta=0.001)
                     
                     
                         def test_transfer_learning_resnet34_imagenet(self):
                         def test_transfer_learning_resnet34_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_resnet34_transfer_learning', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_resnet34_transfer_learning', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
                             model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                             model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                    -                           **self.imagenet_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.transfer_classification_train_params)
                    +                           **self.imagenet_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.transfer_classification_train_params,
                    +                      train_loader=self.transfer_classification_dataloader,
                    +                      valid_loader=self.transfer_classification_dataloader)
                     
                     
                         def test_pretrained_resnet18_imagenet(self):
                         def test_pretrained_resnet18_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_resnet18', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_resnet18', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
                    +
                             model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                             model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                                                **self.imagenet_pretrained_ckpt_params)
                                                **self.imagenet_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                    +        res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet18"], delta=0.001)
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet18"], delta=0.001)
                     
                     
                         def test_transfer_learning_resnet18_imagenet(self):
                         def test_transfer_learning_resnet18_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_resnet18_transfer_learning', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_resnet18_transfer_learning', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
                             model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                             model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                    -                           **self.imagenet_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.transfer_classification_train_params)
                    +                           **self.imagenet_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.transfer_classification_train_params,
                    +                      train_loader=self.transfer_classification_dataloader,
                    +                      valid_loader=self.transfer_classification_dataloader)
                     
                     
                         def test_pretrained_regnetY800_imagenet(self):
                         def test_pretrained_regnetY800_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_regnetY800', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_regnetY800', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
                    +
                             model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                             model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                                                **self.imagenet_pretrained_ckpt_params)
                                                **self.imagenet_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                    +        res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY800"], delta=0.001)
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY800"], delta=0.001)
                     
                     
                         def test_transfer_learning_regnetY800_imagenet(self):
                         def test_transfer_learning_regnetY800_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_regnetY800_transfer_learning', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_regnetY800_transfer_learning', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
                             model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                             model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                    -                           **self.imagenet_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.transfer_classification_train_params)
                    +                           **self.imagenet_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.transfer_classification_train_params,
                    +                      train_loader=self.transfer_classification_dataloader,
                    +                      valid_loader=self.transfer_classification_dataloader)
                     
                     
                         def test_pretrained_regnetY600_imagenet(self):
                         def test_pretrained_regnetY600_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_regnetY600', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_regnetY600', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
                    +
                             model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                             model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                                                **self.imagenet_pretrained_ckpt_params)
                                                **self.imagenet_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                    +        res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY600"], delta=0.001)
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY600"], delta=0.001)
                     
                     
                         def test_transfer_learning_regnetY600_imagenet(self):
                         def test_transfer_learning_regnetY600_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_regnetY600_transfer_learning', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_regnetY600_transfer_learning', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
                             model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                             model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                    -                           **self.imagenet_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.transfer_classification_train_params)
                    +                           **self.imagenet_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.transfer_classification_train_params,
                    +                      train_loader=self.transfer_classification_dataloader,
                    +                      valid_loader=self.transfer_classification_dataloader)
                     
                     
                         def test_pretrained_regnetY400_imagenet(self):
                         def test_pretrained_regnetY400_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_regnetY400', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_regnetY400', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
                    +
                             model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                             model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                                                **self.imagenet_pretrained_ckpt_params)
                                                **self.imagenet_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                    +        res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY400"], delta=0.001)
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY400"], delta=0.001)
                     
                     
                         def test_transfer_learning_regnetY400_imagenet(self):
                         def test_transfer_learning_regnetY400_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_regnetY400_transfer_learning', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_regnetY400_transfer_learning', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
                             model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                             model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                    -                           **self.imagenet_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.transfer_classification_train_params)
                    +                           **self.imagenet_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.transfer_classification_train_params,
                    +                      train_loader=self.transfer_classification_dataloader,
                    +                      valid_loader=self.transfer_classification_dataloader)
                     
                     
                         def test_pretrained_regnetY200_imagenet(self):
                         def test_pretrained_regnetY200_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_regnetY200', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_regnetY200', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
                    +
                             model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                             model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                                                **self.imagenet_pretrained_ckpt_params)
                                                **self.imagenet_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                    +        res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY200"], delta=0.001)
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY200"], delta=0.001)
                     
                     
                         def test_transfer_learning_regnetY200_imagenet(self):
                         def test_transfer_learning_regnetY200_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_regnetY200_transfer_learning', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_regnetY200_transfer_learning', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
                             model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                             model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                    -                           **self.imagenet_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.transfer_classification_train_params)
                    +                           **self.imagenet_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.transfer_classification_train_params,
                    +                      train_loader=self.transfer_classification_dataloader,
                    +                      valid_loader=self.transfer_classification_dataloader)
                     
                     
                         def test_pretrained_repvgg_a0_imagenet(self):
                         def test_pretrained_repvgg_a0_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_repvgg_a0', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_repvgg_a0', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
                    +
                             model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"],
                             model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"],
                                                **self.imagenet_pretrained_ckpt_params)
                                                **self.imagenet_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                    +        res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["repvgg_a0"], delta=0.001)
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["repvgg_a0"], delta=0.001)
                     
                     
                         def test_transfer_learning_repvgg_a0_imagenet(self):
                         def test_transfer_learning_repvgg_a0_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_repvgg_a0_transfer_learning', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_repvgg_a0_transfer_learning', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
                             model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"],
                             model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"],
                    -                           **self.imagenet_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.transfer_classification_train_params)
                    +                           **self.imagenet_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.transfer_classification_train_params,
                    +                      train_loader=self.transfer_classification_dataloader,
                    +                      valid_loader=self.transfer_classification_dataloader)
                     
                     
                         def test_pretrained_regseg48_cityscapes(self):
                         def test_pretrained_regseg48_cityscapes(self):
                             trainer = Trainer('cityscapes_pretrained_regseg48', model_checkpoints_location='local',
                             trainer = Trainer('cityscapes_pretrained_regseg48', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.cityscapes_dataset, data_loader_num_workers=8)
                             model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"],
                             model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"],
                                                **self.cityscapes_pretrained_ckpt_params)
                                                **self.cityscapes_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.cityscapes_dataset.val_loader,
                    +        res = trainer.test(model=model, test_loader=self.cityscapes_dataset,
                                                test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                                                test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["regseg48"], delta=0.001)
                             self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["regseg48"], delta=0.001)
                    @@ -489,18 +407,18 @@ class PretrainedModelsTest(unittest.TestCase):
                         def test_transfer_learning_regseg48_cityscapes(self):
                         def test_transfer_learning_regseg48_cityscapes(self):
                             trainer = Trainer('regseg48_cityscapes_transfer_learning', model_checkpoints_location='local',
                             trainer = Trainer('regseg48_cityscapes_transfer_learning', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
                             model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"],
                             model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"],
                                                **self.cityscapes_pretrained_ckpt_params)
                                                **self.cityscapes_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.regseg_transfer_segmentation_train_params)
                    +        trainer.train(model=model, train_loader=self.transfer_segmentation_dataset,
                    +                      valid_loader=self.transfer_segmentation_dataset,
                    +                      training_params=self.regseg_transfer_segmentation_train_params)
                     
                     
                         def test_pretrained_ddrnet23_cityscapes(self):
                         def test_pretrained_ddrnet23_cityscapes(self):
                             trainer = Trainer('cityscapes_pretrained_ddrnet23', model_checkpoints_location='local',
                             trainer = Trainer('cityscapes_pretrained_ddrnet23', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.cityscapes_dataset, data_loader_num_workers=8)
                             model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
                             model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
                                                **self.cityscapes_pretrained_ckpt_params)
                                                **self.cityscapes_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.cityscapes_dataset.val_loader,
                    +        res = trainer.test(model=model, test_loader=self.cityscapes_dataset,
                                                test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                                                test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23"], delta=0.001)
                             self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23"], delta=0.001)
                    @@ -508,10 +426,9 @@ class PretrainedModelsTest(unittest.TestCase):
                         def test_pretrained_ddrnet23_slim_cityscapes(self):
                         def test_pretrained_ddrnet23_slim_cityscapes(self):
                             trainer = Trainer('cityscapes_pretrained_ddrnet23_slim', model_checkpoints_location='local',
                             trainer = Trainer('cityscapes_pretrained_ddrnet23_slim', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.cityscapes_dataset, data_loader_num_workers=8)
                             model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
                             model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
                                                **self.cityscapes_pretrained_ckpt_params)
                                                **self.cityscapes_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.cityscapes_dataset.val_loader,
                    +        res = trainer.test(model=model, test_loader=self.cityscapes_dataset,
                                                test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                                                test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23_slim"], delta=0.001)
                             self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23_slim"], delta=0.001)
                    @@ -519,94 +436,94 @@ class PretrainedModelsTest(unittest.TestCase):
                         def test_transfer_learning_ddrnet23_cityscapes(self):
                         def test_transfer_learning_ddrnet23_cityscapes(self):
                             trainer = Trainer('cityscapes_pretrained_ddrnet23_transfer_learning', model_checkpoints_location='local',
                             trainer = Trainer('cityscapes_pretrained_ddrnet23_transfer_learning', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
                             model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
                             model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
                                                **self.cityscapes_pretrained_ckpt_params)
                                                **self.cityscapes_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.ddrnet_transfer_segmentation_train_params)
                    +        trainer.train(model=model, training_params=self.ddrnet_transfer_segmentation_train_params,
                    +                      train_loader=self.transfer_segmentation_dataset,
                    +                      valid_loader=self.transfer_segmentation_dataset)
                     
                     
                         def test_transfer_learning_ddrnet23_slim_cityscapes(self):
                         def test_transfer_learning_ddrnet23_slim_cityscapes(self):
                             trainer = Trainer('cityscapes_pretrained_ddrnet23_slim_transfer_learning', model_checkpoints_location='local',
                             trainer = Trainer('cityscapes_pretrained_ddrnet23_slim_transfer_learning', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
                             model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
                             model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
                                                **self.cityscapes_pretrained_ckpt_params)
                                                **self.cityscapes_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.ddrnet_transfer_segmentation_train_params)
                    +        trainer.train(model=model, training_params=self.ddrnet_transfer_segmentation_train_params,
                    +                      train_loader=self.transfer_segmentation_dataset,
                    +                      valid_loader=self.transfer_segmentation_dataset)
                     
                     
                         def test_pretrained_coco_segmentation_subclass_pretrained_shelfnet34_lw(self):
                         def test_pretrained_coco_segmentation_subclass_pretrained_shelfnet34_lw(self):
                             trainer = Trainer('coco_segmentation_subclass_pretrained_shelfnet34_lw', model_checkpoints_location='local',
                             trainer = Trainer('coco_segmentation_subclass_pretrained_shelfnet34_lw', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.coco_segmentation_dataset, data_loader_num_workers=8)
                             model = models.get("shelfnet34_lw",
                             model = models.get("shelfnet34_lw",
                                                arch_params=self.coco_segmentation_subclass_pretrained_arch_params["shelfnet34_lw"],
                                                arch_params=self.coco_segmentation_subclass_pretrained_arch_params["shelfnet34_lw"],
                                                **self.coco_segmentation_subclass_pretrained_ckpt_params)
                                                **self.coco_segmentation_subclass_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.coco_segmentation_dataset.val_loader,
                    +        res = trainer.test(model=model, test_loader=self.coco_segmentation_dataset,
                                                test_metrics_list=[IoU(21)], metrics_progress_verbose=True)[0].cpu().item()
                                                test_metrics_list=[IoU(21)], metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.coco_segmentation_subclass_pretrained_mious["shelfnet34_lw"], delta=0.001)
                             self.assertAlmostEqual(res, self.coco_segmentation_subclass_pretrained_mious["shelfnet34_lw"], delta=0.001)
                     
                     
                         def test_pretrained_efficientnet_b0_imagenet(self):
                         def test_pretrained_efficientnet_b0_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_efficientnet_b0', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_efficientnet_b0', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
                    +
                             model = models.get("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"],
                             model = models.get("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"],
                                                **self.imagenet_pretrained_ckpt_params)
                                                **self.imagenet_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                    +        res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["efficientnet_b0"], delta=0.001)
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["efficientnet_b0"], delta=0.001)
                     
                     
                         def test_transfer_learning_efficientnet_b0_imagenet(self):
                         def test_transfer_learning_efficientnet_b0_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_efficientnet_b0_transfer_learning', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_efficientnet_b0_transfer_learning', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
                    +
                             model = models.get("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"],
                             model = models.get("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"],
                    -                           **self.imagenet_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.transfer_classification_train_params)
                    +                           **self.imagenet_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.transfer_classification_train_params,
                    +                      train_loader=self.transfer_classification_dataloader,
                    +                      valid_loader=self.transfer_classification_dataloader)
                     
                     
                         def test_pretrained_ssd_lite_mobilenet_v2_coco(self):
                         def test_pretrained_ssd_lite_mobilenet_v2_coco(self):
                             trainer = Trainer('coco_ssd_lite_mobilenet_v2', model_checkpoints_location='local',
                             trainer = Trainer('coco_ssd_lite_mobilenet_v2', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.coco_dataset['ssd_mobilenet'], data_loader_num_workers=8)
                             model = models.get("ssd_lite_mobilenet_v2",
                             model = models.get("ssd_lite_mobilenet_v2",
                                                arch_params=self.coco_pretrained_arch_params["ssd_lite_mobilenet_v2"],
                                                arch_params=self.coco_pretrained_arch_params["ssd_lite_mobilenet_v2"],
                                                **self.coco_pretrained_ckpt_params)
                                                **self.coco_pretrained_ckpt_params)
                             ssd_post_prediction_callback = SSDPostPredictCallback()
                             ssd_post_prediction_callback = SSDPostPredictCallback()
                    -        res = trainer.test(model=model, test_loader=self.coco_dataset['ssd_mobilenet'].val_loader, test_metrics_list=[
                    +        res = trainer.test(model=model, test_loader=self.coco_dataset['ssd_mobilenet'], test_metrics_list=[
                                 DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback, num_cls=80)], metrics_progress_verbose=True)[2]
                                 DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback, num_cls=80)], metrics_progress_verbose=True)[2]
                             self.assertAlmostEqual(res, self.coco_pretrained_maps["ssd_lite_mobilenet_v2"], delta=0.001)
                             self.assertAlmostEqual(res, self.coco_pretrained_maps["ssd_lite_mobilenet_v2"], delta=0.001)
                     
                     
                         def test_transfer_learning_ssd_lite_mobilenet_v2_coco(self):
                         def test_transfer_learning_ssd_lite_mobilenet_v2_coco(self):
                             trainer = Trainer('coco_ssd_lite_mobilenet_v2_transfer_learning',
                             trainer = Trainer('coco_ssd_lite_mobilenet_v2_transfer_learning',
                                               model_checkpoints_location='local', multi_gpu=MultiGPUMode.OFF)
                                               model_checkpoints_location='local', multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_detection_dataset,
                    -                                          data_loader_num_workers=8)
                             transfer_arch_params = self.coco_pretrained_arch_params['ssd_lite_mobilenet_v2'].copy()
                             transfer_arch_params = self.coco_pretrained_arch_params['ssd_lite_mobilenet_v2'].copy()
                    -        transfer_arch_params['num_classes'] = len(self.transfer_detection_dataset.classes)
                    +        transfer_arch_params['num_classes'] = 5
                             model = models.get("ssd_lite_mobilenet_v2",
                             model = models.get("ssd_lite_mobilenet_v2",
                                                arch_params=transfer_arch_params,
                                                arch_params=transfer_arch_params,
                                                **self.coco_pretrained_ckpt_params)
                                                **self.coco_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.transfer_detection_train_params['ssd_lite_mobilenet_v2'])
                    +        trainer.train(model=model, training_params=self.transfer_detection_train_params_ssd,
                    +                      train_loader=self.transfer_detection_dataset,
                    +                      valid_loader=self.transfer_detection_dataset)
                     
                     
                         def test_pretrained_ssd_mobilenet_v1_coco(self):
                         def test_pretrained_ssd_mobilenet_v1_coco(self):
                             trainer = Trainer('coco_ssd_mobilenet_v1', model_checkpoints_location='local',
                             trainer = Trainer('coco_ssd_mobilenet_v1', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.coco_dataset['ssd_mobilenet'], data_loader_num_workers=8)
                             model = models.get("ssd_mobilenet_v1",
                             model = models.get("ssd_mobilenet_v1",
                                                arch_params=self.coco_pretrained_arch_params["coco_ssd_mobilenet_v1"],
                                                arch_params=self.coco_pretrained_arch_params["coco_ssd_mobilenet_v1"],
                                                **self.coco_pretrained_ckpt_params)
                                                **self.coco_pretrained_ckpt_params)
                             ssd_post_prediction_callback = SSDPostPredictCallback()
                             ssd_post_prediction_callback = SSDPostPredictCallback()
                    -        res = trainer.test(model=model, test_loader=self.coco_dataset['ssd_mobilenet'].val_loader,
                    +        res = trainer.test(model=model, test_loader=self.coco_dataset['ssd_mobilenet'],
                                                test_metrics_list=[DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback,
                                                test_metrics_list=[DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback,
                    -                                                               num_cls=len(
                    -                                                                   self.coco_dataset['ssd_mobilenet'].coco_classes))],
                    +                                                               num_cls=80)],
                                                metrics_progress_verbose=True)[2]
                                                metrics_progress_verbose=True)[2]
                             self.assertAlmostEqual(res, self.coco_pretrained_maps["coco_ssd_mobilenet_v1"], delta=0.001)
                             self.assertAlmostEqual(res, self.coco_pretrained_maps["coco_ssd_mobilenet_v1"], delta=0.001)
                     
                     
                         def test_pretrained_yolox_s_coco(self):
                         def test_pretrained_yolox_s_coco(self):
                             trainer = Trainer('yolox_s', model_checkpoints_location='local',
                             trainer = Trainer('yolox_s', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
                    +
                             model = models.get("yolox_s",
                             model = models.get("yolox_s",
                                                **self.coco_pretrained_ckpt_params)
                                                **self.coco_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.coco_dataset['yolox'].val_loader,
                    +        res = trainer.test(model=model, test_loader=self.coco_dataset['yolox'],
                                                test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                                                test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                                                                                    num_cls=80,
                                                                                    num_cls=80,
                                                                                    normalize_targets=True)])[2]
                                                                                    normalize_targets=True)])[2]
                    @@ -615,10 +532,9 @@ class PretrainedModelsTest(unittest.TestCase):
                         def test_pretrained_yolox_m_coco(self):
                         def test_pretrained_yolox_m_coco(self):
                             trainer = Trainer('yolox_m', model_checkpoints_location='local',
                             trainer = Trainer('yolox_m', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
                             model = models.get("yolox_m",
                             model = models.get("yolox_m",
                                                **self.coco_pretrained_ckpt_params)
                                                **self.coco_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.coco_dataset['yolox'].val_loader,
                    +        res = trainer.test(model=model, test_loader=self.coco_dataset['yolox'],
                                                test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                                                test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                                                                                    num_cls=80,
                                                                                    num_cls=80,
                                                                                    normalize_targets=True)])[2]
                                                                                    normalize_targets=True)])[2]
                    @@ -627,10 +543,9 @@ class PretrainedModelsTest(unittest.TestCase):
                         def test_pretrained_yolox_l_coco(self):
                         def test_pretrained_yolox_l_coco(self):
                             trainer = Trainer('yolox_l', model_checkpoints_location='local',
                             trainer = Trainer('yolox_l', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
                             model = models.get("yolox_l",
                             model = models.get("yolox_l",
                                                **self.coco_pretrained_ckpt_params)
                                                **self.coco_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.coco_dataset['yolox'].val_loader,
                    +        res = trainer.test(model=model, test_loader=self.coco_dataset['yolox'],
                                                test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                                                test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                                                                                    num_cls=80,
                                                                                    num_cls=80,
                                                                                    normalize_targets=True)])[2]
                                                                                    normalize_targets=True)])[2]
                    @@ -639,10 +554,10 @@ class PretrainedModelsTest(unittest.TestCase):
                         def test_pretrained_yolox_n_coco(self):
                         def test_pretrained_yolox_n_coco(self):
                             trainer = Trainer('yolox_n', model_checkpoints_location='local',
                             trainer = Trainer('yolox_n', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
                    +
                             model = models.get("yolox_n",
                             model = models.get("yolox_n",
                                                **self.coco_pretrained_ckpt_params)
                                                **self.coco_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.coco_dataset['yolox'].val_loader,
                    +        res = trainer.test(model=model, test_loader=self.coco_dataset['yolox'],
                                                test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                                                test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                                                                                    num_cls=80,
                                                                                    num_cls=80,
                                                                                    normalize_targets=True)])[2]
                                                                                    normalize_targets=True)])[2]
                    @@ -651,10 +566,9 @@ class PretrainedModelsTest(unittest.TestCase):
                         def test_pretrained_yolox_t_coco(self):
                         def test_pretrained_yolox_t_coco(self):
                             trainer = Trainer('yolox_t', model_checkpoints_location='local',
                             trainer = Trainer('yolox_t', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.coco_dataset['yolox'], data_loader_num_workers=8)
                             model = models.get("yolox_t",
                             model = models.get("yolox_t",
                                                **self.coco_pretrained_ckpt_params)
                                                **self.coco_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.coco_dataset['yolox'].val_loader,
                    +        res = trainer.test(model=model, test_loader=self.coco_dataset['yolox'],
                                                test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                                                test_metrics_list=[DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
                                                                                    num_cls=80,
                                                                                    num_cls=80,
                                                                                    normalize_targets=True)])[2]
                                                                                    normalize_targets=True)])[2]
                    @@ -664,26 +578,29 @@ class PretrainedModelsTest(unittest.TestCase):
                             trainer = Trainer('test_transfer_learning_yolox_n_coco',
                             trainer = Trainer('test_transfer_learning_yolox_n_coco',
                                               model_checkpoints_location='local',
                                               model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_detection_dataset, data_loader_num_workers=8)
                    -        model = models.get("yolox_n", **self.coco_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.transfer_detection_train_params["yolox"])
                    +        model = models.get("yolox_n", **self.coco_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.transfer_detection_train_params_yolox,
                    +                      train_loader=self.transfer_detection_dataset,
                    +                      valid_loader=self.transfer_detection_dataset)
                     
                     
                         def test_transfer_learning_mobilenet_v3_large_imagenet(self):
                         def test_transfer_learning_mobilenet_v3_large_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_mobilenet_v3_large_transfer_learning',
                             trainer = Trainer('imagenet_pretrained_mobilenet_v3_large_transfer_learning',
                                               model_checkpoints_location='local',
                                               model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
                    +
                             model = models.get("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
                             model = models.get("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
                    -                           **self.imagenet_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.transfer_classification_train_params)
                    +                           **self.imagenet_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.transfer_classification_train_params,
                    +                      train_loader=self.transfer_classification_dataloader,
                    +                      valid_loader=self.transfer_classification_dataloader)
                     
                     
                         def test_pretrained_mobilenet_v3_large_imagenet(self):
                         def test_pretrained_mobilenet_v3_large_imagenet(self):
                             trainer = Trainer('imagenet_mobilenet_v3_large', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_mobilenet_v3_large', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
                    +
                             model = models.get("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
                             model = models.get("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
                                                **self.imagenet_pretrained_ckpt_params)
                                                **self.imagenet_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                    +        res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_large"], delta=0.001)
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_large"], delta=0.001)
                     
                     
                    @@ -691,18 +608,20 @@ class PretrainedModelsTest(unittest.TestCase):
                             trainer = Trainer('imagenet_pretrained_mobilenet_v3_small_transfer_learning',
                             trainer = Trainer('imagenet_pretrained_mobilenet_v3_small_transfer_learning',
                                               model_checkpoints_location='local',
                                               model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
                    +
                             model = models.get("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
                             model = models.get("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
                    -                           **self.imagenet_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.transfer_classification_train_params)
                    +                           **self.imagenet_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.transfer_classification_train_params,
                    +                      train_loader=self.transfer_classification_dataloader,
                    +                      valid_loader=self.transfer_classification_dataloader)
                     
                     
                         def test_pretrained_mobilenet_v3_small_imagenet(self):
                         def test_pretrained_mobilenet_v3_small_imagenet(self):
                             trainer = Trainer('imagenet_mobilenet_v3_small', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_mobilenet_v3_small', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
                    +
                             model = models.get("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
                             model = models.get("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
                                                **self.imagenet_pretrained_ckpt_params)
                                                **self.imagenet_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                    +        res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_small"], delta=0.001)
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_small"], delta=0.001)
                     
                     
                    @@ -710,28 +629,29 @@ class PretrainedModelsTest(unittest.TestCase):
                             trainer = Trainer('imagenet_pretrained_mobilenet_v2_transfer_learning',
                             trainer = Trainer('imagenet_pretrained_mobilenet_v2_transfer_learning',
                                               model_checkpoints_location='local',
                                               model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
                    +
                             model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
                             model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
                    -                           **self.imagenet_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.transfer_classification_train_params)
                    +                           **self.imagenet_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.transfer_classification_train_params,
                    +                      train_loader=self.transfer_classification_dataloader,
                    +                      valid_loader=self.transfer_classification_dataloader)
                     
                     
                         def test_pretrained_mobilenet_v2_imagenet(self):
                         def test_pretrained_mobilenet_v2_imagenet(self):
                             trainer = Trainer('imagenet_mobilenet_v2', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_mobilenet_v2', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.imagenet_dataset, data_loader_num_workers=8)
                    +
                             model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
                             model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
                                                **self.imagenet_pretrained_ckpt_params)
                                                **self.imagenet_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.imagenet_dataset.val_loader, test_metrics_list=[Accuracy()],
                    +        res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v2"], delta=0.001)
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v2"], delta=0.001)
                     
                     
                         def test_pretrained_stdc1_seg50_cityscapes(self):
                         def test_pretrained_stdc1_seg50_cityscapes(self):
                             trainer = Trainer('cityscapes_pretrained_stdc1_seg50', model_checkpoints_location='local',
                             trainer = Trainer('cityscapes_pretrained_stdc1_seg50', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled50, data_loader_num_workers=8)
                             model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                             model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                                                **self.cityscapes_pretrained_ckpt_params)
                                                **self.cityscapes_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.cityscapes_dataset_rescaled50.val_loader,
                    +        res = trainer.test(model=model, test_loader=self.cityscapes_dataset_rescaled50,
                                                test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                                                test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg50"], delta=0.001)
                             self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg50"], delta=0.001)
                    @@ -739,18 +659,18 @@ class PretrainedModelsTest(unittest.TestCase):
                         def test_transfer_learning_stdc1_seg50_cityscapes(self):
                         def test_transfer_learning_stdc1_seg50_cityscapes(self):
                             trainer = Trainer('cityscapes_pretrained_stdc1_seg50_transfer_learning', model_checkpoints_location='local',
                             trainer = Trainer('cityscapes_pretrained_stdc1_seg50_transfer_learning', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
                             model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                             model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                    -                           **self.cityscapes_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.stdc_transfer_segmentation_train_params)
                    +                           **self.cityscapes_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.stdc_transfer_segmentation_train_params,
                    +                      train_loader=self.transfer_segmentation_dataset,
                    +                      valid_loader=self.transfer_segmentation_dataset)
                     
                     
                         def test_pretrained_stdc1_seg75_cityscapes(self):
                         def test_pretrained_stdc1_seg75_cityscapes(self):
                             trainer = Trainer('cityscapes_pretrained_stdc1_seg75', model_checkpoints_location='local',
                             trainer = Trainer('cityscapes_pretrained_stdc1_seg75', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled75, data_loader_num_workers=8)
                             model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                             model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                                                **self.cityscapes_pretrained_ckpt_params)
                                                **self.cityscapes_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.cityscapes_dataset_rescaled75.val_loader,
                    +        res = trainer.test(model=model, test_loader=self.cityscapes_dataset_rescaled75,
                                                test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                                                test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg75"], delta=0.001)
                             self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg75"], delta=0.001)
                    @@ -758,18 +678,18 @@ class PretrainedModelsTest(unittest.TestCase):
                         def test_transfer_learning_stdc1_seg75_cityscapes(self):
                         def test_transfer_learning_stdc1_seg75_cityscapes(self):
                             trainer = Trainer('cityscapes_pretrained_stdc1_seg75_transfer_learning', model_checkpoints_location='local',
                             trainer = Trainer('cityscapes_pretrained_stdc1_seg75_transfer_learning', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
                             model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                             model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                    -                           **self.cityscapes_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.stdc_transfer_segmentation_train_params)
                    +                           **self.cityscapes_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.stdc_transfer_segmentation_train_params,
                    +                      train_loader=self.transfer_segmentation_dataset,
                    +                      valid_loader=self.transfer_segmentation_dataset)
                     
                     
                         def test_pretrained_stdc2_seg50_cityscapes(self):
                         def test_pretrained_stdc2_seg50_cityscapes(self):
                             trainer = Trainer('cityscapes_pretrained_stdc2_seg50', model_checkpoints_location='local',
                             trainer = Trainer('cityscapes_pretrained_stdc2_seg50', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled50, data_loader_num_workers=8)
                             model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                             model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                                                **self.cityscapes_pretrained_ckpt_params)
                                                **self.cityscapes_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.cityscapes_dataset_rescaled50.val_loader,
                    +        res = trainer.test(model=model, test_loader=self.cityscapes_dataset_rescaled50,
                                                test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                                                test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg50"], delta=0.001)
                             self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg50"], delta=0.001)
                    @@ -777,122 +697,80 @@ class PretrainedModelsTest(unittest.TestCase):
                         def test_transfer_learning_stdc2_seg50_cityscapes(self):
                         def test_transfer_learning_stdc2_seg50_cityscapes(self):
                             trainer = Trainer('cityscapes_pretrained_stdc2_seg50_transfer_learning', model_checkpoints_location='local',
                             trainer = Trainer('cityscapes_pretrained_stdc2_seg50_transfer_learning', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
                             model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                             model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                    -                           **self.cityscapes_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.stdc_transfer_segmentation_train_params)
                    +                           **self.cityscapes_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.stdc_transfer_segmentation_train_params,
                    +                      train_loader=self.transfer_segmentation_dataset,
                    +                      valid_loader=self.transfer_segmentation_dataset)
                     
                     
                         def test_pretrained_stdc2_seg75_cityscapes(self):
                         def test_pretrained_stdc2_seg75_cityscapes(self):
                             trainer = Trainer('cityscapes_pretrained_stdc2_seg75', model_checkpoints_location='local',
                             trainer = Trainer('cityscapes_pretrained_stdc2_seg75', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled75, data_loader_num_workers=8)
                             model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                             model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                                                **self.cityscapes_pretrained_ckpt_params)
                                                **self.cityscapes_pretrained_ckpt_params)
                    -        res = trainer.test(model=model, test_loader=self.cityscapes_dataset_rescaled75.val_loader,
                    +        res = trainer.test(model=model, test_loader=self.cityscapes_dataset_rescaled75,
                                                test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                                                test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                                                metrics_progress_verbose=True)[0].cpu().item()
                                                metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg75"], delta=0.001)
                             self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg75"], delta=0.001)
                     
                     
                    -    def test_pretrained_pplite_t_seg50_cityscapes(self):
                    -        trainer = Trainer('cityscapes_pretrained_pplite_t_seg50', model_checkpoints_location='local',
                    -                          multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled50, data_loader_num_workers=8)
                    -        trainer.build_model("pp_lite_t_seg50", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"],
                    -                            checkpoint_params=self.cityscapes_pretrained_ckpt_params)
                    -        res = trainer.test(test_loader=self.cityscapes_dataset_rescaled50.val_loader,
                    -                           test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                    -                           metrics_progress_verbose=True)[0].cpu().item()
                    -        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_t_seg50"], delta=0.001)
                    -
                    -    def test_pretrained_pplite_t_seg75_cityscapes(self):
                    -        trainer = Trainer('cityscapes_pretrained_pplite_t_seg75', model_checkpoints_location='local',
                    -                          multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled75, data_loader_num_workers=8)
                    -        trainer.build_model("pp_lite_t_seg75", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"],
                    -                            checkpoint_params=self.cityscapes_pretrained_ckpt_params)
                    -        res = trainer.test(test_loader=self.cityscapes_dataset_rescaled75.val_loader,
                    -                           test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                    -                           metrics_progress_verbose=True)[0].cpu().item()
                    -        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_t_seg75"], delta=0.001)
                    -
                    -    def test_pretrained_pplite_b_seg50_cityscapes(self):
                    -        trainer = Trainer('cityscapes_pretrained_pplite_b_seg50', model_checkpoints_location='local',
                    -                          multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled50, data_loader_num_workers=8)
                    -        trainer.build_model("pp_lite_b_seg50", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"],
                    -                            checkpoint_params=self.cityscapes_pretrained_ckpt_params)
                    -        res = trainer.test(test_loader=self.cityscapes_dataset_rescaled50.val_loader,
                    -                           test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                    -                           metrics_progress_verbose=True)[0].cpu().item()
                    -        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_b_seg50"], delta=0.001)
                    -
                    -    def test_pretrained_pplite_b_seg75_cityscapes(self):
                    -        trainer = Trainer('cityscapes_pretrained_pplite_b_seg75', model_checkpoints_location='local',
                    -                          multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.cityscapes_dataset_rescaled75, data_loader_num_workers=8)
                    -        trainer.build_model("pp_lite_b_seg75", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"],
                    -                            checkpoint_params=self.cityscapes_pretrained_ckpt_params)
                    -        res = trainer.test(test_loader=self.cityscapes_dataset_rescaled75.val_loader,
                    -                           test_metrics_list=[IoU(num_classes=20, ignore_index=19)],
                    -                           metrics_progress_verbose=True)[0].cpu().item()
                    -        self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_b_seg75"], delta=0.001)
                    -
                         def test_transfer_learning_stdc2_seg75_cityscapes(self):
                         def test_transfer_learning_stdc2_seg75_cityscapes(self):
                             trainer = Trainer('cityscapes_pretrained_stdc2_seg75_transfer_learning', model_checkpoints_location='local',
                             trainer = Trainer('cityscapes_pretrained_stdc2_seg75_transfer_learning', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_segmentation_dataset, data_loader_num_workers=8)
                             model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                             model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                    -                           **self.cityscapes_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.stdc_transfer_segmentation_train_params)
                    +                           **self.cityscapes_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.stdc_transfer_segmentation_train_params,
                    +                      train_loader=self.transfer_segmentation_dataset,
                    +                      valid_loader=self.transfer_segmentation_dataset)
                     
                     
                         def test_transfer_learning_vit_base_imagenet21k(self):
                         def test_transfer_learning_vit_base_imagenet21k(self):
                             trainer = Trainer('imagenet21k_pretrained_vit_base',
                             trainer = Trainer('imagenet21k_pretrained_vit_base',
                                               model_checkpoints_location='local',
                                               model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
                    +
                             model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
                             model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
                    -                           **self.imagenet21k_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.transfer_classification_train_params)
                    +                           **self.imagenet21k_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.transfer_classification_train_params,
                    +                      train_loader=self.transfer_classification_dataloader,
                    +                      valid_loader=self.transfer_classification_dataloader)
                     
                     
                         def test_transfer_learning_vit_large_imagenet21k(self):
                         def test_transfer_learning_vit_large_imagenet21k(self):
                             trainer = Trainer('imagenet21k_pretrained_vit_large',
                             trainer = Trainer('imagenet21k_pretrained_vit_large',
                                               model_checkpoints_location='local',
                                               model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.transfer_classification_dataset, data_loader_num_workers=8)
                    +
                             model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
                             model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
                    -                           **self.imagenet21k_pretrained_ckpt_params)
                    -        trainer.train(model=model, training_params=self.transfer_classification_train_params)
                    +                           **self.imagenet21k_pretrained_ckpt_params, num_classes=5)
                    +        trainer.train(model=model, training_params=self.transfer_classification_train_params,
                    +                      train_loader=self.transfer_classification_dataloader,
                    +                      valid_loader=self.transfer_classification_dataloader)
                     
                     
                         def test_pretrained_vit_base_imagenet(self):
                         def test_pretrained_vit_base_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_vit_base', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_vit_base', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.imagenet_dataset_05_mean_std, data_loader_num_workers=8)
                             model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
                             model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
                                                **self.imagenet_pretrained_ckpt_params)
                                                **self.imagenet_pretrained_ckpt_params)
                             res = \
                             res = \
                    -            trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std.val_loader,
                    +            trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std,
                                              test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
                                              test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_base"], delta=0.001)
                             self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_base"], delta=0.001)
                     
                     
                         def test_pretrained_vit_large_imagenet(self):
                         def test_pretrained_vit_large_imagenet(self):
                             trainer = Trainer('imagenet_pretrained_vit_large', model_checkpoints_location='local',
                             trainer = Trainer('imagenet_pretrained_vit_large', model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(self.imagenet_dataset_05_mean_std, data_loader_num_workers=8)
                             model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
                             model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
                                                **self.imagenet_pretrained_ckpt_params)
                                                **self.imagenet_pretrained_ckpt_params)
                             res = \
                             res = \
                    -            trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std.val_loader,
                    +            trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std,
                    Discard
                    @@ -1,6 +1,6 @@
                     import unittest
                     import unittest
                     
                     
                    -from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
                    +from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
                     from super_gradients.training import Trainer, MultiGPUMode, models
                     from super_gradients.training import Trainer, MultiGPUMode, models
                     from super_gradients.training.metrics.classification_metrics import Accuracy
                     from super_gradients.training.metrics.classification_metrics import Accuracy
                     import os
                     import os
                    @@ -9,12 +9,9 @@ from super_gradients.training.utils.quantization_utils import PostQATConversionC
                     
                     
                     class QATIntegrationTest(unittest.TestCase):
                     class QATIntegrationTest(unittest.TestCase):
                         def _get_trainer(self, experiment_name):
                         def _get_trainer(self, experiment_name):
                    -        dataset_params = {"batch_size": 10}
                    -        dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
                             trainer = Trainer(experiment_name,
                             trainer = Trainer(experiment_name,
                                               model_checkpoints_location='local',
                                               model_checkpoints_location='local',
                                               multi_gpu=MultiGPUMode.OFF)
                                               multi_gpu=MultiGPUMode.OFF)
                    -        trainer.connect_dataset_interface(dataset)
                             model = models.get("resnet18", pretrained_weights="imagenet")
                             model = models.get("resnet18", pretrained_weights="imagenet")
                             return trainer, model
                             return trainer, model
                     
                     
                    @@ -47,7 +44,8 @@ class QATIntegrationTest(unittest.TestCase):
                                 "percentile": 99.99
                                 "percentile": 99.99
                             })
                             })
                     
                     
                    -        model.train(model=net, training_params=train_params)
                    +        model.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(),
                    +                    valid_loader=classification_test_dataloader())
                     
                     
                         def test_qat_transition(self):
                         def test_qat_transition(self):
                             model, net = self._get_trainer("test_qat_transition")
                             model, net = self._get_trainer("test_qat_transition")
                    @@ -59,7 +57,8 @@ class QATIntegrationTest(unittest.TestCase):
                                 "percentile": 99.99
                                 "percentile": 99.99
                             })
                             })
                     
                     
                    -        model.train(model=net, training_params=train_params)
                    +        model.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(),
                    +                    valid_loader=classification_test_dataloader())
                     
                     
                         def test_qat_from_calibrated_ckpt(self):
                         def test_qat_from_calibrated_ckpt(self):
                             model, net = self._get_trainer("generate_calibrated_model")
                             model, net = self._get_trainer("generate_calibrated_model")
                    @@ -71,7 +70,8 @@ class QATIntegrationTest(unittest.TestCase):
                                 "percentile": 99.99
                                 "percentile": 99.99
                             })
                             })
                     
                     
                    -        model.train(model=net, training_params=train_params)
                    +        model.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(),
                    +                    valid_loader=classification_test_dataloader())
                     
                     
                             calibrated_model_path = os.path.join(model.checkpoints_dir_path, "ckpt_calibrated_percentile_99.99.pth")
                             calibrated_model_path = os.path.join(model.checkpoints_dir_path, "ckpt_calibrated_percentile_99.99.pth")
                     
                     
                    @@ -85,7 +85,8 @@ class QATIntegrationTest(unittest.TestCase):
                                 "percentile": 99.99
                                 "percentile": 99.99
                             })
                             })
                     
                     
                    -        model.train(model=net, training_params=train_params)
                    +        model.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(),
                    +                    valid_loader=classification_test_dataloader())
                     
                     
                     
                     
                     if __name__ == '__main__':
                     if __name__ == '__main__':
                    Discard
                    1
                    2
                    3
                    4
                    5
                    6
                    7
                    8
                    9
                    10
                    11
                    12
                    13
                    14
                    15
                    16
                    17
                    18
                    19
                    20
                    21
                    1. import unittest
                    2. import os
                    3. import shutil
                    4. from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationDatasetInterface
                    5. class TestDataset(unittest.TestCase):
                    6. def test_donwload_dataset(self):
                    7. default_dataset_params = {"dataset_dir": os.path.expanduser("~/test_data/"),
                    8. "s3_link": "s3://research-data1/data.zip"}
                    9. dataset = ClassificationDatasetInterface(dataset_params=default_dataset_params)
                    10. test_sample = dataset.get_test_sample()
                    11. self.assertListEqual([3, 64, 64], list(test_sample[0].shape))
                    12. shutil.rmtree(default_dataset_params["dataset_dir"])
                    13. if __name__ == '__main__':
                    14. unittest.main()
                    Discard
                    @@ -1,5 +1,4 @@
                     # PACKAGE IMPORTS FOR EXTERNAL USAGE
                     # PACKAGE IMPORTS FOR EXTERNAL USAGE
                    -from tests.unit_tests.dataset_interface_test import TestDatasetInterface
                     from tests.unit_tests.factories_test import FactoriesTest
                     from tests.unit_tests.factories_test import FactoriesTest
                     from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
                     from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
                     from tests.unit_tests.zero_weight_decay_on_bias_bn_test import ZeroWdForBnBiasTest
                     from tests.unit_tests.zero_weight_decay_on_bias_bn_test import ZeroWdForBnBiasTest
                    @@ -18,7 +17,7 @@ from tests.unit_tests.conv_bn_relu_test import TestConvBnRelu
                     from tests.unit_tests.initialize_with_dataloaders_test import InitializeWithDataloadersTest
                     from tests.unit_tests.initialize_with_dataloaders_test import InitializeWithDataloadersTest
                     from tests.unit_tests.training_params_factory_test import TrainingParamsTest
                     from tests.unit_tests.training_params_factory_test import TrainingParamsTest
                     
                     
                    -__all__ = ['TestDatasetInterface', 'ZeroWdForBnBiasTest', 'SaveCkptListUnitTest',
                    +__all__ = ['ZeroWdForBnBiasTest', 'SaveCkptListUnitTest',
                                'AllArchitecturesTest', 'TestAverageMeter', 'TestRepVgg', 'TestWithoutTrainTest',
                                'AllArchitecturesTest', 'TestAverageMeter', 'TestRepVgg', 'TestWithoutTrainTest',
                                'StrictLoadEnumTest', 'TrainWithInitializedObjectsTest', 'TestAutoAugment',
                                'StrictLoadEnumTest', 'TrainWithInitializedObjectsTest', 'TestAutoAugment',
                                'OhemLossTest', 'EarlyStopTest', 'SegmentationTransformsTest', 'PretrainedModelsUnitTest', 'TestConvBnRelu',
                                'OhemLossTest', 'EarlyStopTest', 'SegmentationTransformsTest', 'PretrainedModelsUnitTest', 'TestConvBnRelu',
                    Discard
                    @@ -4,7 +4,7 @@ import pkg_resources
                     import yaml
                     import yaml
                     from torch.utils.data import DataLoader
                     from torch.utils.data import DataLoader
                     
                     
                    -from super_gradients.training.dataloaders.dataloader_factory import cityscapes_train, cityscapes_val, \
                    +from super_gradients.training.dataloaders.dataloaders import cityscapes_train, cityscapes_val, \
                         cityscapes_stdc_seg50_train, cityscapes_stdc_seg50_val, cityscapes_stdc_seg75_val, cityscapes_ddrnet_train, \
                         cityscapes_stdc_seg50_train, cityscapes_stdc_seg50_val, cityscapes_stdc_seg75_val, cityscapes_ddrnet_train, \
                         cityscapes_regseg48_val, cityscapes_regseg48_train, cityscapes_ddrnet_val, cityscapes_stdc_seg75_train
                         cityscapes_regseg48_val, cityscapes_regseg48_train, cityscapes_ddrnet_val, cityscapes_stdc_seg75_train
                     from super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation import CityscapesDataset
                     from super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation import CityscapesDataset
                    Discard
                    @@ -5,7 +5,7 @@ import pkg_resources
                     import yaml
                     import yaml
                     from torch.utils.data import DataLoader
                     from torch.utils.data import DataLoader
                     
                     
                    -from super_gradients.training.dataloaders.dataloader_factory import coco_segmentation_train, coco_segmentation_val
                    +from super_gradients.training.dataloaders.dataloaders import coco_segmentation_train, coco_segmentation_val
                     from super_gradients.training.datasets.segmentation_datasets.coco_segmentation import CoCoSegmentationDataSet
                     from super_gradients.training.datasets.segmentation_datasets.coco_segmentation import CoCoSegmentationDataSet
                     
                     
                     
                     
                    Discard
                    @@ -2,7 +2,7 @@ import unittest
                     
                     
                     from torch.utils.data import DataLoader, TensorDataset
                     from torch.utils.data import DataLoader, TensorDataset
                     
                     
                    -from super_gradients.training.dataloaders.dataloader_factory import (
                    +from super_gradients.training.dataloaders.dataloaders import (
                         classification_test_dataloader,
                         classification_test_dataloader,
                         detection_test_dataloader,
                         detection_test_dataloader,
                         segmentation_test_dataloader,
                         segmentation_test_dataloader,
                    Discard
                    1
                    2
                    3
                    4
                    5
                    6
                    7
                    8
                    9
                    10
                    11
                    12
                    13
                    14
                    15
                    16
                    17
                    18
                    19
                    20
                    21
                    22
                    23
                    24
                    25
                    26
                    27
                    28
                    29
                    30
                    31
                    32
                    33
                    34
                    35
                    36
                    37
                    38
                    39
                    40
                    41
                    42
                    43
                    44
                    45
                    46
                    47
                    48
                    49
                    50
                    51
                    52
                    53
                    54
                    55
                    56
                    57
                    58
                    59
                    60
                    61
                    62
                    63
                    64
                    65
                    66
                    67
                    68
                    69
                    70
                    71
                    72
                    73
                    74
                    75
                    76
                    77
                    78
                    79
                    80
                    81
                    82
                    83
                    84
                    85
                    86
                    87
                    88
                    89
                    90
                    91
                    92
                    93
                    94
                    95
                    96
                    97
                    98
                    99
                    100
                    101
                    102
                    103
                    104
                    105
                    106
                    107
                    108
                    109
                    110
                    111
                    112
                    113
                    114
                    115
                    116
                    117
                    118
                    119
                    120
                    121
                    122
                    123
                    124
                    125
                    126
                    127
                    128
                    129
                    130
                    131
                    132
                    133
                    134
                    135
                    136
                    137
                    138
                    139
                    1. import unittest
                    2. from super_gradients.training.datasets.dataset_interfaces.dataset_interface import PascalVOCUnifiedDetectionDatasetInterface,\
                    3. CoCoDetectionDatasetInterface
                    4. from super_gradients.training.transforms.transforms import DetectionPaddedRescale, DetectionTargetsFormatTransform, DetectionMosaic, DetectionRandomAffine,\
                    5. DetectionHSV
                    6. from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
                    7. from super_gradients.training.utils.detection_utils import DetectionCollateFN
                    8. from super_gradients.training.utils import sg_trainer_utils
                    9. from super_gradients.training import utils as core_utils
                    10. class TestDatasetInterface(unittest.TestCase):
                    11. def setUp(self) -> None:
                    12. self.root_dir = "/home/louis.dupont/data/"
                    13. self.train_batch_size, self.val_batch_size = 16, 32
                    14. self.train_image_size, self.val_image_size = 640, 640
                    15. self.train_input_dim = (self.train_image_size, self.train_image_size)
                    16. self.val_input_dim = (self.val_image_size, self.val_image_size)
                    17. self.train_max_num_samples = 100
                    18. self.val_max_num_samples = 90
                    19. def setup_pascal_voc_interface(self):
                    20. """setup PascalVOCUnifiedDetectionDatasetInterface and return dataloaders"""
                    21. dataset_params = {
                    22. "data_dir": self.root_dir + "pascal_unified_coco_format/",
                    23. "cache_dir": self.root_dir + "pascal_unified_coco_format/",
                    24. "batch_size": self.train_batch_size,
                    25. "val_batch_size": self.val_batch_size,
                    26. "train_image_size": self.train_image_size,
                    27. "val_image_size": self.val_image_size,
                    28. "train_max_num_samples": self.train_max_num_samples,
                    29. "val_max_num_samples": self.val_max_num_samples,
                    30. "train_transforms": [
                    31. DetectionMosaic(input_dim=self.train_input_dim, prob=1),
                    32. DetectionRandomAffine(degrees=0.373, translate=0.245, scales=0.898, shear=0.602, target_size=self.train_input_dim),
                    33. DetectionHSV(prob=1, hgain=0.0138, sgain=0.664, vgain=0.464),
                    34. DetectionPaddedRescale(input_dim=self.train_input_dim, max_targets=100),
                    35. DetectionTargetsFormatTransform(input_format=DetectionTargetsFormat.XYXY_LABEL,
                    36. output_format=DetectionTargetsFormat.LABEL_CXCYWH)],
                    37. "val_transforms": [
                    38. DetectionPaddedRescale(input_dim=self.val_input_dim),
                    39. DetectionTargetsFormatTransform(input_format=DetectionTargetsFormat.XYXY_LABEL,
                    40. output_format=DetectionTargetsFormat.LABEL_CXCYWH)],
                    41. "train_collate_fn": DetectionCollateFN(),
                    42. "val_collate_fn": DetectionCollateFN(),
                    43. "download": False,
                    44. "cache_train_images": False,
                    45. "cache_val_images": False,
                    46. "class_inclusion_list": ["person"]
                    47. }
                    48. dataset_interface = PascalVOCUnifiedDetectionDatasetInterface(dataset_params=dataset_params)
                    49. train_loader, valid_loader, _test_loader, _classes = dataset_interface.get_data_loaders()
                    50. return train_loader, valid_loader
                    51. def setup_coco_detection_interface(self):
                    52. """setup CoCoDetectionDatasetInterface and return dataloaders"""
                    53. dataset_params = {
                    54. "data_dir": "/data/coco",
                    55. "train_subdir": "images/train2017", # sub directory path of data_dir containing the train data.
                    56. "val_subdir": "images/val2017", # sub directory path of data_dir containing the validation data.
                    57. "train_json_file": "instances_train2017.json", # path to coco train json file, data_dir/annotations/train_json_file.
                    58. "val_json_file": "instances_val2017.json", # path to coco validation json file, data_dir/annotations/val_json_file.
                    59. "batch_size": self.train_batch_size,
                    60. "val_batch_size": self.val_batch_size,
                    61. "train_image_size": self.train_image_size,
                    62. "val_image_size": self.val_image_size,
                    63. "train_max_num_samples": self.train_max_num_samples,
                    64. "val_max_num_samples": self.val_max_num_samples,
                    65. "mixup_prob": 1.0, # probability to apply per-sample mixup
                    66. "degrees": 10., # rotation degrees, randomly sampled from [-degrees, degrees]
                    67. "shear": 2.0, # shear degrees, randomly sampled from [-degrees, degrees]
                    68. "flip_prob": 0.5, # probability to apply horizontal flip
                    69. "hsv_prob": 1.0, # probability to apply HSV transform
                    70. "hgain": 5, # HSV transform hue gain (randomly sampled from [-hgain, hgain])
                    71. "sgain": 30, # HSV transform saturation gain (randomly sampled from [-sgain, sgain])
                    72. "vgain": 30, # HSV transform value gain (randomly sampled from [-vgain, vgain])
                    73. "mosaic_scale": [0.1, 2], # random rescale range (keeps size by padding/cropping) after mosaic transform.
                    74. "mixup_scale": [0.5, 1.5], # random rescale range for the additional sample in mixup
                    75. "mosaic_prob": 1., # probability to apply mosaic
                    76. "translate": 0.1, # image translation fraction
                    77. "filter_box_candidates": False, # whether to filter out transformed bboxes by edge size, area ratio, and aspect ratio.
                    78. "wh_thr": 2, # edge size threshold when filter_box_candidates = True (pixels)
                    79. "ar_thr": 20, # aspect ratio threshold when filter_box_candidates = True
                    80. "area_thr": 0.1, # threshold for area ratio between original image and the transformed one, when when filter_box_candidates = True
                    81. "tight_box_rotation": False,
                    82. "download": False,
                    83. "train_collate_fn": DetectionCollateFN(),
                    84. "val_collate_fn": DetectionCollateFN(),
                    85. "cache_train_images": False,
                    86. "cache_val_images": False,
                    87. "cache_dir": "/home/data/cache", # Depends on the user
                    88. "class_inclusion_list": None
                    89. # "with_crowd": True
                    90. }
                    91. dataset_interface = CoCoDetectionDatasetInterface(dataset_params=dataset_params)
                    92. train_loader, valid_loader, _test_loader, _classes = dataset_interface.get_data_loaders()
                    93. return train_loader, valid_loader
                    94. def test_coco_detection(self):
                    95. """Check that the dataset interface is correctly instantiated, and that the batch items are of expected size"""
                    96. train_loader, valid_loader = self.setup_coco_detection_interface()
                    97. for loader, batch_size, image_size, max_num_samples in [(train_loader, self.train_batch_size, self.train_image_size, self.train_max_num_samples),
                    98. (valid_loader, self.val_batch_size, self.val_image_size, self.val_max_num_samples)]:
                    99. # The dataset is at most of length max_num_samples, but can be smaller if not enough samples
                    100. self.assertGreaterEqual(max_num_samples, len(loader.dataset))
                    101. batch_items = next(iter(loader))
                    102. batch_items = core_utils.tensor_container_to_device(batch_items, 'cuda', non_blocking=True)
                    103. inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items)
                    104. self.assertListEqual([batch_size, 3, image_size, image_size], list(inputs.shape))
                    105. def test_pascal_voc(self):
                    106. """Check that the dataset interface is correctly instantiated, and that the batch items are of expected size"""
                    107. train_loader, valid_loader = self.setup_pascal_voc_interface()
                    108. for loader, batch_size, image_size, max_num_samples in [(train_loader, self.train_batch_size, self.train_image_size, self.train_max_num_samples),
                    109. (valid_loader, self.val_batch_size, self.val_image_size, self.val_max_num_samples)]:
                    110. # The dataset is at most of length max_num_samples, but can be smaller if not enough samples
                    111. self.assertGreaterEqual(max_num_samples, len(loader.dataset))
                    112. batch_items = next(iter(loader))
                    113. batch_items = core_utils.tensor_container_to_device(batch_items, 'cuda', non_blocking=True)
                    114. inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items)
                    115. self.assertListEqual([batch_size, 3, image_size, image_size], list(inputs.shape))
                    116. if __name__ == '__main__':
                    117. unittest.main()
                    Discard
                    @@ -1,12 +1,10 @@
                     import unittest
                     import unittest
                     
                     
                    -from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CoCoDetectionDatasetInterface
                    +from super_gradients.training.dataloaders.dataloaders import coco2017_train, coco2017_val
                     from super_gradients.training.metrics.detection_metrics import DetectionMetrics
                     from super_gradients.training.metrics.detection_metrics import DetectionMetrics
                     
                     
                     from super_gradients.training import Trainer, models
                     from super_gradients.training import Trainer, models
                     from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
                     from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
                    -from super_gradients.training.utils.detection_utils import CrowdDetectionCollateFN, DetectionCollateFN, \
                    -    DetectionTargetsFormat
                     
                     
                     
                     
                     class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
                     class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
                    @@ -19,44 +17,11 @@ class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
                             browser and make sure the text and plots in the tensorboard are as expected.
                             browser and make sure the text and plots in the tensorboard are as expected.
                             """
                             """
                             # Create dataset
                             # Create dataset
                    -        dataset = CoCoDetectionDatasetInterface(dataset_params={"data_dir": "/data/coco",
                    -                                                                "train_subdir": "images/train2017",
                    -                                                                "val_subdir": "images/val2017",
                    -                                                                "train_json_file": "instances_train2017.json",
                    -                                                                "val_json_file": "instances_val2017.json",
                    -                                                                "batch_size": 16,
                    -                                                                "val_batch_size": 128,
                    -                                                                "val_image_size": 640,
                    -                                                                "train_image_size": 640,
                    -                                                                "hgain": 5,
                    -                                                                "sgain": 30,
                    -                                                                "vgain": 30,
                    -                                                                "mixup_prob": 1.0,
                    -                                                                "degrees": 10.,
                    -                                                                "shear": 2.0,
                    -                                                                "flip_prob": 0.5,
                    -                                                                "hsv_prob": 1.0,
                    -                                                                "mosaic_scale": [0.1, 2],
                    -                                                                "mixup_scale": [0.5, 1.5],
                    -                                                                "mosaic_prob": 1.,
                    -                                                                "translate": 0.1,
                    -                                                                "val_collate_fn": CrowdDetectionCollateFN(),
                    -                                                                "train_collate_fn": DetectionCollateFN(),
                    -                                                                "cache_dir_path": None,
                    -                                                                "cache_train_images": False,
                    -                                                                "cache_val_images": False,
                    -                                                                "targets_format": DetectionTargetsFormat.LABEL_CXCYWH,
                    -                                                                "with_crowd": True,
                    -                                                                "filter_box_candidates": False,
                    -                                                                "wh_thr": 0,
                    -                                                                "ar_thr": 0,
                    -                                                                "area_thr": 0
                    -                                                                })
                     
                     
                             trainer = Trainer('dataset_statistics_visual_test',
                             trainer = Trainer('dataset_statistics_visual_test',
                                               model_checkpoints_location='local',
                                               model_checkpoints_location='local',
                                               post_prediction_callback=YoloPostPredictionCallback())
                                               post_prediction_callback=YoloPostPredictionCallback())
                    -        trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
                    +
                             model = models.get("yolox_s")
                             model = models.get("yolox_s")
                     
                     
                             training_params = {"max_epochs": 1,  # we dont really need the actual training to run
                             training_params = {"max_epochs": 1,  # we dont really need the actual training to run
                    @@ -74,7 +39,7 @@ class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
                                                "loss_logging_items_names": ["iou", "obj", "cls", "l1", "num_fg", "Loss"],
                                                "loss_logging_items_names": ["iou", "obj", "cls", "l1", "num_fg", "Loss"],
                                                "metric_to_watch": "mAP@0.50:0.95",
                                                "metric_to_watch": "mAP@0.50:0.95",
                                                }
                                                }
                    -        trainer.train(model=model, training_params=training_params)
                    +        trainer.train(model=model, training_params=training_params, train_loader=coco2017_train(), valid_loader=coco2017_val())
                     
                     
                     
                     
                     if __name__ == '__main__':
                     if __name__ == '__main__':
                    Discard
                    @@ -2,58 +2,23 @@ import os
                     import unittest
                     import unittest
                     
                     
                     from super_gradients.training import Trainer, utils as core_utils, models
                     from super_gradients.training import Trainer, utils as core_utils, models
                    -from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CoCoDetectionDatasetInterface
                    +from super_gradients.training.dataloaders.dataloaders import coco2017_val
                     from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
                     from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
                     from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
                     from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
                    -from super_gradients.training.utils.detection_utils import DetectionVisualization, DetectionCollateFN, DetectionTargetsFormat
                    +from super_gradients.training.utils.detection_utils import DetectionVisualization
                     
                     
                     
                     
                     class TestDetectionUtils(unittest.TestCase):
                     class TestDetectionUtils(unittest.TestCase):
                         def test_visualization(self):
                         def test_visualization(self):
                    -        # Create dataset
                    -        dataset = CoCoDetectionDatasetInterface(dataset_params={"data_dir": "/data/coco",
                    -                                                                "train_subdir": "images/train2017",
                    -                                                                "val_subdir": "images/val2017",
                    -                                                                "train_json_file": "instances_train2017.json",
                    -                                                                "val_json_file": "instances_val2017.json",
                    -                                                                "batch_size": 16,
                    -                                                                "val_batch_size": 4,
                    -                                                                "val_image_size": 640,
                    -                                                                "train_image_size": 640,
                    -                                                                "hgain": 5,
                    -                                                                "sgain": 30,
                    -                                                                "vgain": 30,
                    -                                                                "mixup_prob": 1.0,
                    -                                                                "degrees": 10.,
                    -                                                                "shear": 2.0,
                    -                                                                "flip_prob": 0.5,
                    -                                                                "hsv_prob": 1.0,
                    -                                                                "mosaic_scale": [0.1, 2],
                    -                                                                "mixup_scale": [0.5, 1.5],
                    -                                                                "mosaic_prob": 1.,
                    -                                                                "translate": 0.1,
                    -                                                                "val_collate_fn": DetectionCollateFN(),
                    -                                                                "train_collate_fn": DetectionCollateFN(),
                    -                                                                "cache_dir_path": None,
                    -                                                                "cache_train_images": False,
                    -                                                                "cache_val_images": False,
                    -                                                                "targets_format": DetectionTargetsFormat.LABEL_NORMALIZED_CXCYWH,
                    -                                                                "with_crowd": False,
                    -                                                                "filter_box_candidates": False,
                    -                                                                "wh_thr": 0,
                    -                                                                "ar_thr": 0,
                    -                                                                "area_thr": 0
                    -                                                                })
                     
                     
                             # Create Yolo model
                             # Create Yolo model
                             trainer = Trainer('visualization_test',
                             trainer = Trainer('visualization_test',
                                               model_checkpoints_location='local',
                                               model_checkpoints_location='local',
                                               post_prediction_callback=YoloPostPredictionCallback())
                                               post_prediction_callback=YoloPostPredictionCallback())
                    -        trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
                             model = models.get("yolox_n", pretrained_weights="coco")
                             model = models.get("yolox_n", pretrained_weights="coco")
                     
                     
                             # Simulate one iteration of validation subset
                             # Simulate one iteration of validation subset
                    -        valid_loader = trainer.valid_loader
                    +        valid_loader = coco2017_val()
                             batch_i, (imgs, targets) = 0, next(iter(valid_loader))
                             batch_i, (imgs, targets) = 0, next(iter(valid_loader))
                             imgs = core_utils.tensor_container_to_device(imgs, trainer.device)
                             imgs = core_utils.tensor_container_to_device(imgs, trainer.device)
                             targets = core_utils.tensor_container_to_device(targets, trainer.device)
                             targets = core_utils.tensor_container_to_device(targets, trainer.device)
                    Discard
                    @@ -2,10 +2,10 @@ import torch
                     import torch.nn as nn
                     import torch.nn as nn
                     import unittest
                     import unittest
                     
                     
                    +from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
                     from super_gradients.training.utils.early_stopping import EarlyStop
                     from super_gradients.training.utils.early_stopping import EarlyStop
                     from super_gradients.training.utils.callbacks import Phase
                     from super_gradients.training.utils.callbacks import Phase
                     from super_gradients.training.sg_trainer import Trainer
                     from super_gradients.training.sg_trainer import Trainer
                    -from super_gradients.training.datasets.dataset_interfaces import ClassificationTestDatasetInterface
                     from super_gradients.training.models.classification_models.resnet import ResNet18
                     from super_gradients.training.models.classification_models.resnet import ResNet18
                     from super_gradients.training.metrics import Accuracy, Top5
                     from super_gradients.training.metrics import Accuracy, Top5
                     from torchmetrics.metric import Metric
                     from torchmetrics.metric import Metric
                    @@ -43,8 +43,6 @@ class LossTest(nn.Module):
                     class EarlyStopTest(unittest.TestCase):
                     class EarlyStopTest(unittest.TestCase):
                         def setUp(self) -> None:
                         def setUp(self) -> None:
                             # batch_size is equal to length of dataset, to have only one step per epoch, to ease the test.
                             # batch_size is equal to length of dataset, to have only one step per epoch, to ease the test.
                    -        dataset_params = {"batch_size": 10}
                    -        self.dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params, batch_size=10)
                             self.net = ResNet18(num_classes=5, arch_params={})
                             self.net = ResNet18(num_classes=5, arch_params={})
                             self.max_epochs = 10
                             self.max_epochs = 10
                             self.train_params = {"max_epochs": self.max_epochs, "lr_updates": [1], "lr_decay_factor": 0.1,
                             self.train_params = {"max_epochs": self.max_epochs, "lr_updates": [1], "lr_decay_factor": 0.1,
                    @@ -61,7 +59,6 @@ class EarlyStopTest(unittest.TestCase):
                             epochs.
                             epochs.
                             """
                             """
                             trainer = Trainer("early_stop_test", model_checkpoints_location='local')
                             trainer = Trainer("early_stop_test", model_checkpoints_location='local')
                    -        trainer.connect_dataset_interface(self.dataset)
                     
                     
                             early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", patience=3, verbose=True)
                             early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", patience=3, verbose=True)
                             phase_callbacks = [early_stop_loss]
                             phase_callbacks = [early_stop_loss]
                    @@ -71,8 +68,8 @@ class EarlyStopTest(unittest.TestCase):
                             train_params = self.train_params.copy()
                             train_params = self.train_params.copy()
                             train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
                             train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
                     
                     
                    -        trainer.train(model=self.net, training_params=train_params)
                    -
                    +        trainer.train(model=self.net, training_params=train_params, train_loader=classification_test_dataloader(),
                    +                      valid_loader=classification_test_dataloader())
                             excepted_end_epoch = 5
                             excepted_end_epoch = 5
                     
                     
                             # count divided by 2, because loss counter used for both train and eval.
                             # count divided by 2, because loss counter used for both train and eval.
                    @@ -84,8 +81,6 @@ class EarlyStopTest(unittest.TestCase):
                             epochs.
                             epochs.
                             """
                             """
                             trainer = Trainer("early_stop_test", model_checkpoints_location='local')
                             trainer = Trainer("early_stop_test", model_checkpoints_location='local')
                    -        trainer.connect_dataset_interface(self.dataset)
                    -
                             early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=3,
                             early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=3,
                                                        verbose=True)
                                                        verbose=True)
                             phase_callbacks = [early_stop_acc]
                             phase_callbacks = [early_stop_acc]
                    @@ -96,8 +91,8 @@ class EarlyStopTest(unittest.TestCase):
                             train_params.update(
                             train_params.update(
                                 {"valid_metrics_list": [fake_metric], "metric_to_watch": "MetricTest", "phase_callbacks": phase_callbacks})
                                 {"valid_metrics_list": [fake_metric], "metric_to_watch": "MetricTest", "phase_callbacks": phase_callbacks})
                     
                     
                    -        trainer.train(model=self.net, training_params=train_params)
                    -
                    +        trainer.train(model=self.net, training_params=train_params, train_loader=classification_test_dataloader(),
                    +                      valid_loader=classification_test_dataloader())
                             excepted_end_epoch = 6
                             excepted_end_epoch = 6
                     
                     
                             self.assertEqual(excepted_end_epoch, fake_metric.count)
                             self.assertEqual(excepted_end_epoch, fake_metric.count)
                    @@ -107,7 +102,6 @@ class EarlyStopTest(unittest.TestCase):
                             Test for mode=min metric, test that training stops after metric value reaches the `threshold` value.
                             Test for mode=min metric, test that training stops after metric value reaches the `threshold` value.
                             """
                             """
                             trainer = Trainer("early_stop_test", model_checkpoints_location='local')
                             trainer = Trainer("early_stop_test", model_checkpoints_location='local')
                    -        trainer.connect_dataset_interface(self.dataset)
                     
                     
                             early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", threshold=0.1, verbose=True)
                             early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", threshold=0.1, verbose=True)
                             phase_callbacks = [early_stop_loss]
                             phase_callbacks = [early_stop_loss]
                    @@ -117,8 +111,8 @@ class EarlyStopTest(unittest.TestCase):
                             train_params = self.train_params.copy()
                             train_params = self.train_params.copy()
                             train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
                             train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
                     
                     
                    -        trainer.train(model=self.net, training_params=train_params)
                    -
                    +        trainer.train(model=self.net, training_params=train_params, train_loader=classification_test_dataloader(),
                    +                      valid_loader=classification_test_dataloader())
                             excepted_end_epoch = 5
                             excepted_end_epoch = 5
                             # count divided by 2, because loss counter used for both train and eval.
                             # count divided by 2, because loss counter used for both train and eval.
                             self.assertEqual(excepted_end_epoch, fake_loss.count // 2)
                             self.assertEqual(excepted_end_epoch, fake_loss.count // 2)
                    @@ -128,7 +122,6 @@ class EarlyStopTest(unittest.TestCase):
                             Test for mode=max metric, test that training stops after metric value reaches the `threshold` value.
                             Test for mode=max metric, test that training stops after metric value reaches the `threshold` value.
                             """
                             """
                             trainer = Trainer("early_stop_test", model_checkpoints_location='local')
                             trainer = Trainer("early_stop_test", model_checkpoints_location='local')
                    -        trainer.connect_dataset_interface(self.dataset)
                     
                     
                             early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", threshold=0.94,
                             early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", threshold=0.94,
                                                        verbose=True)
                                                        verbose=True)
                    @@ -140,8 +133,8 @@ class EarlyStopTest(unittest.TestCase):
                             train_params.update(
                             train_params.update(
                                 {"valid_metrics_list": [fake_metric], "metric_to_watch": "MetricTest", "phase_callbacks": phase_callbacks})
                                 {"valid_metrics_list": [fake_metric], "metric_to_watch": "MetricTest", "phase_callbacks": phase_callbacks})
                     
                     
                    -        trainer.train(model=self.net, training_params=train_params)
                    -
                    +        trainer.train(model=self.net, training_params=train_params, train_loader=classification_test_dataloader(),
                    +                      valid_loader=classification_test_dataloader())
                             excepted_end_epoch = 7
                             excepted_end_epoch = 7
                     
                     
                             self.assertEqual(excepted_end_epoch, fake_metric.count)
                             self.assertEqual(excepted_end_epoch, fake_metric.count)
                    @@ -152,7 +145,6 @@ class EarlyStopTest(unittest.TestCase):
                             """
                             """
                             # test Nan value
                             # test Nan value
                             trainer = Trainer("early_stop_test", model_checkpoints_location='local')
                             trainer = Trainer("early_stop_test", model_checkpoints_location='local')
                    -        trainer.connect_dataset_interface(self.dataset)
                     
                     
                             early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", check_finite=True,
                             early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", check_finite=True,
                                                         verbose=True)
                                                         verbose=True)
                    @@ -163,16 +155,14 @@ class EarlyStopTest(unittest.TestCase):
                             train_params = self.train_params.copy()
                             train_params = self.train_params.copy()
                             train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
                             train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
                     
                     
                    -        trainer.train(model=self.net, training_params=train_params)
                    -
                    +        trainer.train(model=self.net, training_params=train_params, train_loader=classification_test_dataloader(),
                    +                      valid_loader=classification_test_dataloader())
                             excepted_end_epoch = 2
                             excepted_end_epoch = 2
                     
                     
                             self.assertEqual(excepted_end_epoch, fake_loss.count // 2)
                             self.assertEqual(excepted_end_epoch, fake_loss.count // 2)
                     
                     
                             # test Inf value
                             # test Inf value
                             trainer = Trainer("early_stop_test", model_checkpoints_location='local')
                             trainer = Trainer("early_stop_test", model_checkpoints_location='local')
                    -        trainer.connect_dataset_interface(self.dataset)
                    -        trainer.build_model(self.net)
                     
                     
                             early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", patience=3, verbose=True)
                             early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="Loss", mode="min", patience=3, verbose=True)
                             phase_callbacks = [early_stop_loss]
                             phase_callbacks = [early_stop_loss]
                    @@ -182,8 +172,8 @@ class EarlyStopTest(unittest.TestCase):
                             train_params = self.train_params.copy()
                             train_params = self.train_params.copy()
                             train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
                             train_params.update({"loss": fake_loss, "phase_callbacks": phase_callbacks})
                     
                     
                    -        trainer.train(model=self.net, training_params=train_params)
                    -
                    +        trainer.train(model=self.net, training_params=train_params, train_loader=classification_test_dataloader(),
                    +                      valid_loader=classification_test_dataloader())
                             excepted_end_epoch = 3
                             excepted_end_epoch = 3
                             # count divided by 2, because loss counter used for both train and eval.
                             # count divided by 2, because loss counter used for both train and eval.
                             self.assertEqual(excepted_end_epoch, fake_loss.count // 2)
                             self.assertEqual(excepted_end_epoch, fake_loss.count // 2)
                    @@ -194,7 +184,6 @@ class EarlyStopTest(unittest.TestCase):
                             current_value - min_delta > best_value
                             current_value - min_delta > best_value
                             """
                             """
                             trainer = Trainer("early_stop_test", model_checkpoints_location='local')
                             trainer = Trainer("early_stop_test", model_checkpoints_location='local')
                    -        trainer.connect_dataset_interface(self.dataset)
                     
                     
                             early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=2,
                             early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=2,
                                                        min_delta=0.1, verbose=True)
                                                        min_delta=0.1, verbose=True)
                    @@ -206,7 +195,8 @@ class EarlyStopTest(unittest.TestCase):
                             train_params.update(
                             train_params.update(
                                 {"valid_metrics_list": [fake_metric], "metric_to_watch": "MetricTest", "phase_callbacks": phase_callbacks})
                                 {"valid_metrics_list": [fake_metric], "metric_to_watch": "MetricTest", "phase_callbacks": phase_callbacks})
                     
                     
                    -        trainer.train(model=self.net, training_params=train_params)
                    +        trainer.train(model=self.net, training_params=train_params, train_loader=classification_test_dataloader(),
                    +                      valid_loader=classification_test_dataloader())
                     
                     
                             excepted_end_epoch = 5
                             excepted_end_epoch = 5
                     
                     
                    Discard
                    1
                    2
                    3
                    4
                    5
                    6
                    7
                    8
                    9
                    10
                    11
                    12
                    13
                    14
                    15
                    16
                    17
                    18
                    19
                    20
                    21
                    22
                    23
                    24
                    25
                    26
                    27
                    28
                    29
                    30
                    31
                    32
                    33
                    34
                    35
                    36
                    37
                    38
                    39
                    40
                    41
                    42
                    43
                    44
                    45
                    46
                    47
                    48
                    49
                    50
                    51
                    52
                    53
                    54
                    55
                    56
                    57
                    58
                    59
                    60
                    61
                    62
                    63
                    64
                    65
                    66
                    67
                    68
                    69
                    70
                    71
                    72
                    73
                    74
                    75
                    76
                    77
                    1. import torch
                    2. import unittest
                    3. import numpy as np
                    4. import tensorflow.keras as keras
                    5. from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ExternalDatasetInterface
                    6. class DataGenerator(keras.utils.Sequence):
                    7. def __init__(self, batch_size=1, dim=(320, 320), n_channels=3,
                    8. n_classes=1000, shuffle=True):
                    9. self.dim = dim
                    10. self.batch_size = batch_size
                    11. self.list_IDs = np.ones(1000)
                    12. self.n_channels = n_channels
                    13. self.n_classes = n_classes
                    14. self.shuffle = shuffle
                    15. self.on_epoch_end()
                    16. def __len__(self):
                    17. dataset_len = 32
                    18. return dataset_len
                    19. def __getitem__(self, index):
                    20. indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]
                    21. list_IDs_temp = [self.list_IDs[k] for k in indices]
                    22. X, y = self.__data_generation(list_IDs_temp)
                    23. return X.squeeze(axis=0), y.squeeze(axis=0)
                    24. def on_epoch_end(self):
                    25. self.indices = np.arange(len(self.list_IDs))
                    26. if self.shuffle:
                    27. np.random.shuffle(self.indices)
                    28. def __data_generation(self, list_IDs_temp):
                    29. X = np.ones((self.batch_size, self.n_channels, *self.dim), dtype=np.float32)
                    30. y = np.ones((self.batch_size, 1), dtype=np.float32)
                    31. return X, y
                    32. class TestExternalDatasetInterface(unittest.TestCase):
                    33. def setUp(self):
                    34. params = {'dim': (256, 256),
                    35. 'batch_size': 1,
                    36. 'n_classes': 1000,
                    37. 'n_channels': 3,
                    38. 'shuffle': True}
                    39. training_generator = DataGenerator(**params)
                    40. testing_generator = DataGenerator(**params)
                    41. external_num_classes = 1000
                    42. external_dataset_params = {'batch_size': 16,
                    43. "val_batch_size": 16}
                    44. self.dim = params['dim'][0]
                    45. self.n_channels = params['n_channels']
                    46. self.batch_size = external_dataset_params['batch_size']
                    47. self.val_batch_size = external_dataset_params['val_batch_size']
                    48. self.test_external_dataset_interface = ExternalDatasetInterface(train_loader=training_generator,
                    49. val_loader=testing_generator,
                    50. num_classes=external_num_classes,
                    51. dataset_params=external_dataset_params)
                    52. def test_get_data_loaders(self):
                    53. train_loader, val_loader, _, num_classes = self.test_external_dataset_interface.get_data_loaders()
                    54. for batch_idx, (inputs, targets) in enumerate(train_loader):
                    55. self.assertListEqual([self.batch_size, self.n_channels, self.dim, self.dim], list(inputs.shape))
                    56. self.assertListEqual([self.batch_size, 1], list(targets.shape))
                    57. self.assertEqual(torch.Tensor, type(inputs))
                    58. self.assertEqual(torch.Tensor, type(targets))
                    59. for batch_idx, (inputs, targets) in enumerate(val_loader):
                    60. self.assertListEqual([self.val_batch_size, self.n_channels, self.dim, self.dim], list(inputs.shape))
                    61. self.assertListEqual([self.val_batch_size, 1], list(targets.shape))
                    62. self.assertEqual(torch.Tensor, type(inputs))
                    63. self.assertEqual(torch.Tensor, type(targets))
                    64. if __name__ == '__main__':
                    65. unittest.main()
                    Discard
                    @@ -2,21 +2,17 @@ import unittest
                     
                     
                     import torch
                     import torch
                     
                     
                    -from super_gradients import ClassificationTestDatasetInterface, Trainer
                    +from super_gradients import Trainer
                    +from super_gradients.training import models
                    +from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
                     from super_gradients.training.metrics import Accuracy, Top5
                     from super_gradients.training.metrics import Accuracy, Top5
                    -from super_gradients.training.models import ResNet18
                     
                     
                     
                     
                     class FactoriesTest(unittest.TestCase):
                     class FactoriesTest(unittest.TestCase):
                     
                     
                         def test_training_with_factories(self):
                         def test_training_with_factories(self):
                             trainer = Trainer("test_train_with_factories", model_checkpoints_location='local')
                             trainer = Trainer("test_train_with_factories", model_checkpoints_location='local')
                    -        dataset_params = {"batch_size": 10}
                    -        dataset = {"classification_test_dataset": {"dataset_params": dataset_params}}
                    -        trainer.connect_dataset_interface(dataset)
                    -
                    -        net = ResNet18(num_classes=5, arch_params={})
                    -        trainer.build_model(net)
                    +        net = models.get("resnet18", num_classes=5)
                             train_params = {"max_epochs": 2,
                             train_params = {"max_epochs": 2,
                                             "lr_updates": [1],
                                             "lr_updates": [1],
                                             "lr_decay_factor": 0.1,
                                             "lr_decay_factor": 0.1,
                    @@ -32,11 +28,12 @@ class FactoriesTest(unittest.TestCase):
                                             "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                                             "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                                             "greater_metric_to_watch_is_better": True}
                                             "greater_metric_to_watch_is_better": True}
                     
                     
                    -        trainer.train(model=net, training_params=train_params)
                    +        trainer.train(model=net, training_params=train_params,
                    +                      train_loader=classification_test_dataloader(),
                    +                      valid_loader=classification_test_dataloader())
                     
                     
                             self.assertIsInstance(trainer.train_metrics.Accuracy, Accuracy)
                             self.assertIsInstance(trainer.train_metrics.Accuracy, Accuracy)
                             self.assertIsInstance(trainer.valid_metrics.Top5, Top5)
                             self.assertIsInstance(trainer.valid_metrics.Top5, Top5)
                    -        self.assertIsInstance(trainer.dataset_interface, ClassificationTestDatasetInterface)
                             self.assertIsInstance(trainer.optimizer, torch.optim.ASGD)
                             self.assertIsInstance(trainer.optimizer, torch.optim.ASGD)
                     
                     
                     
                     
                    Discard
                    @@ -1,7 +1,7 @@
                     import unittest
                     import unittest
                     from super_gradients.training import Trainer, models
                     from super_gradients.training import Trainer, models
                    +from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
                     from super_gradients.training.metrics import Accuracy
                     from super_gradients.training.metrics import Accuracy
                    -from super_gradients.training.datasets import ClassificationTestDatasetInterface
                     from super_gradients.training.utils.callbacks import PhaseCallback, Phase, PhaseContext
                     from super_gradients.training.utils.callbacks import PhaseCallback, Phase, PhaseContext
                     import torch
                     import torch
                     
                     
                    @@ -27,16 +27,11 @@ def test_forward_pass_prep_fn(inputs, targets, *args, **kwargs):
                     
                     
                     
                     
                     class ForwardpassPrepFNTest(unittest.TestCase):
                     class ForwardpassPrepFNTest(unittest.TestCase):
                    -    def setUp(self) -> None:
                    -        self.dataset_params = {"batch_size": 4}
                    -        self.dataset = ClassificationTestDatasetInterface(dataset_params=self.dataset_params)
                    -        self.arch_params = {'num_classes': 10}
                     
                     
                         def test_resizing_with_forward_pass_prep_fn(self):
                         def test_resizing_with_forward_pass_prep_fn(self):
                             # Define Model
                             # Define Model
                             trainer = Trainer("ForwardpassPrepFNTest")
                             trainer = Trainer("ForwardpassPrepFNTest")
                    -        trainer.connect_dataset_interface(self.dataset)
                    -        model = models.get("resnet18", arch_params=self.arch_params)
                    +        model = models.get("resnet18", num_classes=5)
                     
                     
                             sizes = []
                             sizes = []
                             phase_callbacks = [TestInputSizesCallback(sizes)]
                             phase_callbacks = [TestInputSizesCallback(sizes)]
                    @@ -49,7 +44,8 @@ class ForwardpassPrepFNTest(unittest.TestCase):
                                             "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                                             "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                                             "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks,
                                             "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks,
                                             "pre_prediction_callback": test_forward_pass_prep_fn}
                                             "pre_prediction_callback": test_forward_pass_prep_fn}
                    -        trainer.train(model=model, training_params=train_params)
                    +        trainer.train(model=model, training_params=train_params, train_loader=classification_test_dataloader(),
                    +                      valid_loader=classification_test_dataloader())
                     
                     
                             # ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
                             # ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
                             # THE LRS AFTER THE UPDATE
                             # THE LRS AFTER THE UPDATE
                    Discard
                    @@ -2,7 +2,7 @@ import unittest
                     
                     
                     from super_gradients.training import models
                     from super_gradients.training import models
                     
                     
                    -from super_gradients import Trainer, ClassificationTestDatasetInterface
                    +from super_gradients import Trainer
                     import torch
                     import torch
                     from torch.utils.data import TensorDataset, DataLoader
                     from torch.utils.data import TensorDataset, DataLoader
                     from super_gradients.training.metrics import Accuracy
                     from super_gradients.training.metrics import Accuracy
                    @@ -26,22 +26,6 @@ class InitializeWithDataloadersTest(unittest.TestCase):
                             label = torch.randint(0, len(self.testcase_classes), size=(test_size,))
                             label = torch.randint(0, len(self.testcase_classes), size=(test_size,))
                             self.testcase_testloader = DataLoader(TensorDataset(inp, label))
                             self.testcase_testloader = DataLoader(TensorDataset(inp, label))
                     
                     
                    -    def test_interface_was_not_broken(self):
                    -        trainer = Trainer("test_interface", model_checkpoints_location='local')
                    -        dataset_params = {"batch_size": 10}
                    -        dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
                    -        trainer.connect_dataset_interface(dataset)
                    -
                    -        model = models.get("efficientnet_b0", arch_params={"num_classes": 5})
                    -        train_params = {"max_epochs": 1, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                    -                        "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": torch.nn.CrossEntropyLoss(),
                    -                        "optimizer": "SGD",
                    -                        "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                    -                        "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
                    -                        "metric_to_watch": "Accuracy",
                    -                        "greater_metric_to_watch_is_better": True}
                    -        trainer.train(model=model, training_params=train_params)
                    -
                         def test_initialization_rules(self):
                         def test_initialization_rules(self):
                             self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
                             self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
                                               train_loader=self.testcase_trainloader)
                                               train_loader=self.testcase_trainloader)
                    @@ -63,9 +47,7 @@ class InitializeWithDataloadersTest(unittest.TestCase):
                     
                     
                         def test_train_with_dataloaders(self):
                         def test_train_with_dataloaders(self):
                             trainer = Trainer(experiment_name="test_name", model_checkpoints_location="local")
                             trainer = Trainer(experiment_name="test_name", model_checkpoints_location="local")
                    -
                    -        trainer.build_model("resnet18")
                    -        model = models.get("resnet18", arch_params={"num_classes": 5})
                    +        model = models.get("resnet18", num_classes=5)
                             trainer.train(model=model,
                             trainer.train(model=model,
                                           training_params={"max_epochs": 2,
                                           training_params={"max_epochs": 2,
                                                            "lr_updates": [5, 6, 12],
                                                            "lr_updates": [5, 6, 12],
                    Discard
                    @@ -2,10 +2,10 @@ import unittest
                     
                     
                     from super_gradients.training import models
                     from super_gradients.training import models
                     from super_gradients.training import Trainer
                     from super_gradients.training import Trainer
                    +from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
                     from super_gradients.training.kd_trainer import KDTrainer
                     from super_gradients.training.kd_trainer import KDTrainer
                     import torch
                     import torch
                     from super_gradients.training.utils.utils import check_models_have_same_weights
                     from super_gradients.training.utils.utils import check_models_have_same_weights
                    -from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
                     from super_gradients.training.metrics import Accuracy
                     from super_gradients.training.metrics import Accuracy
                     from super_gradients.training.losses.kd_losses import KDLogitsLoss
                     from super_gradients.training.losses.kd_losses import KDLogitsLoss
                     
                     
                    @@ -14,8 +14,6 @@ class KDEMATest(unittest.TestCase):
                         @classmethod
                         @classmethod
                         def setUp(cls):
                         def setUp(cls):
                             cls.sg_trained_teacher = Trainer("sg_trained_teacher", device='cpu')
                             cls.sg_trained_teacher = Trainer("sg_trained_teacher", device='cpu')
                    -        cls.dataset_params = {"batch_size": 5}
                    -        cls.dataset = ClassificationTestDatasetInterface(dataset_params=cls.dataset_params)
                     
                     
                             cls.kd_train_params = {"max_epochs": 3, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                             cls.kd_train_params = {"max_epochs": 3, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                                                    "lr_warmup_epochs": 0, "initial_lr": 0.1,
                                                    "lr_warmup_epochs": 0, "initial_lr": 0.1,
                    @@ -32,12 +30,13 @@ class KDEMATest(unittest.TestCase):
                             """Check that the teacher EMA is a reference to the teacher net (not a copy)."""
                             """Check that the teacher EMA is a reference to the teacher net (not a copy)."""
                     
                     
                             kd_model = KDTrainer("test_teacher_ema_not_duplicated", device='cpu')
                             kd_model = KDTrainer("test_teacher_ema_not_duplicated", device='cpu')
                    -        kd_model.connect_dataset_interface(self.dataset)
                             student = models.get('resnet18', arch_params={'num_classes': 1000})
                             student = models.get('resnet18', arch_params={'num_classes': 1000})
                             teacher = models.get('resnet50', arch_params={'num_classes': 1000},
                             teacher = models.get('resnet50', arch_params={'num_classes': 1000},
                                                  pretrained_weights="imagenet")
                                                  pretrained_weights="imagenet")
                     
                     
                    -        kd_model.train(training_params=self.kd_train_params, student=student, teacher=teacher)
                    +        kd_model.train(training_params=self.kd_train_params, student=student, teacher=teacher,
                    +                       train_loader=classification_test_dataloader(),
                    +                       valid_loader=classification_test_dataloader())
                     
                     
                             self.assertTrue(kd_model.ema_model.ema.module.teacher is kd_model.net.module.teacher)
                             self.assertTrue(kd_model.ema_model.ema.module.teacher is kd_model.net.module.teacher)
                             self.assertTrue(kd_model.ema_model.ema.module.student is not kd_model.net.module.student)
                             self.assertTrue(kd_model.ema_model.ema.module.student is not kd_model.net.module.student)
                    @@ -48,24 +47,26 @@ class KDEMATest(unittest.TestCase):
                             # Create a KD trainer and train it
                             # Create a KD trainer and train it
                             train_params = self.kd_train_params.copy()
                             train_params = self.kd_train_params.copy()
                             kd_model = KDTrainer("test_kd_ema_ckpt_reload", device='cpu')
                             kd_model = KDTrainer("test_kd_ema_ckpt_reload", device='cpu')
                    -        kd_model.connect_dataset_interface(self.dataset)
                             student = models.get('resnet18', arch_params={'num_classes': 1000})
                             student = models.get('resnet18', arch_params={'num_classes': 1000})
                             teacher = models.get('resnet50', arch_params={'num_classes': 1000},
                             teacher = models.get('resnet50', arch_params={'num_classes': 1000},
                                                  pretrained_weights="imagenet")
                                                  pretrained_weights="imagenet")
                     
                     
                    -        kd_model.train(training_params=self.kd_train_params, student=student, teacher=teacher)
                    +        kd_model.train(training_params=self.kd_train_params, student=student, teacher=teacher,
                    +                       train_loader=classification_test_dataloader(),
                    +                       valid_loader=classification_test_dataloader())
                             ema_model = kd_model.ema_model.ema
                             ema_model = kd_model.ema_model.ema
                             net = kd_model.net
                             net = kd_model.net
                     
                     
                             # Load the trained KD trainer
                             # Load the trained KD trainer
                             kd_model = KDTrainer("test_kd_ema_ckpt_reload", device='cpu')
                             kd_model = KDTrainer("test_kd_ema_ckpt_reload", device='cpu')
                    -        kd_model.connect_dataset_interface(self.dataset)
                             student = models.get('resnet18', arch_params={'num_classes': 1000})
                             student = models.get('resnet18', arch_params={'num_classes': 1000})
                             teacher = models.get('resnet50', arch_params={'num_classes': 1000},
                             teacher = models.get('resnet50', arch_params={'num_classes': 1000},
                                                  pretrained_weights="imagenet")
                                                  pretrained_weights="imagenet")
                     
                     
                             train_params["resume"] = True
                             train_params["resume"] = True
                    -        kd_model.train(training_params=train_params, student=student, teacher=teacher)
                    +        kd_model.train(training_params=train_params, student=student, teacher=teacher,
                    +                       train_loader=classification_test_dataloader(),
                    +                       valid_loader=classification_test_dataloader())
                             reloaded_ema_model = kd_model.ema_model.ema
                             reloaded_ema_model = kd_model.ema_model.ema
                             reloaded_net = kd_model.net
                             reloaded_net = kd_model.net
                     
                     
                    @@ -79,7 +80,8 @@ class KDEMATest(unittest.TestCase):
                             self.assertTrue(not check_models_have_same_weights(reloaded_net, ema_model))
                             self.assertTrue(not check_models_have_same_weights(reloaded_net, ema_model))
                     
                     
                             # loaded student ema == loaded  student net (since load_ema_as_net = False)
                             # loaded student ema == loaded  student net (since load_ema_as_net = False)
                    -        self.assertTrue(not check_models_have_same_weights(reloaded_ema_model.module.student, reloaded_net.module.student))
                    +        self.assertTrue(
                    +            not check_models_have_same_weights(reloaded_ema_model.module.student, reloaded_net.module.student))
                     
                     
                             # loaded teacher ema == loaded teacher net (teacher always loads ema)
                             # loaded teacher ema == loaded teacher net (teacher always loads ema)
                             self.assertTrue(check_models_have_same_weights(reloaded_ema_model.module.teacher, reloaded_net.module.teacher))
                             self.assertTrue(check_models_have_same_weights(reloaded_ema_model.module.teacher, reloaded_net.module.teacher))
                    Discard
                    @@ -1,11 +1,12 @@
                     import os
                     import os
                     import unittest
                     import unittest
                     from copy import deepcopy
                     from copy import deepcopy
                    +
                    +from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
                     from super_gradients.training.kd_trainer.kd_trainer import KDTrainer
                     from super_gradients.training.kd_trainer.kd_trainer import KDTrainer
                     import torch
                     import torch
                     
                     
                     from super_gradients.training import models
                     from super_gradients.training import models
                    -from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
                     from super_gradients.training.losses.kd_losses import KDLogitsLoss
                     from super_gradients.training.losses.kd_losses import KDLogitsLoss
                     from super_gradients.training.metrics import Accuracy
                     from super_gradients.training.metrics import Accuracy
                     from super_gradients.training.models.classification_models.resnet import ResNet50, ResNet18
                     from super_gradients.training.models.classification_models.resnet import ResNet50, ResNet18
                    @@ -36,9 +37,6 @@ class PreTrainingEMANetCollector(PhaseCallback):
                     class KDTrainerTest(unittest.TestCase):
                     class KDTrainerTest(unittest.TestCase):
                         @classmethod
                         @classmethod
                         def setUp(cls):
                         def setUp(cls):
                    -        cls.dataset_params = {"batch_size": 5}
                    -        cls.dataset = ClassificationTestDatasetInterface(dataset_params=cls.dataset_params)
                    -
                             cls.kd_train_params = {"max_epochs": 3, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                             cls.kd_train_params = {"max_epochs": 3, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                                                    "lr_warmup_epochs": 0, "initial_lr": 0.1,
                                                    "lr_warmup_epochs": 0, "initial_lr": 0.1,
                                                    "loss": KDLogitsLoss(torch.nn.CrossEntropyLoss()),
                                                    "loss": KDLogitsLoss(torch.nn.CrossEntropyLoss()),
                    @@ -69,9 +67,10 @@ class KDTrainerTest(unittest.TestCase):
                             sg_model = KDTrainer("test_train_kd_module_external_models", device='cpu')
                             sg_model = KDTrainer("test_train_kd_module_external_models", device='cpu')
                             teacher_model = ResNet50(arch_params={}, num_classes=5)
                             teacher_model = ResNet50(arch_params={}, num_classes=5)
                             student_model = ResNet18(arch_params={}, num_classes=5)
                             student_model = ResNet18(arch_params={}, num_classes=5)
                    -        sg_model.connect_dataset_interface(self.dataset)
                     
                     
                    -        sg_model.train(training_params=self.kd_train_params, student=deepcopy(student_model), teacher=teacher_model)
                    +        sg_model.train(training_params=self.kd_train_params, student=deepcopy(student_model), teacher=teacher_model,
                    +                       train_loader=classification_test_dataloader(),
                    +                       valid_loader=classification_test_dataloader())
                     
                     
                             # TEACHER WEIGHT'S SHOULD REMAIN THE SAME
                             # TEACHER WEIGHT'S SHOULD REMAIN THE SAME
                             self.assertTrue(
                             self.assertTrue(
                    @@ -83,7 +82,6 @@ class KDTrainerTest(unittest.TestCase):
                     
                     
                         def test_train_model_with_input_adapter(self):
                         def test_train_model_with_input_adapter(self):
                             kd_trainer = KDTrainer("train_kd_module_with_with_input_adapter", device='cpu')
                             kd_trainer = KDTrainer("train_kd_module_with_with_input_adapter", device='cpu')
                    -        kd_trainer.connect_dataset_interface(self.dataset)
                             student = models.get('resnet18', arch_params={'num_classes': 5})
                             student = models.get('resnet18', arch_params={'num_classes': 5})
                             teacher = models.get('resnet50', arch_params={'num_classes': 5},
                             teacher = models.get('resnet50', arch_params={'num_classes': 5},
                                                  pretrained_weights="imagenet")
                                                  pretrained_weights="imagenet")
                    @@ -96,19 +94,21 @@ class KDTrainerTest(unittest.TestCase):
                             kd_arch_params = {
                             kd_arch_params = {
                                 "teacher_input_adapter": adapter}
                                 "teacher_input_adapter": adapter}
                             kd_trainer.train(training_params=self.kd_train_params, student=student, teacher=teacher,
                             kd_trainer.train(training_params=self.kd_train_params, student=student, teacher=teacher,
                    -                         kd_arch_params=kd_arch_params)
                    +                         kd_arch_params=kd_arch_params, train_loader=classification_test_dataloader(),
                    +                         valid_loader=classification_test_dataloader())
                     
                     
                             self.assertEqual(kd_trainer.net.module.teacher_input_adapter, adapter)
                             self.assertEqual(kd_trainer.net.module.teacher_input_adapter, adapter)
                     
                     
                         def test_load_ckpt_best_for_student(self):
                         def test_load_ckpt_best_for_student(self):
                             kd_trainer = KDTrainer("test_load_ckpt_best", device='cpu')
                             kd_trainer = KDTrainer("test_load_ckpt_best", device='cpu')
                    -        kd_trainer.connect_dataset_interface(self.dataset)
                             student = models.get('resnet18', arch_params={'num_classes': 5})
                             student = models.get('resnet18', arch_params={'num_classes': 5})
                             teacher = models.get('resnet50', arch_params={'num_classes': 5},
                             teacher = models.get('resnet50', arch_params={'num_classes': 5},
                                                  pretrained_weights="imagenet")
                                                  pretrained_weights="imagenet")
                             train_params = self.kd_train_params.copy()
                             train_params = self.kd_train_params.copy()
                             train_params["max_epochs"] = 1
                             train_params["max_epochs"] = 1
                    -        kd_trainer.train(training_params=train_params, student=student, teacher=teacher)
                    +        kd_trainer.train(training_params=train_params, student=student, teacher=teacher,
                    +                         train_loader=classification_test_dataloader(),
                    +                         valid_loader=classification_test_dataloader())
                             best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
                             best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
                     
                     
                             student_reloaded = models.get('resnet18', arch_params={'num_classes': 5},
                             student_reloaded = models.get('resnet18', arch_params={'num_classes': 5},
                    @@ -119,14 +119,15 @@ class KDTrainerTest(unittest.TestCase):
                     
                     
                         def test_load_ckpt_best_for_student_with_ema(self):
                         def test_load_ckpt_best_for_student_with_ema(self):
                             kd_trainer = KDTrainer("test_load_ckpt_best", device='cpu')
                             kd_trainer = KDTrainer("test_load_ckpt_best", device='cpu')
                    -        kd_trainer.connect_dataset_interface(self.dataset)
                             student = models.get('resnet18', arch_params={'num_classes': 5})
                             student = models.get('resnet18', arch_params={'num_classes': 5})
                             teacher = models.get('resnet50', arch_params={'num_classes': 5},
                             teacher = models.get('resnet50', arch_params={'num_classes': 5},
                                                  pretrained_weights="imagenet")
                                                  pretrained_weights="imagenet")
                             train_params = self.kd_train_params.copy()
                             train_params = self.kd_train_params.copy()
                             train_params["max_epochs"] = 1
                             train_params["max_epochs"] = 1
                             train_params["ema"] = True
                             train_params["ema"] = True
                    -        kd_trainer.train(training_params=train_params, student=student, teacher=teacher)
                    +        kd_trainer.train(training_params=train_params, student=student, teacher=teacher,
                    +                         train_loader=classification_test_dataloader(),
                    +                         valid_loader=classification_test_dataloader())
                             best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
                             best_student_ckpt = os.path.join(kd_trainer.checkpoints_dir_path, "ckpt_best.pth")
                     
                     
                             student_reloaded = models.get('resnet18', arch_params={'num_classes': 5},
                             student_reloaded = models.get('resnet18', arch_params={'num_classes': 5},
                    @@ -137,17 +138,17 @@ class KDTrainerTest(unittest.TestCase):
                     
                     
                         def test_resume_kd_training(self):
                         def test_resume_kd_training(self):
                             kd_trainer = KDTrainer("test_resume_training_start", device='cpu')
                             kd_trainer = KDTrainer("test_resume_training_start", device='cpu')
                    -        kd_trainer.connect_dataset_interface(self.dataset)
                             student = models.get('resnet18', arch_params={'num_classes': 5})
                             student = models.get('resnet18', arch_params={'num_classes': 5})
                             teacher = models.get('resnet50', arch_params={'num_classes': 5},
                             teacher = models.get('resnet50', arch_params={'num_classes': 5},
                                                  pretrained_weights="imagenet")
                                                  pretrained_weights="imagenet")
                             train_params = self.kd_train_params.copy()
                             train_params = self.kd_train_params.copy()
                             train_params["max_epochs"] = 1
                             train_params["max_epochs"] = 1
                    -        kd_trainer.train(training_params=train_params, student=student, teacher=teacher)
                    +        kd_trainer.train(training_params=train_params, student=student, teacher=teacher,
                    +                         train_loader=classification_test_dataloader(),
                    +                         valid_loader=classification_test_dataloader())
                             latest_net = deepcopy(kd_trainer.net)
                             latest_net = deepcopy(kd_trainer.net)
                     
                     
                             kd_trainer = KDTrainer("test_resume_training_start", device='cpu')
                             kd_trainer = KDTrainer("test_resume_training_start", device='cpu')
                    -        kd_trainer.connect_dataset_interface(self.dataset)
                             student = models.get('resnet18', arch_params={'num_classes': 5})
                             student = models.get('resnet18', arch_params={'num_classes': 5})
                             teacher = models.get('resnet50', arch_params={'num_classes': 5},
                             teacher = models.get('resnet50', arch_params={'num_classes': 5},
                                                  pretrained_weights="imagenet")
                                                  pretrained_weights="imagenet")
                    @@ -156,7 +157,9 @@ class KDTrainerTest(unittest.TestCase):
                             train_params["resume"] = True
                             train_params["resume"] = True
                             collector = PreTrainingNetCollector()
                             collector = PreTrainingNetCollector()
                             train_params["phase_callbacks"] = [collector]
                             train_params["phase_callbacks"] = [collector]
                    -        kd_trainer.train(training_params=train_params, student=student, teacher=teacher)
                    +        kd_trainer.train(training_params=train_params, student=student, teacher=teacher,
                    +                         train_loader=classification_test_dataloader(),
                    +                         valid_loader=classification_test_dataloader())
                     
                     
                             self.assertTrue(
                             self.assertTrue(
                                 check_models_have_same_weights(collector.net, latest_net))
                                 check_models_have_same_weights(collector.net, latest_net))
                    Discard
                    @@ -1,9 +1,9 @@
                     import unittest
                     import unittest
                     from super_gradients.training import Trainer
                     from super_gradients.training import Trainer
                    +from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
                     from super_gradients.training.metrics import Accuracy, Top5
                     from super_gradients.training.metrics import Accuracy, Top5
                     from super_gradients.training.utils.callbacks import PhaseCallback, Phase, PhaseContext
                     from super_gradients.training.utils.callbacks import PhaseCallback, Phase, PhaseContext
                     from super_gradients.training.utils.utils import check_models_have_same_weights
                     from super_gradients.training.utils.utils import check_models_have_same_weights
                    -from super_gradients.training.datasets import ClassificationTestDatasetInterface
                     from super_gradients.training.models import LeNet
                     from super_gradients.training.models import LeNet
                     from copy import deepcopy
                     from copy import deepcopy
                     
                     
                    @@ -19,8 +19,6 @@ class PreTrainingEMANetCollector(PhaseCallback):
                     
                     
                     class LoadCheckpointWithEmaTest(unittest.TestCase):
                     class LoadCheckpointWithEmaTest(unittest.TestCase):
                         def setUp(self) -> None:
                         def setUp(self) -> None:
                    -        self.dataset_params = {"batch_size": 4}
                    -        self.dataset = ClassificationTestDatasetInterface(dataset_params=self.dataset_params)
                             self.train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                             self.train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                                                  "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": 'SGD',
                                                  "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": 'SGD',
                                                  "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                                                  "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                    @@ -32,22 +30,23 @@ class LoadCheckpointWithEmaTest(unittest.TestCase):
                             # Define Model
                             # Define Model
                             net = LeNet()
                             net = LeNet()
                             trainer = Trainer("ema_ckpt_test", model_checkpoints_location='local')
                             trainer = Trainer("ema_ckpt_test", model_checkpoints_location='local')
                    -
                    -        trainer.connect_dataset_interface(self.dataset)
                    -
                    -        trainer.train(model=net, training_params=self.train_params)
                    +        trainer.train(model=net, training_params=self.train_params,
                    +                      train_loader=classification_test_dataloader(),
                    +                      valid_loader=classification_test_dataloader())
                     
                     
                             ema_model = trainer.ema_model.ema
                             ema_model = trainer.ema_model.ema
                     
                     
                             # TRAIN FOR 1 MORE EPOCH AND COMPARE THE NET AT THE BEGINNING OF EPOCH 3 AND THE END OF EPOCH NUMBER 2
                             # TRAIN FOR 1 MORE EPOCH AND COMPARE THE NET AT THE BEGINNING OF EPOCH 3 AND THE END OF EPOCH NUMBER 2
                             net = LeNet()
                             net = LeNet()
                             trainer = Trainer("ema_ckpt_test", model_checkpoints_location='local')
                             trainer = Trainer("ema_ckpt_test", model_checkpoints_location='local')
                    -        trainer.connect_dataset_interface(self.dataset)
                    +
                             net_collector = PreTrainingEMANetCollector()
                             net_collector = PreTrainingEMANetCollector()
                             self.train_params["resume"] = True
                             self.train_params["resume"] = True
                             self.train_params["max_epochs"] = 3
                             self.train_params["max_epochs"] = 3
                             self.train_params["phase_callbacks"] = [net_collector]
                             self.train_params["phase_callbacks"] = [net_collector]
                    -        trainer.train(model=net, training_params=self.train_params)
                    +        trainer.train(model=net, training_params=self.train_params,
                    +                      train_loader=classification_test_dataloader(),
                    +                      valid_loader=classification_test_dataloader())
                     
                     
                             reloaded_ema_model = net_collector.net.ema
                             reloaded_ema_model = net_collector.net.ema
                     
                     
                    Discard

                    Some files were not shown because too many files changed in this diff