|
@@ -1,4 +1,5 @@
|
|
import os
|
|
import os
|
|
|
|
+from typing import Optional
|
|
|
|
|
|
from super_gradients.common.abstractions.abstract_logger import get_logger
|
|
from super_gradients.common.abstractions.abstract_logger import get_logger
|
|
from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
|
|
from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
|
|
@@ -36,6 +37,7 @@ class DeciPlatformSGLogger(BaseSGLogger):
|
|
save_tensorboard_remote: bool = True,
|
|
save_tensorboard_remote: bool = True,
|
|
save_logs_remote: bool = True,
|
|
save_logs_remote: bool = True,
|
|
monitor_system: bool = True,
|
|
monitor_system: bool = True,
|
|
|
|
+ model_name: Optional[str] = None,
|
|
):
|
|
):
|
|
|
|
|
|
if _imported_deci_lab_failure is not None:
|
|
if _imported_deci_lab_failure is not None:
|
|
@@ -59,7 +61,13 @@ class DeciPlatformSGLogger(BaseSGLogger):
|
|
|
|
|
|
self.platform_client = DeciPlatformClient()
|
|
self.platform_client = DeciPlatformClient()
|
|
self.platform_client.login(token=os.getenv("DECI_PLATFORM_TOKEN"))
|
|
self.platform_client.login(token=os.getenv("DECI_PLATFORM_TOKEN"))
|
|
- self.platform_client.register_experiment(name=experiment_name)
|
|
|
|
|
|
+ if model_name is None:
|
|
|
|
+ logger.warning(
|
|
|
|
+ "'model_name' parameter not passed. "
|
|
|
|
+ "The experiment won't be connected to an architecture in the Deci platform. "
|
|
|
|
+ "To pass a model_name, please use the 'sg_logger_params.model_name' field in the training recipe."
|
|
|
|
+ )
|
|
|
|
+ self.platform_client.register_experiment(name=experiment_name, model_name=model_name if model_name else None)
|
|
self.checkpoints_dir_path = checkpoints_dir_path
|
|
self.checkpoints_dir_path = checkpoints_dir_path
|
|
|
|
|
|
@multi_process_safe
|
|
@multi_process_safe
|