Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#751 input_size passed to prep_ model for conversion

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bug/SG-669_models_convert_input_shape_bug
@@ -12,7 +12,7 @@ strict_load: no_key_matching # One of [On, Off, no_key_matching] (case insensiti
 
 
 # CONVERSION RELATED PARAMS
 # CONVERSION RELATED PARAMS
 out_path: # str, Destination path for the .onnx file. When None- will be set to the checkpoint_path.replace(".ckpt",".onnx").
 out_path: # str, Destination path for the .onnx file. When None- will be set to the checkpoint_path.replace(".ckpt",".onnx").
-input_shape: # input shape, not including batch_size. Always channels first (i.e (3, 224, 224)).
+input_shape: # DEPRECATED USE input_size KWARG IN prep_model_for_conversion_kwargs INSTEAD.
 pre_process: # Preprocessing pipeline, will be resolved by TransformsFactory(), and will be baked into the converted model (optional).
 pre_process: # Preprocessing pipeline, will be resolved by TransformsFactory(), and will be baked into the converted model (optional).
 post_process: # Postprocessing pipeline, will be resolved by TransformsFactory(), and will be baked into the converted model (optional).
 post_process: # Postprocessing pipeline, will be resolved by TransformsFactory(), and will be baked into the converted model (optional).
 prep_model_for_conversion_kwargs: # For SgModules, args to be passed to model.prep_model_for_conversion prior to torch.onnx.export call.
 prep_model_for_conversion_kwargs: # For SgModules, args to be passed to model.prep_model_for_conversion prior to torch.onnx.export call.
Discard
@@ -52,7 +52,7 @@ class ConvertableCompletePipelineModel(torch.nn.Module):
 def convert_to_onnx(
 def convert_to_onnx(
     model: torch.nn.Module,
     model: torch.nn.Module,
     out_path: str,
     out_path: str,
-    input_shape: tuple,
+    input_shape: tuple = None,
     pre_process: torch.nn.Module = None,
     pre_process: torch.nn.Module = None,
     post_process: torch.nn.Module = None,
     post_process: torch.nn.Module = None,
     prep_model_for_conversion_kwargs=None,
     prep_model_for_conversion_kwargs=None,
@@ -64,7 +64,7 @@ def convert_to_onnx(
 
 
     :param model: torch.nn.Module, model to export to ONNX.
     :param model: torch.nn.Module, model to export to ONNX.
     :param out_path: str, destination path for the .onnx file.
     :param out_path: str, destination path for the .onnx file.
-    :param input_shape: tuple, input shape, excluding batch_size (i.e (3, 224, 224)).
+    :param input_shape: DEPRECATED USE input_size KWARG IN prep_model_for_conversion_kwargs INSTEAD.
     :param pre_process: torch.nn.Module, preprocessing pipeline, will be resolved by TransformsFactory()
     :param pre_process: torch.nn.Module, preprocessing pipeline, will be resolved by TransformsFactory()
     :param post_process: torch.nn.Module, postprocessing pipeline, will be resolved by TransformsFactory()
     :param post_process: torch.nn.Module, postprocessing pipeline, will be resolved by TransformsFactory()
     :param prep_model_for_conversion_kwargs: dict, for SgModules- args to be passed to model.prep_model_for_conversion
     :param prep_model_for_conversion_kwargs: dict, for SgModules- args to be passed to model.prep_model_for_conversion
@@ -79,6 +79,14 @@ def convert_to_onnx(
         raise FileNotFoundError(f"Could not find destination directory {out_path} for the ONNX file.")
         raise FileNotFoundError(f"Could not find destination directory {out_path} for the ONNX file.")
     torch_onnx_export_kwargs = torch_onnx_export_kwargs or dict()
     torch_onnx_export_kwargs = torch_onnx_export_kwargs or dict()
     prep_model_for_conversion_kwargs = prep_model_for_conversion_kwargs or dict()
     prep_model_for_conversion_kwargs = prep_model_for_conversion_kwargs or dict()
+
+    if input_shape is not None:
+        logger.warning(
+            "input_shape is deprecated and will be removed in the next major release." " Use the input_size kwarg in prep_model_for_conversion_kwargs instead"
+        )
+
+    prep_model_for_conversion_kwargs["input_size"] = (1, *input_shape)
+
     onnx_input = torch.Tensor(np.zeros([1, *input_shape]))
     onnx_input = torch.Tensor(np.zeros([1, *input_shape]))
     if not out_path.endswith(".onnx"):
     if not out_path.endswith(".onnx"):
         out_path = out_path + ".onnx"
         out_path = out_path + ".onnx"
Discard