Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#585 PLFM-3331 Register experiments with model name

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/plfm-3331_model_name_in_experiment
@@ -1,4 +1,5 @@
 import os
 import os
+from typing import Optional
 
 
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
@@ -36,6 +37,7 @@ class DeciPlatformSGLogger(BaseSGLogger):
         save_tensorboard_remote: bool = True,
         save_tensorboard_remote: bool = True,
         save_logs_remote: bool = True,
         save_logs_remote: bool = True,
         monitor_system: bool = True,
         monitor_system: bool = True,
+        model_name: Optional[str] = None,
     ):
     ):
 
 
         if _imported_deci_lab_failure is not None:
         if _imported_deci_lab_failure is not None:
@@ -59,7 +61,13 @@ class DeciPlatformSGLogger(BaseSGLogger):
 
 
         self.platform_client = DeciPlatformClient()
         self.platform_client = DeciPlatformClient()
         self.platform_client.login(token=os.getenv("DECI_PLATFORM_TOKEN"))
         self.platform_client.login(token=os.getenv("DECI_PLATFORM_TOKEN"))
-        self.platform_client.register_experiment(name=experiment_name)
+        if model_name is None:
+            logger.warning(
+                "'model_name' parameter not passed. "
+                "The experiment won't be connected to an architecture in the Deci platform. "
+                "To pass a model_name, please use the 'sg_logger_params.model_name' field in the training recipe."
+            )
+        self.platform_client.register_experiment(name=experiment_name, model_name=model_name if model_name else None)
         self.checkpoints_dir_path = checkpoints_dir_path
         self.checkpoints_dir_path = checkpoints_dir_path
 
 
     @multi_process_safe
     @multi_process_safe
Discard