|
@@ -15,6 +15,8 @@ from super_gradients.common.environment.cfg_utils import load_experiment_cfg
|
|
from super_gradients.training.utils.sg_trainer_utils import parse_args
|
|
from super_gradients.training.utils.sg_trainer_utils import parse_args
|
|
import os
|
|
import os
|
|
import pathlib
|
|
import pathlib
|
|
|
|
+from onnxsim import simplify
|
|
|
|
+import onnx
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@@ -55,6 +57,7 @@ def convert_to_onnx(
|
|
post_process: torch.nn.Module = None,
|
|
post_process: torch.nn.Module = None,
|
|
prep_model_for_conversion_kwargs=None,
|
|
prep_model_for_conversion_kwargs=None,
|
|
torch_onnx_export_kwargs=None,
|
|
torch_onnx_export_kwargs=None,
|
|
|
|
+ simplify: bool = True,
|
|
):
|
|
):
|
|
"""
|
|
"""
|
|
Exports model to ONNX.
|
|
Exports model to ONNX.
|
|
@@ -67,6 +70,8 @@ def convert_to_onnx(
|
|
:param prep_model_for_conversion_kwargs: dict, for SgModules- args to be passed to model.prep_model_for_conversion
|
|
:param prep_model_for_conversion_kwargs: dict, for SgModules- args to be passed to model.prep_model_for_conversion
|
|
prior to torch.onnx.export call.
|
|
prior to torch.onnx.export call.
|
|
:param torch_onnx_export_kwargs: kwargs (EXCLUDING: FIRST 3 KWARGS- MODEL, F, ARGS). to be unpacked in torch.onnx.export call
|
|
:param torch_onnx_export_kwargs: kwargs (EXCLUDING: FIRST 3 KWARGS- MODEL, F, ARGS). to be unpacked in torch.onnx.export call
|
|
|
|
+ :param simplify: bool,whether to apply onnx simplifier method, same as `python -m onnxsim onnx_path onnx_sim_path.
|
|
|
|
+ When true, the simplified model will be saved in out_path (default=True).
|
|
|
|
|
|
:return: out_path
|
|
:return: out_path
|
|
"""
|
|
"""
|
|
@@ -80,6 +85,8 @@ def convert_to_onnx(
|
|
complete_model = ConvertableCompletePipelineModel(model, pre_process, post_process, **prep_model_for_conversion_kwargs)
|
|
complete_model = ConvertableCompletePipelineModel(model, pre_process, post_process, **prep_model_for_conversion_kwargs)
|
|
|
|
|
|
torch.onnx.export(model=complete_model, args=onnx_input, f=out_path, **torch_onnx_export_kwargs)
|
|
torch.onnx.export(model=complete_model, args=onnx_input, f=out_path, **torch_onnx_export_kwargs)
|
|
|
|
+ if simplify:
|
|
|
|
+ onnx_simplify(out_path, out_path)
|
|
return out_path
|
|
return out_path
|
|
|
|
|
|
|
|
|
|
@@ -129,3 +136,15 @@ def convert_from_config(cfg: DictConfig) -> str:
|
|
out_path = models.convert_to_onnx(model=model, **cfg)
|
|
out_path = models.convert_to_onnx(model=model, **cfg)
|
|
logger.info(f"Successfully exported model at {out_path}")
|
|
logger.info(f"Successfully exported model at {out_path}")
|
|
return out_path
|
|
return out_path
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def onnx_simplify(onnx_path: str, onnx_sim_path: str):
|
|
|
|
+ """
|
|
|
|
+ onnx simplifier method, same as `python -m onnxsim onnx_path onnx_sim_path
|
|
|
|
+ :param onnx_path: path to onnx model
|
|
|
|
+ :param onnx_sim_path: path for output onnx simplified model
|
|
|
|
+ """
|
|
|
|
+ model_sim, check = simplify(model=onnx_path)
|
|
|
|
+ if not check:
|
|
|
|
+ raise RuntimeError("Simplified ONNX model could not be validated")
|
|
|
|
+ onnx.save_model(model_sim, onnx_sim_path)
|