Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#1001 Bug/sg 861 decouple qat from train from config

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bug/SG-861_decouple_qat_from_train_from_config
@@ -465,6 +465,8 @@ jobs:
             python3.8 -m pip install -r requirements.txt
             python3.8 -m pip install .
             python3.8 -m pip install torch torchvision torchaudio
+            python3.8 -m pip install pytorch-quantization==2.1.2 --extra-index-url https://pypi.ngc.nvidia.com
+
             python3.8 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=coco2017_pose_dekr_w32_no_dc experiment_name=shortened_coco2017_pose_dekr_w32_ap_test batch_size=4 val_batch_size=8 epochs=1 training_hyperparams.lr_warmup_steps=0 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=1000 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4
             python3.8 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=cifar10_resnet experiment_name=shortened_cifar10_resnet_accuracy_test epochs=100 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
             python3.8 src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test
Discard
@@ -3,6 +3,7 @@ from super_gradients.training import losses, utils, datasets_utils, DataAugmenta
 from super_gradients.common.registry.registry import ARCHITECTURES
 from super_gradients.sanity_check import env_sanity_check
 from super_gradients.training.utils.distributed_training_utils import setup_device
+from super_gradients.training.pre_launch_callbacks import AutoTrainBatchSizeSelectionCallback, QATRecipeModificationCallback
 
 __all__ = [
     "ARCHITECTURES",
@@ -18,6 +19,8 @@ __all__ = [
     "is_distributed",
     "env_sanity_check",
     "setup_device",
+    "QATRecipeModificationCallback",
+    "AutoTrainBatchSizeSelectionCallback",
 ]
 
 __version__ = "3.1.1"
Discard
@@ -8,13 +8,12 @@ For recipe's specific instructions and details refer to the recipe's configurati
 import hydra
 from omegaconf import DictConfig
 
-from super_gradients import init_trainer
-from super_gradients.training.qat_trainer.qat_trainer import QATTrainer
+from super_gradients import init_trainer, Trainer
 
 
 @hydra.main(config_path="recipes", version_base="1.2")
 def _main(cfg: DictConfig) -> None:
-    QATTrainer.train_from_config(cfg)
+    Trainer.quantize_from_config(cfg)
 
 
 def main():
Discard
@@ -5,6 +5,7 @@ from super_gradients.training.sg_trainer import Trainer
 from super_gradients.training.kd_trainer import KDTrainer
 from super_gradients.training.qat_trainer import QATTrainer
 from super_gradients.common import MultiGPUMode, StrictLoad, EvaluationType
+from super_gradients.training.pre_launch_callbacks import modify_params_for_qat
 
 __all__ = [
     "distributed_training_utils",
@@ -16,4 +17,5 @@ __all__ = [
     "MultiGPUMode",
     "StrictLoad",
     "EvaluationType",
+    "modify_params_for_qat",
 ]
Discard
@@ -2,7 +2,8 @@ from super_gradients.training.pre_launch_callbacks.pre_launch_callbacks import (
     PreLaunchCallback,
     AutoTrainBatchSizeSelectionCallback,
     QATRecipeModificationCallback,
+    modify_params_for_qat,
 )
 from super_gradients.common.registry.registry import ALL_PRE_LAUNCH_CALLBACKS
 
-__all__ = ["PreLaunchCallback", "AutoTrainBatchSizeSelectionCallback", "QATRecipeModificationCallback", "ALL_PRE_LAUNCH_CALLBACKS"]
+__all__ = ["PreLaunchCallback", "AutoTrainBatchSizeSelectionCallback", "QATRecipeModificationCallback", "ALL_PRE_LAUNCH_CALLBACKS", "modify_params_for_qat"]
Discard
@@ -5,6 +5,7 @@ from typing import Union
 from omegaconf import DictConfig
 import torch
 
+from super_gradients.common.environment.cfg_utils import load_recipe
 from super_gradients.common.registry.registry import register_pre_launch_callback
 from super_gradients import is_distributed
 from super_gradients.common.abstractions.abstract_logger import get_logger
@@ -13,6 +14,8 @@ from torch.distributed import barrier
 import cv2
 import numpy as np
 
+from super_gradients.training.utils import get_param
+
 logger = get_logger(__name__)
 
 
@@ -70,7 +73,7 @@ class AutoTrainBatchSizeSelectionCallback(PreLaunchCallback):
     :param max_batch_size: int, optional, upper limit of the batch sizes to try. When None, the search will continue until
      the maximal batch size that does not raise CUDA OUT OF MEMORY is found (deafult=None).
 
-    :param scale_lr: bool, whether to linearly scale cfg.training_hyperparams.initial_lr, i.e multiply by
+    :param scale_lr: bool, whether to linearly scale cfg.training_hyperparamsinitial_lr, i.e multiply by
      FOUND_BATCH_SIZE/cfg.dataset_params.train_datalaoder_params.batch_size (default=True)
     :param mode: str, one of ["fastest","largest"], whether to select the largest batch size that fits memory or the one
      that the resulted in overall fastest execution.
@@ -103,14 +106,14 @@ class AutoTrainBatchSizeSelectionCallback(PreLaunchCallback):
             load_backbone=cfg.checkpoint_params.load_backbone,
         )
         tmp_cfg = deepcopy(cfg)
-        tmp_cfg.training_hyperparams.batch_accumulate = 1
-        tmp_cfg.training_hyperparams.max_train_batches = self.num_forward_passes
-        tmp_cfg.training_hyperparams.run_validation_freq = 2
-        tmp_cfg.training_hyperparams.silent_mode = True
-        tmp_cfg.training_hyperparams.save_model = False
-        tmp_cfg.training_hyperparams.max_epochs = 1
-        tmp_cfg.training_hyperparams.average_best_models = False
-        tmp_cfg.training_hyperparams.kill_ddp_pgroup_on_end = False
+        tmp_cfg.training_hyperparamsbatch_accumulate = 1
+        tmp_cfg.training_hyperparamsmax_train_batches = self.num_forward_passes
+        tmp_cfg.training_hyperparamsrun_validation_freq = 2
+        tmp_cfg.training_hyperparamssilent_mode = True
+        tmp_cfg.training_hyperparamssave_model = False
+        tmp_cfg.training_hyperparamsmax_epochs = 1
+        tmp_cfg.training_hyperparamsaverage_best_models = False
+        tmp_cfg.training_hyperparamskill_ddp_pgroup_on_end = False
         tmp_cfg.pre_launch_callbacks_list = []
 
         fastest_batch_time = np.inf
@@ -166,7 +169,7 @@ class AutoTrainBatchSizeSelectionCallback(PreLaunchCallback):
     def _adapt_lr_if_needed(self, cfg: DictConfig, found_batch_size: int) -> DictConfig:
         if self.scale_lr:
             scale_factor = found_batch_size / cfg.dataset_params.train_dataloader_params.batch_size
-            cfg.training_hyperparams.initial_lr = cfg.training_hyperparams.initial_lr * scale_factor
+            cfg.training_hyperparamsinitial_lr = cfg.training_hyperparamsinitial_lr * scale_factor
         return cfg
 
     @classmethod
@@ -180,6 +183,151 @@ class AutoTrainBatchSizeSelectionCallback(PreLaunchCallback):
             barrier()
 
 
+def modify_params_for_qat(
+    training_hyperparams,
+    train_dataset_params,
+    val_dataset_params,
+    train_dataloader_params,
+    val_dataloader_params,
+    quantization_params=None,
+    batch_size_divisor: int = 2,
+    max_epochs_divisor: int = 10,
+    lr_decay_factor: float = 0.01,
+    warmup_epochs_divisor: int = 10,
+    cosine_final_lr_ratio: float = 0.01,
+    disable_phase_callbacks: bool = True,
+    disable_augmentations: bool = False,
+):
+    """
+
+    This method modifies the recipe for QAT to implement rules of thumb based on the regular non-qat recipe.
+    It does so by manipulating the training_hyperparams, train_dataloader_params, val_dataloader_params, train_dataset_params, val_dataset_params.
+    Usage:
+        trainer = Trainer("test_launch_qat_with_minimal_changes")
+        net = ResNet18(num_classes=10, arch_params={})
+        train_params = {...}
+
+        train_dataset_params = {
+            "transforms": [...
+            ]
+        }
+
+        train_dataloader_params = {"batch_size": 256}
+
+        val_dataset_params = {"transforms": [ToTensor(), Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])]}
+
+        val_dataloader_params = {"batch_size": 256}
+
+        train_loader = cifar10_train(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params)
+        valid_loader = cifar10_val(dataset_params=val_dataset_params, dataloader_params=val_dataloader_params)
+
+        trainer.train(
+            model=net,
+            training_params=train_params,
+            train_loader=train_loader,
+            valid_loader=valid_loader,
+        )
+
+        train_params, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params = modify_params_for_qat(
+            train_params, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params
+        )
+
+        train_loader = cifar10_train(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params)
+        valid_loader = cifar10_val(dataset_params=val_dataset_params, dataloader_params=val_dataloader_params)
+
+        trainer.qat(
+            model=net,
+            training_params=train_params,
+            train_loader=train_loader,
+            valid_loader=valid_loader,
+            calib_loader=train_loader,
+        )
+
+    :param val_dataset_params: Dict, validation dataset_params to be passed to dataloaders.get(...) when instantiating the train dataloader.
+    :param train_dataset_params: Dict, train dataset_params to be passed to dataloaders.get(...) when instantiating the validation dataloader.
+    :param val_dataloader_params: Dict, validation dataloader_params to be passed to dataloaders.get(...) when instantiating the validation dataloader.
+    :param train_dataloader_params: Dict, train dataloader_params to be passed to dataloaders.get(...) when instantiating the train dataloader.
+    :param training_hyperparams: Dict, train parameters passed to Trainer.qat(...)
+    :param quantization_params: Dict, quantization parameters as passed to Trainer.qat(...). When None, will use the
+     default parameters in super_gradients/recipes/quantization_params/default_quantization_params.yaml
+    :param int batch_size_divisor: Divisor used to calculate the batch size. Default value is 2.
+    :param int max_epochs_divisor: Divisor used to calculate the maximum number of epochs. Default value is 10.
+    :param float lr_decay_factor: Factor used to decay the learning rate, weight decay and warmup. Default value is 0.01.
+    :param int warmup_epochs_divisor: Divisor used to calculate the number of warm-up epochs. Default value is 10.
+    :param float cosine_final_lr_ratio: Ratio used to determine the final learning rate in a cosine annealing schedule. Default value is 0.01.
+    :param bool disable_phase_callbacks: Flag to control to disable phase callbacks, which can interfere with QAT. Default value is True.
+    :param bool disable_augmentations: Flag to control to disable phase augmentations, which can interfere with QAT. Default value is False.
+    :return: modified (copy) training_hyperparams, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params
+    """
+    if quantization_params is None:
+        quantization_params = load_recipe("quantization_params/default_quantization_params").quantization_params
+
+    quantization_params = deepcopy(quantization_params)
+    training_hyperparams = deepcopy(training_hyperparams)
+    train_dataloader_params = deepcopy(train_dataloader_params)
+    val_dataloader_params = deepcopy(val_dataloader_params)
+    train_dataset_params = deepcopy(train_dataset_params)
+    val_dataset_params = deepcopy(val_dataset_params)
+
+    if "max_epochs" not in training_hyperparams.keys():
+        raise ValueError("max_epochs is a required field in training_hyperparams for QAT modification.")
+
+    if "initial_lr" not in training_hyperparams.keys():
+        raise ValueError("initial_lr is a required field in training_hyperparams for QAT modification.")
+
+    if "optimizer_params" not in training_hyperparams.keys():
+        raise ValueError("optimizer_params is a required field in training_hyperparams for QAT modification.")
+
+    if "weight_decay" not in training_hyperparams["optimizer_params"].keys():
+        raise ValueError("weight_decay is a required field in training_hyperparams['optimizer_params'] for QAT modification.")
+
+    # Q/DQ Layers take a lot of space for activations in training mode
+    if get_param(quantization_params, "selective_quantizer_params") and get_param(quantization_params["selective_quantizer_params"], "learn_amax"):
+        train_dataloader_params["batch_size"] //= batch_size_divisor
+        val_dataloader_params["batch_size"] //= batch_size_divisor
+
+        logger.warning(f"New dataset_params.train_dataloader_params.batch_size: {train_dataloader_params['batch_size']}")
+        logger.warning(f"New dataset_params.val_dataloader_params.batch_size: {val_dataloader_params['batch_size']}")
+    training_hyperparams["max_epochs"] //= max_epochs_divisor
+    logger.warning(f"New number of epochs: {training_hyperparams['max_epochs']}")
+    training_hyperparams["initial_lr"] *= lr_decay_factor
+    if get_param(training_hyperparams, "warmup_initial_lr") is not None:
+        training_hyperparams["warmup_initial_lr"] *= lr_decay_factor
+    else:
+        training_hyperparams["warmup_initial_lr"] = training_hyperparams["initial_lr"] * 0.01
+    training_hyperparams["optimizer_params"]["weight_decay"] *= lr_decay_factor
+    logger.warning(f"New learning rate: {training_hyperparams['initial_lr']}")
+    logger.warning(f"New weight decay: {training_hyperparams['optimizer_params']['weight_decay']}")
+    # as recommended by pytorch-quantization docs
+    if get_param(training_hyperparams, "lr_mode") != "cosine":
+        training_hyperparams["lr_mode"] = "cosine"
+    training_hyperparams["cosine_final_lr_ratio"] = cosine_final_lr_ratio
+    logger.warning(
+        f"lr_mode will be set to cosine for QAT run instead of {get_param(training_hyperparams, 'lr_mode')} with "
+        f"cosine_final_lr_ratio={cosine_final_lr_ratio}"
+    )
+
+    training_hyperparams["lr_warmup_epochs"] = (training_hyperparams["max_epochs"] // warmup_epochs_divisor) or 1
+    logger.warning(f"New lr_warmup_epochs: {training_hyperparams['lr_warmup_epochs']}")
+
+    # do mess with Q/DQ
+    if get_param(training_hyperparams, "ema"):
+        logger.warning("EMA will be disabled for QAT run.")
+        training_hyperparams["ema"] = False
+    if get_param(training_hyperparams, "sync_bn"):
+        logger.warning("SyncBatchNorm will be disabled for QAT run.")
+        training_hyperparams["sync_bn"] = False
+    if disable_phase_callbacks and get_param(training_hyperparams, "phase_callbacks") is not None and len(training_hyperparams["phase_callbacks"]) > 0:
+        logger.warning(f"Recipe contains {len(training_hyperparams['phase_callbacks'])} phase callbacks. All of them will be disabled.")
+        training_hyperparams["phase_callbacks"] = []
+    # no augmentations
+    if disable_augmentations and "transforms" in val_dataset_params:
+        logger.warning("Augmentations will be disabled for QAT run. Using validation transforms instead.")
+        train_dataset_params["transforms"] = val_dataset_params["transforms"]
+
+    return training_hyperparams, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params
+
+
 @register_pre_launch_callback()
 class QATRecipeModificationCallback(PreLaunchCallback):
     """
@@ -209,7 +357,7 @@ class QATRecipeModificationCallback(PreLaunchCallback):
             disable_phase_callbacks: True
             disable_augmentations: False
 
-    USE THIS CALLBACK ONLY WITH QATTrainer!
+    USE THIS CALLBACK ONLY WITH Trainer.quantize_from_config
     """
 
     def __init__(
@@ -235,54 +383,31 @@ class QATRecipeModificationCallback(PreLaunchCallback):
 
         cfg = copy.deepcopy(cfg)
 
-        # Q/DQ Layers take a lot of space for activations in training mode
-        if cfg.quantization_params.selective_quantizer_params.learn_amax:
-            cfg.dataset_params.train_dataloader_params.batch_size //= self.batch_size_divisor
-            cfg.dataset_params.val_dataloader_params.batch_size //= self.batch_size_divisor
-
-            logger.warning(f"New dataset_params.train_dataloader_params.batch_size: {cfg.dataset_params.train_dataloader_params.batch_size}")
-            logger.warning(f"New dataset_params.val_dataloader_params.batch_size: {cfg.dataset_params.val_dataloader_params.batch_size}")
-
-        cfg.training_hyperparams.max_epochs //= self.max_epochs_divisor
-        logger.warning(f"New number of epochs: {cfg.training_hyperparams.max_epochs}")
-
-        cfg.training_hyperparams.initial_lr *= self.lr_decay_factor
-        if cfg.training_hyperparams.warmup_initial_lr is not None:
-            cfg.training_hyperparams.warmup_initial_lr *= self.lr_decay_factor
-        else:
-            cfg.training_hyperparams.warmup_initial_lr = cfg.training_hyperparams.initial_lr * 0.01
-
-        cfg.training_hyperparams.optimizer_params.weight_decay *= self.lr_decay_factor
-
-        logger.warning(f"New learning rate: {cfg.training_hyperparams.initial_lr}")
-        logger.warning(f"New weight decay: {cfg.training_hyperparams.optimizer_params.weight_decay}")
-
-        # as recommended by pytorch-quantization docs
-        cfg.training_hyperparams.lr_mode = "cosine"
-        cfg.training_hyperparams.lr_warmup_epochs = (cfg.training_hyperparams.max_epochs // self.warmup_epochs_divisor) or 1
-        cfg.training_hyperparams.cosine_final_lr_ratio = self.cosine_final_lr_ratio
-
-        # do mess with Q/DQ
-        if cfg.training_hyperparams.ema:
-            logger.warning("EMA will be disabled for QAT run.")
-            cfg.training_hyperparams.ema = False
-
-        if cfg.training_hyperparams.sync_bn:
-            logger.warning("SyncBatchNorm will be disabled for QAT run.")
-            cfg.training_hyperparams.sync_bn = False
-
-        if self.disable_phase_callbacks and len(cfg.training_hyperparams.phase_callbacks) > 0:
-            logger.warning(f"Recipe contains {len(cfg.training_hyperparams.phase_callbacks)} phase callbacks. All of them will be disabled.")
-            cfg.training_hyperparams.phase_callbacks = []
+        (
+            cfg.training_hyperparams,
+            cfg.dataset_params.train_dataset_params,
+            cfg.dataset_params.val_dataset_params,
+            cfg.dataset_params.train_dataloader_params,
+            cfg.dataset_params.val_dataloader_params,
+        ) = modify_params_for_qat(
+            training_hyperparams=cfg.training_hyperparams,
+            train_dataset_params=cfg.dataset_params.train_dataset_params,
+            train_dataloader_params=cfg.dataset_params.train_dataloader_params,
+            val_dataset_params=cfg.dataset_params.val_dataset_params,
+            val_dataloader_params=cfg.dataset_params.train_dataloader_params,
+            quantization_params=cfg.quantization_params,
+            batch_size_divisor=self.batch_size_divisor,
+            disable_phase_callbacks=self.disable_phase_callbacks,
+            cosine_final_lr_ratio=self.cosine_final_lr_ratio,
+            warmup_epochs_divisor=self.warmup_epochs_divisor,
+            lr_decay_factor=self.lr_decay_factor,
+            max_epochs_divisor=self.max_epochs_divisor,
+            disable_augmentations=self.disable_augmentations,
+        )
 
         if cfg.multi_gpu != "OFF" or cfg.num_gpus != 1:
             logger.warning(f"Recipe requests multi_gpu={cfg.multi_gpu} and num_gpus={cfg.num_gpus}. Changing to multi_gpu=OFF and num_gpus=1")
             cfg.multi_gpu = "OFF"
             cfg.num_gpus = 1
 
-        # no augmentations
-        if self.disable_augmentations and "transforms" in cfg.dataset_params.val_dataset_params:
-            logger.warning("Augmentations will be disabled for QAT run.")
-            cfg.dataset_params.train_dataset_params.transforms = cfg.dataset_params.val_dataset_params.transforms
-
         return cfg
Discard
@@ -1,201 +1,17 @@
-import os
 from typing import Union, Tuple
 
-import copy
-import hydra
-import torch.cuda
+from deprecated import deprecated
 from omegaconf import DictConfig
-from omegaconf import OmegaConf
 from torch import nn
 
 from super_gradients.common.abstractions.abstract_logger import get_logger
-from super_gradients.common.environment.device_utils import device_config
-from super_gradients.training import utils as core_utils, models, dataloaders
 from super_gradients.training.sg_trainer import Trainer
-from super_gradients.training.utils import get_param
-from super_gradients.training.utils.distributed_training_utils import setup_device
-from super_gradients.modules.repvgg_block import fuse_repvgg_blocks_residual_branches
 
 logger = get_logger(__name__)
-try:
-    from super_gradients.training.utils.quantization.calibrator import QuantizationCalibrator
-    from super_gradients.training.utils.quantization.export import export_quantized_module_to_onnx
-    from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer
-
-    _imported_pytorch_quantization_failure = None
-
-except (ImportError, NameError, ModuleNotFoundError) as import_err:
-    logger.debug("Failed to import pytorch_quantization:")
-    logger.debug(import_err)
-    _imported_pytorch_quantization_failure = import_err
 
 
 class QATTrainer(Trainer):
     @classmethod
-    def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]:
-        """
-        Perform quantization aware training (QAT) according to a recipe configuration.
-
-        This method will instantiate all the objects specified in the recipe, build and quantize the model,
-        and calibrate the quantized model. The resulting quantized model and the output of the trainer.train()
-        method will be returned.
-
-        The quantized model will be exported to ONNX along with other checkpoints.
-
-        :param cfg: The parsed DictConfig object from yaml recipe files or a dictionary.
-        :return: A tuple containing the quantized model and the output of trainer.train() method.
-        :rtype: Tuple[nn.Module, Tuple]
-
-        :raises ValueError: If the recipe does not have the required key `quantization_params` or
-        `checkpoint_params.checkpoint_path` in it.
-        :raises NotImplementedError: If the recipe requests multiple GPUs or num_gpus is not equal to 1.
-        :raises ImportError: If pytorch-quantization import was unsuccessful
-
-        """
-        if _imported_pytorch_quantization_failure is not None:
-            raise _imported_pytorch_quantization_failure
-
-        # INSTANTIATE ALL OBJECTS IN CFG
-        cfg = hydra.utils.instantiate(cfg)
-
-        # TRIGGER CFG MODIFYING CALLBACKS
-        cfg = cls._trigger_cfg_modifying_callbacks(cfg)
-
-        if "quantization_params" not in cfg:
-            raise ValueError("Your recipe does not have quantization_params. Add them to use QAT.")
-
-        if "checkpoint_path" not in cfg.checkpoint_params:
-            raise ValueError("Starting checkpoint is a must for QAT finetuning.")
-
-        num_gpus = core_utils.get_param(cfg, "num_gpus")
-        multi_gpu = core_utils.get_param(cfg, "multi_gpu")
-        device = core_utils.get_param(cfg, "device")
-        if num_gpus != 1:
-            raise NotImplementedError(
-                f"Recipe requests multi_gpu={cfg.multi_gpu} and num_gpus={cfg.num_gpus}. QAT is proven to work correctly only with multi_gpu=OFF and num_gpus=1"
-            )
-
-        setup_device(device=device, multi_gpu=multi_gpu, num_gpus=num_gpus)
-
-        # INSTANTIATE DATA LOADERS
-        train_dataloader = dataloaders.get(
-            name=get_param(cfg, "train_dataloader"),
-            dataset_params=copy.deepcopy(cfg.dataset_params.train_dataset_params),
-            dataloader_params=copy.deepcopy(cfg.dataset_params.train_dataloader_params),
-        )
-
-        val_dataloader = dataloaders.get(
-            name=get_param(cfg, "val_dataloader"),
-            dataset_params=copy.deepcopy(cfg.dataset_params.val_dataset_params),
-            dataloader_params=copy.deepcopy(cfg.dataset_params.val_dataloader_params),
-        )
-
-        if "calib_dataloader" in cfg:
-            calib_dataloader_name = get_param(cfg, "calib_dataloader")
-            calib_dataloader_params = copy.deepcopy(cfg.dataset_params.calib_dataloader_params)
-            calib_dataset_params = copy.deepcopy(cfg.dataset_params.calib_dataset_params)
-        else:
-            calib_dataloader_name = get_param(cfg, "train_dataloader")
-            calib_dataloader_params = copy.deepcopy(cfg.dataset_params.train_dataloader_params)
-            calib_dataset_params = copy.deepcopy(cfg.dataset_params.train_dataset_params)
-
-            # if we use whole dataloader for calibration, don't shuffle it
-            # HistogramCalibrator collection routine is sensitive to order of batches and produces slightly different results
-            # if we use several batches, we don't want it to be from one class if it's sequential in dataloader
-            # model is in eval mode, so BNs will not be affected
-            calib_dataloader_params.shuffle = cfg.quantization_params.calib_params.num_calib_batches is not None
-            # we don't need training transforms during calibration, distribution of activations will be skewed
-            calib_dataset_params.transforms = cfg.dataset_params.val_dataset_params.transforms
-
-        calib_dataloader = dataloaders.get(
-            name=calib_dataloader_name,
-            dataset_params=calib_dataset_params,
-            dataloader_params=calib_dataloader_params,
-        )
-
-        # BUILD MODEL
-        model = models.get(
-            model_name=cfg.arch_params.get("model_name", None) or cfg.architecture,
-            num_classes=cfg.get("num_classes", None) or cfg.arch_params.num_classes,
-            arch_params=cfg.arch_params,
-            strict_load=cfg.checkpoint_params.strict_load,
-            pretrained_weights=cfg.checkpoint_params.pretrained_weights,
-            checkpoint_path=cfg.checkpoint_params.checkpoint_path,
-            load_backbone=False,
-        )
-        model.to(device_config.device)
-
-        # QUANTIZE MODEL
-        model.eval()
-        fuse_repvgg_blocks_residual_branches(model)
-
-        q_util = SelectiveQuantizer(
-            default_quant_modules_calibrator_weights=cfg.quantization_params.selective_quantizer_params.calibrator_w,
-            default_quant_modules_calibrator_inputs=cfg.quantization_params.selective_quantizer_params.calibrator_i,
-            default_per_channel_quant_weights=cfg.quantization_params.selective_quantizer_params.per_channel,
-            default_learn_amax=cfg.quantization_params.selective_quantizer_params.learn_amax,
-            verbose=cfg.quantization_params.calib_params.verbose,
-        )
-        q_util.register_skip_quantization(layer_names=cfg.quantization_params.selective_quantizer_params.skip_modules)
-        q_util.quantize_module(model)
-
-        # CALIBRATE MODEL
-        logger.info("Calibrating model...")
-        calibrator = QuantizationCalibrator(
-            verbose=cfg.quantization_params.calib_params.verbose,
-            torch_hist=True,
-        )
-        calibrator.calibrate_model(
-            model,
-            method=cfg.quantization_params.calib_params.histogram_calib_method,
-            calib_data_loader=calib_dataloader,
-            num_calib_batches=cfg.quantization_params.calib_params.num_calib_batches or len(train_dataloader),
-            percentile=get_param(cfg.quantization_params.calib_params, "percentile", 99.99),
-        )
-        calibrator.reset_calibrators(model)  # release memory taken by calibrators
-
-        # VALIDATE PTQ MODEL AND PRINT SUMMARY
-        logger.info("Validating PTQ model...")
-        trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir", default_val=None))
-        valid_metrics_dict = trainer.test(model=model, test_loader=val_dataloader, test_metrics_list=cfg.training_hyperparams.valid_metrics_list)
-        results = ["PTQ Model Validation Results"]
-        results += [f"   - {metric:10}: {value}" for metric, value in valid_metrics_dict.items()]
-        logger.info("\n".join(results))
-
-        # TRAIN
-        if cfg.quantization_params.ptq_only:
-            logger.info("cfg.quantization_params.ptq_only=True. Performing PTQ only!")
-            suffix = "ptq"
-            res = None
-        else:
-            model.train()
-            recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
-            trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir", default_val=None))
-            torch.cuda.empty_cache()
-
-            res = trainer.train(
-                model=model,
-                train_loader=train_dataloader,
-                valid_loader=val_dataloader,
-                training_params=cfg.training_hyperparams,
-                additional_configs_to_log=recipe_logged_cfg,
-            )
-            suffix = "qat"
-
-        # EXPORT QUANTIZED MODEL TO ONNX
-        input_shape = next(iter(val_dataloader))[0].shape
-        os.makedirs(trainer.checkpoints_dir_path, exist_ok=True)
-
-        qdq_onnx_path = os.path.join(trainer.checkpoints_dir_path, f"{cfg.experiment_name}_{'x'.join((str(x) for x in input_shape))}_{suffix}.onnx")
-        # TODO: modify SG's convert_to_onnx for quantized models and use it instead
-        export_quantized_module_to_onnx(
-            model=model.cpu(),
-            onnx_filename=qdq_onnx_path,
-            input_shape=input_shape,
-            input_size=input_shape,
-            train=False,
-        )
-
-        logger.info(f"Exported {suffix.upper()} ONNX to {qdq_onnx_path}")
-
-        return model, res
+    @deprecated(version="3.2.0", reason="QATTrainer is deprecated and will be removed in future release, use Trainer " "class instead.")
+    def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]:
+        return Trainer.quantize_from_config(cfg)
Discard
@@ -1,24 +1,28 @@
+import copy
 import inspect
 import os
 from copy import deepcopy
 from pathlib import Path
-from typing import Union, Tuple, Mapping, Dict, Any
+from typing import Union, Tuple, Mapping, Dict, Any, List
 
 import hydra
 import numpy as np
 import torch
-from omegaconf import DictConfig
-from omegaconf import OmegaConf
+import torch.cuda
+import torch.nn
+import torchmetrics
+from omegaconf import DictConfig, OmegaConf
 from piptools.scripts.sync import _get_installed_distributions
 from torch import nn
 from torch.cuda.amp import GradScaler, autocast
 from torch.utils.data import DataLoader, SequentialSampler
 from torch.utils.data.distributed import DistributedSampler
-from torchmetrics import MetricCollection
+from torchmetrics import MetricCollection, Metric
 from tqdm import tqdm
 
 from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path, get_ckpt_local_path
 from super_gradients.module_interfaces import HasPreprocessingParams, HasPredict
+from super_gradients.modules.repvgg_block import fuse_repvgg_blocks_residual_branches
 
 from super_gradients.training.utils.sg_trainer_utils import get_callable_param_names
 from super_gradients.common.abstractions.abstract_logger import get_logger
@@ -83,13 +87,26 @@ from super_gradients.training.utils.callbacks import (
 from super_gradients.common.registry.registry import LR_SCHEDULERS_CLS_DICT, LR_WARMUP_CLS_DICT
 from super_gradients.common.environment.device_utils import device_config
 from super_gradients.training.utils import HpmStruct
-from super_gradients.common.environment.cfg_utils import load_experiment_cfg, add_params_to_cfg
+from super_gradients.common.environment.cfg_utils import load_experiment_cfg, add_params_to_cfg, load_recipe
 from super_gradients.common.factories.pre_launch_callbacks_factory import PreLaunchCallbacksFactory
 from super_gradients.training.params import TrainingParams
 
 logger = get_logger(__name__)
 
 
+try:
+    from super_gradients.training.utils.quantization.calibrator import QuantizationCalibrator
+    from super_gradients.training.utils.quantization.export import export_quantized_module_to_onnx
+    from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer
+
+    _imported_pytorch_quantization_failure = None
+
+except (ImportError, NameError, ModuleNotFoundError) as import_err:
+    logger.debug("Failed to import pytorch_quantization:")
+    logger.debug(import_err)
+    _imported_pytorch_quantization_failure = import_err
+
+
 class Trainer:
     """
     SuperGradient Model - Base Class for Sg Models
@@ -2018,3 +2035,337 @@ class Trainer:
                 self.metric_to_watch = criterion_name + "/" + self.metric_to_watch
         else:
             self.loss_logging_items_names = [criterion_name]
+
+    @classmethod
+    def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]:
+        """
+        Perform quantization aware training (QAT) according to a recipe configuration.
+
+        This method will instantiate all the objects specified in the recipe, build and quantize the model,
+        and calibrate the quantized model. The resulting quantized model and the output of the trainer.train()
+        method will be returned.
+
+        The quantized model will be exported to ONNX along with other checkpoints.
+
+        The call to self.quantize (see docs in the next method) is done with the created
+         train_loader and valid_loader. If no calibration data loader is passed through cfg.calib_loader,
+         a train data laoder with the validation transforms is used for calibration.
+
+        :param cfg: The parsed DictConfig object from yaml recipe files or a dictionary.
+        :return: A tuple containing the quantized model and the output of trainer.train() method.
+
+        :raises ValueError: If the recipe does not have the required key `quantization_params` or
+        `checkpoint_params.checkpoint_path` in it.
+        :raises NotImplementedError: If the recipe requests multiple GPUs or num_gpus is not equal to 1.
+        :raises ImportError: If pytorch-quantization import was unsuccessful
+
+        """
+        if _imported_pytorch_quantization_failure is not None:
+            raise _imported_pytorch_quantization_failure
+
+        # INSTANTIATE ALL OBJECTS IN CFG
+        cfg = hydra.utils.instantiate(cfg)
+
+        # TRIGGER CFG MODIFYING CALLBACKS
+        cfg = cls._trigger_cfg_modifying_callbacks(cfg)
+
+        quantization_params = get_param(cfg, "quantization_params")
+
+        if quantization_params is None:
+            raise logger.warning("Your recipe does not include quantization_params. Using default quantization params.")
+
+        if get_param(cfg.checkpoint_params, "checkpoint_path") is None and get_param(cfg.checkpoint_params, "pretrained_weights") is None:
+            raise ValueError("Starting checkpoint / pretrained weights are a must for QAT finetuning.")
+
+        num_gpus = core_utils.get_param(cfg, "num_gpus")
+        multi_gpu = core_utils.get_param(cfg, "multi_gpu")
+        device = core_utils.get_param(cfg, "device")
+        if num_gpus != 1:
+            raise NotImplementedError(
+                f"Recipe requests multi_gpu={cfg.multi_gpu} and num_gpus={cfg.num_gpus}. QAT is proven to work correctly only with multi_gpu=OFF and num_gpus=1"
+            )
+
+        setup_device(device=device, multi_gpu=multi_gpu, num_gpus=num_gpus)
+
+        # INSTANTIATE DATA LOADERS
+        train_dataloader = dataloaders.get(
+            name=get_param(cfg, "train_dataloader"),
+            dataset_params=copy.deepcopy(cfg.dataset_params.train_dataset_params),
+            dataloader_params=copy.deepcopy(cfg.dataset_params.train_dataloader_params),
+        )
+
+        val_dataloader = dataloaders.get(
+            name=get_param(cfg, "val_dataloader"),
+            dataset_params=copy.deepcopy(cfg.dataset_params.val_dataset_params),
+            dataloader_params=copy.deepcopy(cfg.dataset_params.val_dataloader_params),
+        )
+
+        if "calib_dataloader" in cfg:
+            calib_dataloader_name = get_param(cfg, "calib_dataloader")
+            calib_dataloader_params = copy.deepcopy(cfg.dataset_params.calib_dataloader_params)
+            calib_dataset_params = copy.deepcopy(cfg.dataset_params.calib_dataset_params)
+        else:
+            calib_dataloader_name = get_param(cfg, "train_dataloader")
+            calib_dataloader_params = copy.deepcopy(cfg.dataset_params.train_dataloader_params)
+            calib_dataset_params = copy.deepcopy(cfg.dataset_params.train_dataset_params)
+
+            # if we use whole dataloader for calibration, don't shuffle it
+            # HistogramCalibrator collection routine is sensitive to order of batches and produces slightly different results
+            # if we use several batches, we don't want it to be from one class if it's sequential in dataloader
+            # model is in eval mode, so BNs will not be affected
+            calib_dataloader_params.shuffle = cfg.quantization_params.calib_params.num_calib_batches is not None
+            # we don't need training transforms during calibration, distribution of activations will be skewed
+            calib_dataset_params.transforms = cfg.dataset_params.val_dataset_params.transforms
+
+        calib_dataloader = dataloaders.get(
+            name=calib_dataloader_name,
+            dataset_params=calib_dataset_params,
+            dataloader_params=calib_dataloader_params,
+        )
+
+        # BUILD MODEL
+        model = models.get(
+            model_name=cfg.arch_params.get("model_name", None) or cfg.architecture,
+            num_classes=cfg.get("num_classes", None) or cfg.arch_params.num_classes,
+            arch_params=cfg.arch_params,
+            strict_load=cfg.checkpoint_params.strict_load,
+            pretrained_weights=cfg.checkpoint_params.pretrained_weights,
+            checkpoint_path=cfg.checkpoint_params.checkpoint_path,
+            load_backbone=False,
+        )
+
+        recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
+        trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir"))
+
+        if quantization_params.ptq_only:
+            res = trainer.ptq(
+                calib_loader=calib_dataloader,
+                model=model,
+                quantization_params=quantization_params,
+                valid_loader=val_dataloader,
+                valid_metrics_list=cfg.training_hyperparams.valid_metrics_list,
+            )
+        else:
+            res = trainer.qat(
+                model=model,
+                quantization_params=quantization_params,
+                calib_loader=calib_dataloader,
+                valid_loader=val_dataloader,
+                train_loader=train_dataloader,
+                training_params=cfg.training_hyperparams,
+                additional_qat_configs_to_log=recipe_logged_cfg,
+            )
+
+        return model, res
+
+    def qat(
+        self,
+        calib_loader: DataLoader,
+        model: torch.nn.Module,
+        valid_loader: DataLoader,
+        train_loader: DataLoader,
+        training_params: Mapping = None,
+        quantization_params: Mapping = None,
+        additional_qat_configs_to_log: Dict = None,
+        valid_metrics_list: List[Metric] = None,
+    ):
+        """
+        Performs post-training quantization (PTQ), and then quantization-aware training (QAT).
+        Exports the ONNX models (ckpt_best.pth of QAT and the calibrated model) to the checkpoints directory.
+
+        :param calib_loader: DataLoader, data loader for calibration.
+
+        :param model: torch.nn.Module, Model to perform QAT/PTQ on. When None, will try to use the network from
+        previous self.train call(that is, if there was one - will try to use self.ema_model.ema if EMA was used,
+        otherwise self.net)
+
+
+        :param valid_loader: DataLoader, data loader for validation. Used both for validating the calibrated model after PTQ and during QAT.
+            When None, will try to use self.valid_loader if it was set in previous self.train(..) call (default=None).
+
+        :param train_loader: DataLoader, data loader for QA training, can be ignored when quantization_params["ptq_only"]=True (default=None).
+
+        :param quantization_params: Mapping, with the following entries:defaults-
+            selective_quantizer_params:
+              calibrator_w: "max"        # calibrator type for weights, acceptable types are ["max", "histogram"]
+              calibrator_i: "histogram"  # calibrator type for inputs acceptable types are ["max", "histogram"]
+              per_channel: True          # per-channel quantization of weights, activations stay per-tensor by default
+              learn_amax: False          # enable learnable amax in all TensorQuantizers using straight-through estimator
+              skip_modules:              # optional list of module names (strings) to skip from quantization
+
+            calib_params:
+              histogram_calib_method: "percentile"  # calibration method for all "histogram" calibrators,
+                                                                # acceptable types are ["percentile", "entropy", mse"],
+                                                                # "max" calibrators always use "max"
+
+              percentile: 99.99                     # percentile for all histogram calibrators with method "percentile",
+                                                    # other calibrators are not affected
+
+              num_calib_batches:                    # number of batches to use for calibration, if None, 512 / batch_size will be used
+              verbose: False                        # if calibrator should be verbose
+
+
+              When None, the above default config is used (default=None)
+
+
+        :param training_params: Mapping, training hyper parameters for QAT, same as in super.train(...). When None, will try to use self.training_params
+         which is set in previous self.train(..) call (default=None).
+
+        :param additional_qat_configs_to_log: Dict, Optional dictionary containing configs that will be added to the QA training's
+         sg_logger. Format should be {"Config_title_1": {...}, "Config_title_2":{..}}.
+
+        :param valid_metrics_list:  (list(torchmetrics.Metric)) metrics list for evaluation of the calibrated model.
+        When None, the validation metrics from training_params are used). (default=None).
+
+        :return: Validation results of the QAT model in case quantization_params['ptq_only']=False and of the PTQ
+        model otherwise.
+        """
+
+        if quantization_params is None:
+            quantization_params = load_recipe("quantization_params/default_quantization_params").quantization_params
+            logger.info(f"Using default quantization params: {quantization_params}")
+        valid_metrics_list = valid_metrics_list or get_param(training_params, "valid_metrics_list")
+
+        _ = self.ptq(
+            calib_loader=calib_loader,
+            model=model,
+            quantization_params=quantization_params,
+            valid_loader=valid_loader,
+            valid_metrics_list=valid_metrics_list,
+            deepcopy_model_for_export=True,
+        )
+        # TRAIN
+        model.train()
+        torch.cuda.empty_cache()
+
+        res = self.train(
+            model=model,
+            train_loader=train_loader,
+            valid_loader=valid_loader,
+            training_params=training_params,
+            additional_configs_to_log=additional_qat_configs_to_log,
+        )
+
+        # EXPORT QUANTIZED MODEL TO ONNX
+        input_shape = next(iter(valid_loader))[0].shape
+        os.makedirs(self.checkpoints_dir_path, exist_ok=True)
+        qdq_onnx_path = os.path.join(self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape))}_qat.onnx")
+
+        # TODO: modify SG's convert_to_onnx for quantized models and use it instead
+        export_quantized_module_to_onnx(
+            model=model.cpu(),
+            onnx_filename=qdq_onnx_path,
+            input_shape=input_shape,
+            input_size=input_shape,
+            train=False,
+        )
+        logger.info(f"Exported QAT ONNX to {qdq_onnx_path}")
+        return res
+
+    def ptq(
+        self,
+        calib_loader: DataLoader,
+        model: nn.Module,
+        valid_loader: DataLoader,
+        valid_metrics_list: List[torchmetrics.Metric],
+        quantization_params: Dict = None,
+        deepcopy_model_for_export: bool = False,
+    ):
+        """
+        Performs post-training quantization (calibration of the model)..
+
+        :param calib_loader: DataLoader, data loader for calibration.
+
+        :param model: torch.nn.Module, Model to perform calibration on. When None, will try to use self.net which is
+        set in previous self.train(..) call (default=None).
+
+        :param valid_loader: DataLoader, data loader for validation. Used both for validating the calibrated model.
+            When None, will try to use self.valid_loader if it was set in previous self.train(..) call (default=None).
+
+        :param quantization_params: Mapping, with the following entries:defaults-
+            selective_quantizer_params:
+              calibrator_w: "max"        # calibrator type for weights, acceptable types are ["max", "histogram"]
+              calibrator_i: "histogram"  # calibrator type for inputs acceptable types are ["max", "histogram"]
+              per_channel: True          # per-channel quantization of weights, activations stay per-tensor by default
+              learn_amax: False          # enable learnable amax in all TensorQuantizers using straight-through estimator
+              skip_modules:              # optional list of module names (strings) to skip from quantization
+
+            calib_params:
+              histogram_calib_method: "percentile"  # calibration method for all "histogram" calibrators,
+                                                                # acceptable types are ["percentile", "entropy", mse"],
+                                                                # "max" calibrators always use "max"
+
+              percentile: 99.99                     # percentile for all histogram calibrators with method "percentile",
+                                                    # other calibrators are not affected
+
+              num_calib_batches:                    # number of batches to use for calibration, if None, 512 / batch_size will be used
+              verbose: False                        # if calibrator should be verbose
+
+
+              When None, the above default config is used (default=None)
+
+
+
+        :param valid_metrics_list:  (list(torchmetrics.Metric)) metrics list for evaluation of the calibrated model.
+
+        :param deepcopy_model_for_export: bool, Whether to export deepcopy(model). Necessary in case further training is
+            performed and prep_model_for_conversion makes the network un-trainable (i.e RepVGG blocks).
+
+        :return: Validation results of the calibrated model.
+        """
+
+        if quantization_params is None:
+            quantization_params = load_recipe("quantization_params/default_quantization_params").quantization_params
+            logger.info(f"Using default quantization params: {quantization_params}")
+
+        selective_quantizer_params = get_param(quantization_params, "selective_quantizer_params")
+        calib_params = get_param(quantization_params, "calib_params")
+        model.to(device_config.device)
+        # QUANTIZE MODEL
+        model.eval()
+        fuse_repvgg_blocks_residual_branches(model)
+        q_util = SelectiveQuantizer(
+            default_quant_modules_calibrator_weights=get_param(selective_quantizer_params, "calibrator_w"),
+            default_quant_modules_calibrator_inputs=get_param(selective_quantizer_params, "calibrator_i"),
+            default_per_channel_quant_weights=get_param(selective_quantizer_params, "per_channel"),
+            default_learn_amax=get_param(selective_quantizer_params, "learn_amax"),
+            verbose=get_param(calib_params, "verbose"),
+        )
+        q_util.register_skip_quantization(layer_names=get_param(selective_quantizer_params, "skip_modules"))
+        q_util.quantize_module(model)
+        # CALIBRATE MODEL
+        logger.info("Calibrating model...")
+        calibrator = QuantizationCalibrator(
+            verbose=get_param(calib_params, "verbose"),
+            torch_hist=True,
+        )
+        calibrator.calibrate_model(
+            model,
+            method=get_param(calib_params, "histogram_calib_method"),
+            calib_data_loader=calib_loader,
+            num_calib_batches=get_param(calib_params, "num_calib_batches") or len(calib_loader),
+            percentile=get_param(calib_params, "percentile", 99.99),
+        )
+        calibrator.reset_calibrators(model)  # release memory taken by calibrators
+        # VALIDATE PTQ MODEL AND PRINT SUMMARY
+        logger.info("Validating PTQ model...")
+        valid_metrics_dict = self.test(model=model, test_loader=valid_loader, test_metrics_list=valid_metrics_list)
+        results = ["PTQ Model Validation Results"]
+        results += [f"   - {metric:10}: {value}" for metric, value in valid_metrics_dict.items()]
+        logger.info("\n".join(results))
+
+        input_shape = next(iter(valid_loader))[0].shape
+        os.makedirs(self.checkpoints_dir_path, exist_ok=True)
+        qdq_onnx_path = os.path.join(self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape))}_ptq.onnx")
+
+        # TODO: modify SG's convert_to_onnx for quantized models and use it instead
+        export_quantized_module_to_onnx(
+            model=model.cpu(),
+            onnx_filename=qdq_onnx_path,
+            input_shape=input_shape,
+            input_size=input_shape,
+            train=False,
+            deepcopy_model=deepcopy_model_for_export,
+        )
+
+        return valid_metrics_dict
Discard
@@ -1,3 +1,5 @@
+from copy import deepcopy
+
 import torch
 from torch.onnx import TrainingMode
 
@@ -14,7 +16,9 @@ except (ImportError, NameError, ModuleNotFoundError) as import_err:
     _imported_pytorch_quantization_failure = import_err
 
 
-def export_quantized_module_to_onnx(model: torch.nn.Module, onnx_filename: str, input_shape: tuple, train: bool = False, to_cpu: bool = True, **kwargs):
+def export_quantized_module_to_onnx(
+    model: torch.nn.Module, onnx_filename: str, input_shape: tuple, train: bool = False, to_cpu: bool = True, deepcopy_model=False, **kwargs
+):
     """
     Method for exporting onnx after QAT.
 
@@ -23,10 +27,15 @@ def export_quantized_module_to_onnx(model: torch.nn.Module, onnx_filename: str,
     :param model: torch.nn.Module, model to export
     :param onnx_filename: str, target path for the onnx file,
     :param input_shape: tuple, input shape (usually BCHW)
+    :param deepcopy_model: Whether to export deepcopy(model). Necessary in case further training is performed and
+     prep_model_for_conversion makes the network un-trainable (i.e RepVGG blocks).
     """
     if _imported_pytorch_quantization_failure is not None:
         raise _imported_pytorch_quantization_failure
 
+    if deepcopy_model:
+        model = deepcopy(model)
+
     use_fb_fake_quant_state = quant_nn.TensorQuantizer.use_fb_fake_quant
     quant_nn.TensorQuantizer.use_fb_fake_quant = True
 
Discard
@@ -2,6 +2,7 @@ import sys
 import unittest
 
 from tests.recipe_training_tests.automatic_batch_selection_single_gpu_test import TestAutoBatchSelectionSingleGPU
+from tests.recipe_training_tests.coded_qat_launch_test import CodedQATLuanchTest
 from tests.recipe_training_tests.shortened_recipes_accuracy_test import ShortenedRecipesAccuracyTests
 
 
@@ -17,6 +18,7 @@ class CoreUnitTestSuiteRunner:
         _add_modules_to_unit_tests_suite - Adds unit tests to the Unit Tests Test Suite
             :return:
         """
+        self.recipe_tests_suite.addTest(self.test_loader.loadTestsFromModule(CodedQATLuanchTest))
         self.recipe_tests_suite.addTest(self.test_loader.loadTestsFromModule(ShortenedRecipesAccuracyTests))
         self.recipe_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestAutoBatchSelectionSingleGPU))
 
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
  1. import unittest
  2. from torchvision.transforms import Normalize, ToTensor, RandomHorizontalFlip, RandomCrop
  3. from super_gradients import Trainer
  4. from super_gradients.training import modify_params_for_qat
  5. from super_gradients.training.dataloaders.dataloaders import cifar10_train, cifar10_val
  6. from super_gradients.training.metrics import Accuracy, Top5
  7. from super_gradients.training.models import ResNet18
  8. class CodedQATLuanchTest(unittest.TestCase):
  9. def test_qat_launch(self):
  10. trainer = Trainer("test_launch_qat_with_minimal_changes")
  11. net = ResNet18(num_classes=10, arch_params={})
  12. train_params = {
  13. "max_epochs": 10,
  14. "lr_updates": [],
  15. "lr_decay_factor": 0.1,
  16. "lr_mode": "step",
  17. "lr_warmup_epochs": 0,
  18. "initial_lr": 0.1,
  19. "loss": "cross_entropy",
  20. "optimizer": "SGD",
  21. "criterion_params": {},
  22. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  23. "train_metrics_list": [Accuracy(), Top5()],
  24. "valid_metrics_list": [Accuracy(), Top5()],
  25. "metric_to_watch": "Accuracy",
  26. "greater_metric_to_watch_is_better": True,
  27. "ema": True,
  28. }
  29. train_dataset_params = {
  30. "transforms": [
  31. RandomCrop(size=32, padding=4),
  32. RandomHorizontalFlip(),
  33. ToTensor(),
  34. Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
  35. ]
  36. }
  37. train_dataloader_params = {"batch_size": 256}
  38. val_dataset_params = {"transforms": [ToTensor(), Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])]}
  39. val_dataloader_params = {"batch_size": 256}
  40. train_loader = cifar10_train(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params)
  41. valid_loader = cifar10_val(dataset_params=val_dataset_params, dataloader_params=val_dataloader_params)
  42. trainer.train(
  43. model=net,
  44. training_params=train_params,
  45. train_loader=train_loader,
  46. valid_loader=valid_loader,
  47. )
  48. train_params, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params = modify_params_for_qat(
  49. train_params, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params
  50. )
  51. train_loader = cifar10_train(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params)
  52. valid_loader = cifar10_val(dataset_params=val_dataset_params, dataloader_params=val_dataloader_params)
  53. trainer.qat(
  54. model=net,
  55. training_params=train_params,
  56. train_loader=train_loader,
  57. valid_loader=valid_loader,
  58. calib_loader=train_loader,
  59. )
  60. def test_ptq_launch(self):
  61. trainer = Trainer("test_launch_ptq_with_minimal_changes")
  62. net = ResNet18(num_classes=10, arch_params={})
  63. train_params = {
  64. "max_epochs": 10,
  65. "lr_updates": [],
  66. "lr_decay_factor": 0.1,
  67. "lr_mode": "step",
  68. "lr_warmup_epochs": 0,
  69. "initial_lr": 0.1,
  70. "loss": "cross_entropy",
  71. "optimizer": "SGD",
  72. "criterion_params": {},
  73. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  74. "train_metrics_list": [Accuracy(), Top5()],
  75. "valid_metrics_list": [Accuracy(), Top5()],
  76. "metric_to_watch": "Accuracy",
  77. "greater_metric_to_watch_is_better": True,
  78. "ema": True,
  79. }
  80. train_dataset_params = {
  81. "transforms": [
  82. RandomCrop(size=32, padding=4),
  83. RandomHorizontalFlip(),
  84. ToTensor(),
  85. Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
  86. ]
  87. }
  88. train_dataloader_params = {"batch_size": 256}
  89. val_dataset_params = {"transforms": [ToTensor(), Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])]}
  90. val_dataloader_params = {"batch_size": 256}
  91. train_loader = cifar10_train(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params)
  92. valid_loader = cifar10_val(dataset_params=val_dataset_params, dataloader_params=val_dataloader_params)
  93. trainer.train(
  94. model=net,
  95. training_params=train_params,
  96. train_loader=train_loader,
  97. valid_loader=valid_loader,
  98. )
  99. train_params, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params = modify_params_for_qat(
  100. train_params, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params
  101. )
  102. train_loader = cifar10_train(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params)
  103. valid_loader = cifar10_val(dataset_params=val_dataset_params, dataloader_params=val_dataloader_params)
  104. trainer.ptq(model=net, valid_loader=valid_loader, calib_loader=train_loader, valid_metrics_list=train_params["valid_metrics_list"])
  105. if __name__ == "__main__":
  106. unittest.main()
Discard