Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#609 Ci fix

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/infra-000_ci
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
  1. import torch
  2. from torch.onnx import TrainingMode
  3. from super_gradients.common.abstractions.abstract_logger import get_logger
  4. logger = get_logger(__name__)
  5. try:
  6. from pytorch_quantization import nn as quant_nn
  7. _imported_pytorch_quantization_failure = None
  8. except (ImportError, NameError, ModuleNotFoundError) as import_err:
  9. logger.warning("Failed to import pytorch_quantization")
  10. _imported_pytorch_quantization_failure = import_err
  11. def export_quantized_module_to_onnx(model: torch.nn.Module, onnx_filename: str, input_shape: tuple, train: bool = False, **kwargs):
  12. """
  13. Method for exporting onnx after QAT.
  14. :param model: torch.nn.Module, model to export
  15. :param onnx_filename: str, target path for the onnx file,
  16. :param input_shape: tuple, input shape (usually BCHW)
  17. """
  18. if _imported_pytorch_quantization_failure is not None:
  19. raise _imported_pytorch_quantization_failure
  20. use_fb_fake_quant_state = quant_nn.TensorQuantizer.use_fb_fake_quant
  21. quant_nn.TensorQuantizer.use_fb_fake_quant = True
  22. # Export ONNX for multiple batch sizes
  23. logger.info("Creating ONNX file: " + onnx_filename)
  24. dummy_input = torch.randn(input_shape, device=next(model.parameters()).device)
  25. if train:
  26. training_mode = TrainingMode.TRAINING
  27. model.train()
  28. else:
  29. training_mode = TrainingMode.EVAL
  30. model.eval()
  31. if hasattr(model, "prep_model_for_conversion"):
  32. model.prep_model_for_conversion(**kwargs)
  33. torch.onnx.export(
  34. model, dummy_input, onnx_filename, verbose=False, opset_version=13, enable_onnx_checker=False, do_constant_folding=True, training=training_mode
  35. )
  36. # Restore functions of quant_nn back as expected
  37. quant_nn.TensorQuantizer.use_fb_fake_quant = use_fb_fake_quant_state
Discard
Tip!

Press p or to see the previous file or, n or to see the next file