|
@@ -52,7 +52,7 @@ class ConvertableCompletePipelineModel(torch.nn.Module):
|
|
def convert_to_onnx(
|
|
def convert_to_onnx(
|
|
model: torch.nn.Module,
|
|
model: torch.nn.Module,
|
|
out_path: str,
|
|
out_path: str,
|
|
- input_shape: tuple,
|
|
|
|
|
|
+ input_shape: tuple = None,
|
|
pre_process: torch.nn.Module = None,
|
|
pre_process: torch.nn.Module = None,
|
|
post_process: torch.nn.Module = None,
|
|
post_process: torch.nn.Module = None,
|
|
prep_model_for_conversion_kwargs=None,
|
|
prep_model_for_conversion_kwargs=None,
|
|
@@ -64,7 +64,7 @@ def convert_to_onnx(
|
|
|
|
|
|
:param model: torch.nn.Module, model to export to ONNX.
|
|
:param model: torch.nn.Module, model to export to ONNX.
|
|
:param out_path: str, destination path for the .onnx file.
|
|
:param out_path: str, destination path for the .onnx file.
|
|
- :param input_shape: tuple, input shape, excluding batch_size (i.e (3, 224, 224)).
|
|
|
|
|
|
+ :param input_shape: DEPRECATED USE input_size KWARG IN prep_model_for_conversion_kwargs INSTEAD.
|
|
:param pre_process: torch.nn.Module, preprocessing pipeline, will be resolved by TransformsFactory()
|
|
:param pre_process: torch.nn.Module, preprocessing pipeline, will be resolved by TransformsFactory()
|
|
:param post_process: torch.nn.Module, postprocessing pipeline, will be resolved by TransformsFactory()
|
|
:param post_process: torch.nn.Module, postprocessing pipeline, will be resolved by TransformsFactory()
|
|
:param prep_model_for_conversion_kwargs: dict, for SgModules- args to be passed to model.prep_model_for_conversion
|
|
:param prep_model_for_conversion_kwargs: dict, for SgModules- args to be passed to model.prep_model_for_conversion
|
|
@@ -79,6 +79,14 @@ def convert_to_onnx(
|
|
raise FileNotFoundError(f"Could not find destination directory {out_path} for the ONNX file.")
|
|
raise FileNotFoundError(f"Could not find destination directory {out_path} for the ONNX file.")
|
|
torch_onnx_export_kwargs = torch_onnx_export_kwargs or dict()
|
|
torch_onnx_export_kwargs = torch_onnx_export_kwargs or dict()
|
|
prep_model_for_conversion_kwargs = prep_model_for_conversion_kwargs or dict()
|
|
prep_model_for_conversion_kwargs = prep_model_for_conversion_kwargs or dict()
|
|
|
|
+
|
|
|
|
+ if input_shape is not None:
|
|
|
|
+ logger.warning(
|
|
|
|
+ "input_shape is deprecated and will be removed in the next major release." " Use the input_size kwarg in prep_model_for_conversion_kwargs instead"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ prep_model_for_conversion_kwargs["input_size"] = (1, *input_shape)
|
|
|
|
+
|
|
onnx_input = torch.Tensor(np.zeros([1, *input_shape]))
|
|
onnx_input = torch.Tensor(np.zeros([1, *input_shape]))
|
|
if not out_path.endswith(".onnx"):
|
|
if not out_path.endswith(".onnx"):
|
|
out_path = out_path + ".onnx"
|
|
out_path = out_path + ".onnx"
|