Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#614 Feature/sg 493 modelnames instead of strings

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-493_modelnames_instead_of_strings
@@ -3,6 +3,8 @@ import torch
 import torchvision
 import torchvision
 from torch import nn
 from torch import nn
 
 
+from super_gradients.common.object_names import Models
+
 try:
 try:
     import super_gradients
     import super_gradients
     from pytorch_quantization import nn as quant_nn
     from pytorch_quantization import nn as quant_nn
@@ -787,7 +789,7 @@ class QuantizationUtilityTest(unittest.TestCase):
         if Bottleneck in sq.mapping_instructions:
         if Bottleneck in sq.mapping_instructions:
             sq.mapping_instructions.pop(Bottleneck)
             sq.mapping_instructions.pop(Bottleneck)
 
 
-        resnet_sg: nn.Module = super_gradients.training.models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
+        resnet_sg: nn.Module = super_gradients.training.models.get(Models.RESNET50, pretrained_weights="imagenet", num_classes=1000)
         sq.quantize_module(resnet_sg, preserve_state_dict=True)
         sq.quantize_module(resnet_sg, preserve_state_dict=True)
 
 
         # PYTORCH-QUANTIZATION
         # PYTORCH-QUANTIZATION
@@ -803,7 +805,7 @@ class QuantizationUtilityTest(unittest.TestCase):
         quant_nn.QuantAdaptiveAvgPool2d.set_default_quant_desc_input(QuantDescriptor(calib_method="histogram"))
         quant_nn.QuantAdaptiveAvgPool2d.set_default_quant_desc_input(QuantDescriptor(calib_method="histogram"))
 
 
         quant_modules.initialize()
         quant_modules.initialize()
-        resnet_pyquant: nn.Module = super_gradients.training.models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
+        resnet_pyquant: nn.Module = super_gradients.training.models.get(Models.RESNET50, pretrained_weights="imagenet", num_classes=1000)
 
 
         quant_modules.deactivate()
         quant_modules.deactivate()
 
 
Discard
Tip!

Press p or to see the previous file or, n or to see the next file