Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#537 Quantization infra mods for different calibrators and learnable amax

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/AL-706-selective-qat
1 changed files with 16 additions and 4 deletions
  1. 16
    4
      tests/unit_tests/quantization_utility_tests.py
@@ -412,7 +412,7 @@ class QuantizationUtilityTest(unittest.TestCase):
         module = MyModel()
         module = MyModel()
 
 
         # TEST
         # TEST
-        q_util = SelectiveQuantizer(default_quant_modules_calib_method="max")
+        q_util = SelectiveQuantizer(default_quant_modules_calib_method_inputs="max", default_quant_modules_calib_method_weights="max")
         q_util.quantize_module(module)
         q_util.quantize_module(module)
 
 
         x = torch.rand(1, 3, 32, 32)
         x = torch.rand(1, 3, 32, 32)
@@ -730,7 +730,7 @@ class QuantizationUtilityTest(unittest.TestCase):
                     input_quant_descriptor=QuantDescriptor(calib_method="max"),
                     input_quant_descriptor=QuantDescriptor(calib_method="max"),
                 ),
                 ),
             },
             },
-            default_per_channel_quant_modules=True,
+            default_per_channel_quant_weights=True,
         )
         )
 
 
         sq.quantize_module(resnet_sg, preserve_state_dict=True)
         sq.quantize_module(resnet_sg, preserve_state_dict=True)
@@ -754,8 +754,9 @@ class QuantizationUtilityTest(unittest.TestCase):
             torch.testing.assert_close(y_sg, y_pyquant)
             torch.testing.assert_close(y_sg, y_pyquant)
 
 
     def test_sg_resnet_sg_vanilla_quantization_matches_pytorch_quantization(self):
     def test_sg_resnet_sg_vanilla_quantization_matches_pytorch_quantization(self):
-
         # SG SELECTIVE QUANTIZATION
         # SG SELECTIVE QUANTIZATION
+        from super_gradients.training.models.classification_models.resnet import Bottleneck
+
         sq = SelectiveQuantizer(
         sq = SelectiveQuantizer(
             custom_mappings={
             custom_mappings={
                 torch.nn.Conv2d: QuantizedMetadata(
                 torch.nn.Conv2d: QuantizedMetadata(
@@ -779,16 +780,27 @@ class QuantizationUtilityTest(unittest.TestCase):
                     input_quant_descriptor=QuantDescriptor(calib_method="max"),
                     input_quant_descriptor=QuantDescriptor(calib_method="max"),
                 ),
                 ),
             },
             },
-            default_per_channel_quant_modules=True,
+            default_per_channel_quant_weights=True,
         )
         )
 
 
+        # SG registers non-naive QuantBottleneck that will have different behaviour, pop it for testing purposes
+        if Bottleneck in sq.mapping_instructions:
+            sq.mapping_instructions.pop(Bottleneck)
+
         resnet_sg: nn.Module = super_gradients.training.models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
         resnet_sg: nn.Module = super_gradients.training.models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
         sq.quantize_module(resnet_sg, preserve_state_dict=True)
         sq.quantize_module(resnet_sg, preserve_state_dict=True)
 
 
         # PYTORCH-QUANTIZATION
         # PYTORCH-QUANTIZATION
         quant_desc_input = QuantDescriptor(calib_method="histogram")
         quant_desc_input = QuantDescriptor(calib_method="histogram")
+        quant_desc_weights = QuantDescriptor(calib_method="max", axis=0)
+
         quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
         quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
+        quant_nn.QuantConv2d.set_default_quant_desc_weight(quant_desc_weights)
+
         quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
         quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
+        quant_nn.QuantLinear.set_default_quant_desc_weight(quant_desc_weights)
+
+        quant_nn.QuantAdaptiveAvgPool2d.set_default_quant_desc_input(QuantDescriptor(calib_method="histogram"))
 
 
         quant_modules.initialize()
         quant_modules.initialize()
         resnet_pyquant: nn.Module = super_gradients.training.models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
         resnet_pyquant: nn.Module = super_gradients.training.models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
Discard