Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#643 PPYolo-E

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-344-PP-Yolo-E-Training-Replicate-Recipe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
  1. import os
  2. import io
  3. from contextlib import contextmanager
  4. from typing import Optional
  5. from super_gradients.common.abstractions.abstract_logger import get_logger
  6. from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger, EXPERIMENT_LOGS_PREFIX, LOGGER_LOGS_PREFIX, CONSOLE_LOGS_PREFIX
  7. from super_gradients.common.environment.ddp_utils import multi_process_safe
  8. from super_gradients.common.plugins.deci_client import DeciClient
  9. from contextlib import redirect_stdout
  10. logger = get_logger(__name__)
  11. TENSORBOARD_EVENTS_PREFIX = "events.out.tfevents"
  12. class DeciPlatformSGLogger(BaseSGLogger):
  13. """Logger responsible to push logs and tensorboard artifacts to Deci platform."""
  14. def __init__(
  15. self,
  16. project_name: str,
  17. experiment_name: str,
  18. storage_location: str,
  19. resumed: bool,
  20. training_params: dict,
  21. checkpoints_dir_path: str,
  22. tb_files_user_prompt: bool = False,
  23. launch_tensorboard: bool = False,
  24. tensorboard_port: int = None,
  25. save_checkpoints_remote: bool = True,
  26. save_tensorboard_remote: bool = True,
  27. save_logs_remote: bool = True,
  28. monitor_system: bool = True,
  29. model_name: Optional[str] = None,
  30. ):
  31. super().__init__(
  32. project_name=project_name,
  33. experiment_name=experiment_name,
  34. storage_location=storage_location,
  35. resumed=resumed,
  36. training_params=training_params,
  37. checkpoints_dir_path=checkpoints_dir_path,
  38. tb_files_user_prompt=tb_files_user_prompt,
  39. launch_tensorboard=launch_tensorboard,
  40. tensorboard_port=tensorboard_port,
  41. save_checkpoints_remote=save_checkpoints_remote,
  42. save_tensorboard_remote=save_tensorboard_remote,
  43. save_logs_remote=save_logs_remote,
  44. monitor_system=monitor_system,
  45. )
  46. if model_name is None:
  47. logger.warning(
  48. "'model_name' parameter not passed. "
  49. "The experiment won't be connected to an architecture in the Deci platform. "
  50. "To pass a model_name, please use the 'sg_logger_params.model_name' field in the training recipe."
  51. )
  52. self.platform_client = DeciClient()
  53. self.platform_client.register_experiment(name=experiment_name, model_name=model_name if model_name else None)
  54. self.checkpoints_dir_path = checkpoints_dir_path
  55. @multi_process_safe
  56. def upload(self):
  57. """
  58. Upload both to the destination specified by the user (base behavior), and to Deci platform.
  59. """
  60. # Upload to the destination specified by the user
  61. super(DeciPlatformSGLogger, self).upload()
  62. # Upload to Deci platform
  63. if not os.path.isdir(self.checkpoints_dir_path):
  64. raise ValueError("Provided directory does not exist")
  65. self._upload_latest_file_starting_with(start_with=TENSORBOARD_EVENTS_PREFIX)
  66. self._upload_latest_file_starting_with(start_with=EXPERIMENT_LOGS_PREFIX)
  67. self._upload_latest_file_starting_with(start_with=LOGGER_LOGS_PREFIX)
  68. self._upload_latest_file_starting_with(start_with=CONSOLE_LOGS_PREFIX)
  69. self._upload_folder_files(folder_name=".hydra")
  70. @multi_process_safe
  71. def _upload_latest_file_starting_with(self, start_with: str):
  72. """
  73. Upload the most recent file starting with a specific prefix to the Deci platform.
  74. :param start_with: prefix of the file to upload
  75. """
  76. files_path = [
  77. os.path.join(self.checkpoints_dir_path, file_name) for file_name in os.listdir(self.checkpoints_dir_path) if file_name.startswith(start_with)
  78. ]
  79. most_recent_file_path = max(files_path, key=os.path.getctime)
  80. self._save_experiment_file(file_path=most_recent_file_path)
  81. @multi_process_safe
  82. def _upload_folder_files(self, folder_name: str):
  83. """
  84. Upload all the files of a given folder.
  85. :param folder_name: Name of the folder that contains the files to upload
  86. """
  87. folder_path = os.path.join(self.checkpoints_dir_path, folder_name)
  88. if not os.path.exists(folder_path):
  89. return
  90. for file in os.listdir(folder_path):
  91. self._save_experiment_file(file_path=f"{folder_path}/{file}")
  92. def _save_experiment_file(self, file_path: str):
  93. with log_stdout(): # TODO: remove when platform_client remove prints from save_experiment_file
  94. self.platform_client.save_experiment_file(file_path=file_path)
  95. logger.info(f"File saved to Deci platform: {file_path}")
  96. @contextmanager
  97. def log_stdout():
  98. """Redirect stdout to DEBUG."""
  99. buffer = io.StringIO()
  100. with redirect_stdout(buffer):
  101. yield
  102. redirected_str = buffer.getvalue()
  103. if redirected_str:
  104. logger.debug(msg=redirected_str)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file