Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#755 Feature/sg 672 add onnx simplifier

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-672_add_onnx_simplifier
@@ -35,3 +35,4 @@ stringcase>=1.2.0
 numpy<=1.23
 numpy<=1.23
 rapidfuzz
 rapidfuzz
 json-tricks==3.16.1
 json-tricks==3.16.1
+onnx-simplifier>=0.3.6,<1.0
Discard
@@ -17,3 +17,4 @@ pre_process: # Preprocessing pipeline, will be resolved by TransformsFactory(),
 post_process: # Postprocessing pipeline, will be resolved by TransformsFactory(), and will be baked into the converted model (optional).
 post_process: # Postprocessing pipeline, will be resolved by TransformsFactory(), and will be baked into the converted model (optional).
 prep_model_for_conversion_kwargs: # For SgModules, args to be passed to model.prep_model_for_conversion prior to torch.onnx.export call.
 prep_model_for_conversion_kwargs: # For SgModules, args to be passed to model.prep_model_for_conversion prior to torch.onnx.export call.
 torch_onnx_export_kwargs: # kwargs (EXCLUDING: FIRST 3 KWARGS- MODEL, F, ARGS). to be unpacked in torch.onnx.export call
 torch_onnx_export_kwargs: # kwargs (EXCLUDING: FIRST 3 KWARGS- MODEL, F, ARGS). to be unpacked in torch.onnx.export call
+simplify: True # whether to apply onnx simplifier method, same as `python -m onnxsim onnx_path onnx_sim_path. When true, the simplified models will be saved in out_path.
Discard
@@ -15,6 +15,8 @@ from super_gradients.common.environment.cfg_utils import load_experiment_cfg
 from super_gradients.training.utils.sg_trainer_utils import parse_args
 from super_gradients.training.utils.sg_trainer_utils import parse_args
 import os
 import os
 import pathlib
 import pathlib
+from onnxsim import simplify
+import onnx
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -55,6 +57,7 @@ def convert_to_onnx(
     post_process: torch.nn.Module = None,
     post_process: torch.nn.Module = None,
     prep_model_for_conversion_kwargs=None,
     prep_model_for_conversion_kwargs=None,
     torch_onnx_export_kwargs=None,
     torch_onnx_export_kwargs=None,
+    simplify: bool = True,
 ):
 ):
     """
     """
     Exports model to ONNX.
     Exports model to ONNX.
@@ -67,6 +70,8 @@ def convert_to_onnx(
     :param prep_model_for_conversion_kwargs: dict, for SgModules- args to be passed to model.prep_model_for_conversion
     :param prep_model_for_conversion_kwargs: dict, for SgModules- args to be passed to model.prep_model_for_conversion
      prior to torch.onnx.export call.
      prior to torch.onnx.export call.
     :param torch_onnx_export_kwargs: kwargs (EXCLUDING: FIRST 3 KWARGS- MODEL, F, ARGS). to be unpacked in torch.onnx.export call
     :param torch_onnx_export_kwargs: kwargs (EXCLUDING: FIRST 3 KWARGS- MODEL, F, ARGS). to be unpacked in torch.onnx.export call
+    :param simplify: bool,whether to apply onnx simplifier method, same as `python -m onnxsim onnx_path onnx_sim_path.
+     When true, the simplified model will be saved in out_path (default=True).
 
 
     :return: out_path
     :return: out_path
     """
     """
@@ -80,6 +85,8 @@ def convert_to_onnx(
     complete_model = ConvertableCompletePipelineModel(model, pre_process, post_process, **prep_model_for_conversion_kwargs)
     complete_model = ConvertableCompletePipelineModel(model, pre_process, post_process, **prep_model_for_conversion_kwargs)
 
 
     torch.onnx.export(model=complete_model, args=onnx_input, f=out_path, **torch_onnx_export_kwargs)
     torch.onnx.export(model=complete_model, args=onnx_input, f=out_path, **torch_onnx_export_kwargs)
+    if simplify:
+        onnx_simplify(out_path, out_path)
     return out_path
     return out_path
 
 
 
 
@@ -129,3 +136,15 @@ def convert_from_config(cfg: DictConfig) -> str:
     out_path = models.convert_to_onnx(model=model, **cfg)
     out_path = models.convert_to_onnx(model=model, **cfg)
     logger.info(f"Successfully exported model at {out_path}")
     logger.info(f"Successfully exported model at {out_path}")
     return out_path
     return out_path
+
+
+def onnx_simplify(onnx_path: str, onnx_sim_path: str):
+    """
+    onnx simplifier method, same as `python -m onnxsim onnx_path onnx_sim_path
+    :param onnx_path: path to onnx model
+    :param onnx_sim_path: path for output onnx simplified model
+    """
+    model_sim, check = simplify(model=onnx_path)
+    if not check:
+        raise RuntimeError("Simplified ONNX model could not be validated")
+    onnx.save_model(model_sim, onnx_sim_path)
Discard