Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#636 Hotfix/sg 000 reduce import loops risk

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:hotfix/SG-000-reduce_import_loops_risk
@@ -1,20 +1,21 @@
 import json
 import json
 import logging
 import logging
 
 
-from super_gradients.common import AWSConnector
-from super_gradients.common import explicit_params_validation
+from super_gradients.common.aws_connection.aws_connector import AWSConnector
+from super_gradients.common.decorators.explicit_params_validator import explicit_params_validation
 
 
 
 
 class AWSSecretsManagerConnector:
 class AWSSecretsManagerConnector:
     """
     """
     AWSSecretsManagerConnector - This class handles the AWS Secrets Manager connection
     AWSSecretsManagerConnector - This class handles the AWS Secrets Manager connection
     """
     """
+
     __slots__ = []  # Making the class immutable for runtime safety
     __slots__ = []  # Making the class immutable for runtime safety
     current_environment_client = None
     current_environment_client = None
-    DECI_ENVIRONMENTS = ['research', 'development', 'staging', 'production']
+    DECI_ENVIRONMENTS = ["research", "development", "staging", "production"]
 
 
     @staticmethod
     @staticmethod
-    @explicit_params_validation(validation_type='NoneOrEmpty')
+    @explicit_params_validation(validation_type="NoneOrEmpty")
     def get_secret_value_for_secret_key(aws_env: str, secret_name: str, secret_key: str) -> str:
     def get_secret_value_for_secret_key(aws_env: str, secret_name: str, secret_key: str) -> str:
         """
         """
         get_secret_value_for_secret_key - Gets a Secret Value from AWS Secrets Manager for the Provided Key
         get_secret_value_for_secret_key - Gets a Secret Value from AWS Secrets Manager for the Provided Key
@@ -26,21 +27,19 @@ class AWSSecretsManagerConnector:
         current_class_name = __class__.__name__
         current_class_name = __class__.__name__
         logger = logging.getLogger(current_class_name)
         logger = logging.getLogger(current_class_name)
         secret_key = secret_key.upper()
         secret_key = secret_key.upper()
-        aws_secrets_dict = AWSSecretsManagerConnector.__get_secrets_manager_dict_for_secret_name(
-            aws_env=aws_env, secret_name=secret_name)
+        aws_secrets_dict = AWSSecretsManagerConnector.__get_secrets_manager_dict_for_secret_name(aws_env=aws_env, secret_name=secret_name)
 
 
-        secret_key = '.'.join([aws_env.upper(), secret_key])
+        secret_key = ".".join([aws_env.upper(), secret_key])
         if secret_key not in aws_secrets_dict.keys():
         if secret_key not in aws_secrets_dict.keys():
-            error = f'[{current_class_name}] - Secret Key ({secret_key}) not Found in AWS Secret: ' + secret_name
+            error = f"[{current_class_name}] - Secret Key ({secret_key}) not Found in AWS Secret: " + secret_name
             logger.error(error)
             logger.error(error)
             raise EnvironmentError(error)
             raise EnvironmentError(error)
         else:
         else:
             return aws_secrets_dict[secret_key]
             return aws_secrets_dict[secret_key]
 
 
     @staticmethod
     @staticmethod
-    @explicit_params_validation(validation_type='NoneOrEmpty')
-    def get_secret_values_dict_for_secret_key_properties(env: str, secret_key: str, secret_name: str,
-                                                         db_properties_set: set = None) -> dict:
+    @explicit_params_validation(validation_type="NoneOrEmpty")
+    def get_secret_values_dict_for_secret_key_properties(env: str, secret_key: str, secret_name: str, db_properties_set: set = None) -> dict:
         """
         """
         get_config_dict - Returns the config dict of the properties from the properties dict
         get_config_dict - Returns the config dict of the properties from the properties dict
             :param  env:                The environment to open the dict for
             :param  env:                The environment to open the dict for
@@ -51,32 +50,34 @@ class AWSSecretsManagerConnector:
         """
         """
         current_class_name = __class__.__name__
         current_class_name = __class__.__name__
         logger = logging.getLogger(current_class_name)
         logger = logging.getLogger(current_class_name)
-        aws_secrets_dict = AWSSecretsManagerConnector.__get_secrets_manager_dict_for_secret_name(
-            aws_env=env, secret_name=secret_name)
+        aws_secrets_dict = AWSSecretsManagerConnector.__get_secrets_manager_dict_for_secret_name(aws_env=env, secret_name=secret_name)
 
 
         aws_env_safe_secrets = {}
         aws_env_safe_secrets = {}
         # FILL THE DICT VALUES FROM THE AWS SECRETS RESPONSE
         # FILL THE DICT VALUES FROM THE AWS SECRETS RESPONSE
         if db_properties_set:
         if db_properties_set:
             for secret_key_property in db_properties_set:
             for secret_key_property in db_properties_set:
-                secret_key_to_retrieve = '.'.join([env.upper(), secret_key, secret_key_property])
+                secret_key_to_retrieve = ".".join([env.upper(), secret_key, secret_key_property])
                 if secret_key_to_retrieve not in aws_secrets_dict:
                 if secret_key_to_retrieve not in aws_secrets_dict:
-                    error = f'[{current_class_name}] - Error retrieving data from AWS Secrets Manager for Secret Key "{secret_name}": ' \
-                            f'The secret property "{secret_key_property}" Does Not Exist'
+                    error = (
+                        f'[{current_class_name}] - Error retrieving data from AWS Secrets Manager for Secret Key "{secret_name}": '
+                        f'The secret property "{secret_key_property}" Does Not Exist'
+                    )
                     logger.error(error)
                     logger.error(error)
                     raise EnvironmentError(error)
                     raise EnvironmentError(error)
                 else:
                 else:
-                    env_stripped_key_name = secret_key_to_retrieve.lstrip(env.upper()).lstrip('.')
+                    env_stripped_key_name = secret_key_to_retrieve.lstrip(env.upper()).lstrip(".")
                     aws_env_safe_secrets[env_stripped_key_name] = aws_secrets_dict[secret_key_to_retrieve]
                     aws_env_safe_secrets[env_stripped_key_name] = aws_secrets_dict[secret_key_to_retrieve]
         else:
         else:
             # "db_properties_set" is not specified - validating and returning all the secret keys and values for
             # "db_properties_set" is not specified - validating and returning all the secret keys and values for
             # the secret name.
             # the secret name.
             for secret_key_name, secret_value in aws_secrets_dict.items():
             for secret_key_name, secret_value in aws_secrets_dict.items():
-                secret_key_to_retrieve = '.'.join([env.upper(), secret_key])
-                assert secret_key_name.startswith(
-                    env.upper()), f'The secret key property "{secret_key_name}", found in secret named {secret_name}, is not following the convention of ' \
-                                  f'environment prefix. please add the environment prefix "{env.upper()}" to property "{secret_key_name}"'
+                secret_key_to_retrieve = ".".join([env.upper(), secret_key])
+                assert secret_key_name.startswith(env.upper()), (
+                    f'The secret key property "{secret_key_name}", found in secret named {secret_name}, is not following the convention of '
+                    f'environment prefix. please add the environment prefix "{env.upper()}" to property "{secret_key_name}"'
+                )
                 if secret_key_name.startswith(secret_key_to_retrieve):
                 if secret_key_name.startswith(secret_key_to_retrieve):
-                    env_stripped_key_name = secret_key_name.lstrip(env.upper()).lstrip('.')
+                    env_stripped_key_name = secret_key_name.lstrip(env.upper()).lstrip(".")
                     aws_env_safe_secrets[env_stripped_key_name] = secret_value
                     aws_env_safe_secrets[env_stripped_key_name] = secret_value
         return aws_env_safe_secrets
         return aws_env_safe_secrets
 
 
@@ -95,18 +96,19 @@ class AWSSecretsManagerConnector:
 
 
         try:
         try:
             if not AWSSecretsManagerConnector.current_environment_client:
             if not AWSSecretsManagerConnector.current_environment_client:
-                logger.debug('Initializing a new secrets manager client...')
+                logger.debug("Initializing a new secrets manager client...")
                 AWSSecretsManagerConnector.current_environment_client = AWSConnector.get_aws_client_for_service_name(
                 AWSSecretsManagerConnector.current_environment_client = AWSConnector.get_aws_client_for_service_name(
-                    profile_name=aws_env,
-                    service_name='secretsmanager')
+                    profile_name=aws_env, service_name="secretsmanager"
+                )
             logger.debug(f'Fetching the secret "{secret_name}" in env "{aws_env}"')
             logger.debug(f'Fetching the secret "{secret_name}" in env "{aws_env}"')
             aws_secrets = AWSSecretsManagerConnector.current_environment_client.get_secret_value(SecretId=secrets_path)
             aws_secrets = AWSSecretsManagerConnector.current_environment_client.get_secret_value(SecretId=secrets_path)
-            aws_secrets_dict = json.loads(aws_secrets['SecretString'])
+            aws_secrets_dict = json.loads(aws_secrets["SecretString"])
             return aws_secrets_dict
             return aws_secrets_dict
 
 
         except Exception as ex:
         except Exception as ex:
-            error = '[' + current_class_name + '] - Caught Exception while trying to connect to aws to get credentials from secrets manager: ' + '"' + str(
-                ex) + '"' + ' for ' + str(secrets_path)
+            error = (
+                f'[{current_class_name}] - Caught Exception while trying to connect to aws to get credentials from secrets manager: "{ex}" for {secrets_path}'
+            )
             logger.error(error)
             logger.error(error)
             raise EnvironmentError(error)
             raise EnvironmentError(error)
 
 
@@ -123,8 +125,8 @@ class AWSSecretsManagerConnector:
 
 
         # Checking for lowercase exact match, in order to prevent any implicit usage of the environments.
         # Checking for lowercase exact match, in order to prevent any implicit usage of the environments.
         if aws_env not in AWSSecretsManagerConnector.DECI_ENVIRONMENTS:
         if aws_env not in AWSSecretsManagerConnector.DECI_ENVIRONMENTS:
-            logger.critical('[' + current_class_name + ' ] -  wrong environment param... Exiting')
-            raise Exception('[' + current_class_name + '] - wrong environment param')
+            logger.critical("[" + current_class_name + " ] -  wrong environment param... Exiting")
+            raise Exception("[" + current_class_name + "] - wrong environment param")
 
 
-        secrets_path = '/'.join([aws_env, secret_name])
+        secrets_path = "/".join([aws_env, secret_name])
         return secrets_path
         return secrets_path
Discard
@@ -5,8 +5,8 @@ from typing import List
 
 
 import botocore
 import botocore
 
 
-from super_gradients.common import AWSConnector
-from super_gradients.common import explicit_params_validation
+from super_gradients.common.aws_connection.aws_connector import AWSConnector
+from super_gradients.common.decorators.explicit_params_validator import explicit_params_validation
 from super_gradients.common.abstractions.abstract_logger import ILogger
 from super_gradients.common.abstractions.abstract_logger import ILogger
 
 
 
 
@@ -26,10 +26,10 @@ class S3Connector(ILogger):
         super().__init__()
         super().__init__()
         self.env = env
         self.env = env
         self.bucket_name = bucket_name
         self.bucket_name = bucket_name
-        self.s3_client = AWSConnector.get_aws_client_for_service_name(profile_name=env, service_name='s3')
-        self.s3_resource = AWSConnector.get_aws_resource_for_service_name(profile_name=env, service_name='s3')
+        self.s3_client = AWSConnector.get_aws_client_for_service_name(profile_name=env, service_name="s3")
+        self.s3_resource = AWSConnector.get_aws_resource_for_service_name(profile_name=env, service_name="s3")
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
+    @explicit_params_validation(validation_type="NoneOrEmpty")
     def check_key_exists(self, s3_key_to_check: str) -> bool:
     def check_key_exists(self, s3_key_to_check: str) -> bool:
         """
         """
         check_key_exists - Checks if an S3 key exists
         check_key_exists - Checks if an S3 key exists
@@ -39,16 +39,15 @@ class S3Connector(ILogger):
         try:
         try:
             self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key_to_check)
             self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key_to_check)
         except botocore.exceptions.ClientError as ex:
         except botocore.exceptions.ClientError as ex:
-            if ex.response['Error']['Code'] == "404":
+            if ex.response["Error"]["Code"] == "404":
                 return False
                 return False
             else:
             else:
-                self._logger.error(
-                    'Failed to check key: ' + str(s3_key_to_check) + ' existence in bucket' + str(self.bucket_name))
+                self._logger.error("Failed to check key: " + str(s3_key_to_check) + " existence in bucket" + str(self.bucket_name))
                 return None
                 return None
         else:
         else:
             return True
             return True
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
+    @explicit_params_validation(validation_type="NoneOrEmpty")
     def get_object_by_etag(self, bucket_relative_file_name: str, etag: str) -> object:
     def get_object_by_etag(self, bucket_relative_file_name: str, etag: str) -> object:
         """
         """
         get_object_by_etag - Gets S3 object by it's ETag heder if it. exists
         get_object_by_etag - Gets S3 object by it's ETag heder if it. exists
@@ -61,14 +60,13 @@ class S3Connector(ILogger):
             s3_object = self.s3_client.get_object(Bucket=self.bucket_name, Key=bucket_relative_file_name, IfMatch=etag)
             s3_object = self.s3_client.get_object(Bucket=self.bucket_name, Key=bucket_relative_file_name, IfMatch=etag)
             return s3_object
             return s3_object
         except botocore.exceptions.ClientError as ex:
         except botocore.exceptions.ClientError as ex:
-            if ex.response['Error']['Code'] == "404":
+            if ex.response["Error"]["Code"] == "404":
                 return False
                 return False
             else:
             else:
-                self._logger.error(
-                    'Failed to check ETag: ' + str(etag) + ' existence in bucket ' + str(self.bucket_name))
+                self._logger.error("Failed to check ETag: " + str(etag) + " existence in bucket " + str(self.bucket_name))
         return
         return
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
+    @explicit_params_validation(validation_type="NoneOrEmpty")
     def create_bucket(self) -> bool:
     def create_bucket(self) -> bool:
         """
         """
         Creates a bucket with the initialized bucket name.
         Creates a bucket with the initialized bucket name.
@@ -77,22 +75,14 @@ class S3Connector(ILogger):
         """
         """
         try:
         try:
             # TODO: Change bucket_owner_arn to the company's proper IAM Role
             # TODO: Change bucket_owner_arn to the company's proper IAM Role
-            self._logger.info('Creating Bucket: ' + self.bucket_name)
-            create_bucket_response = self.s3_client.create_bucket(
-                ACL='private',
-                Bucket=self.bucket_name
-            )
-            self._logger.info(f'Successfully created bucket: {create_bucket_response}')
+            self._logger.info("Creating Bucket: " + self.bucket_name)
+            create_bucket_response = self.s3_client.create_bucket(ACL="private", Bucket=self.bucket_name)
+            self._logger.info(f"Successfully created bucket: {create_bucket_response}")
 
 
             # Changing the bucket public access block to be private (disable public access)
             # Changing the bucket public access block to be private (disable public access)
-            self._logger.debug('Disabling public access to the bucket...')
+            self._logger.debug("Disabling public access to the bucket...")
             self.s3_client.put_public_access_block(
             self.s3_client.put_public_access_block(
-                PublicAccessBlockConfiguration={
-                    'BlockPublicAcls': True,
-                    'IgnorePublicAcls': True,
-                    'BlockPublicPolicy': True,
-                    'RestrictPublicBuckets': True
-                },
+                PublicAccessBlockConfiguration={"BlockPublicAcls": True, "IgnorePublicAcls": True, "BlockPublicPolicy": True, "RestrictPublicBuckets": True},
                 Bucket=self.bucket_name,
                 Bucket=self.bucket_name,
             )
             )
             return create_bucket_response
             return create_bucket_response
@@ -100,7 +90,7 @@ class S3Connector(ILogger):
             self._logger.fatal(f'Failed to create bucket "{self.bucket_name}": {err}')
             self._logger.fatal(f'Failed to create bucket "{self.bucket_name}": {err}')
             raise
             raise
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
+    @explicit_params_validation(validation_type="NoneOrEmpty")
     def delete_bucket(self):
     def delete_bucket(self):
         """
         """
         Deletes a bucket with the initialized bucket name.
         Deletes a bucket with the initialized bucket name.
@@ -108,28 +98,28 @@ class S3Connector(ILogger):
         :raises ClientError: If the creation failed for any reason.
         :raises ClientError: If the creation failed for any reason.
         """
         """
         try:
         try:
-            self._logger.info('Deleting Bucket: ' + self.bucket_name + ' from S3')
+            self._logger.info("Deleting Bucket: " + self.bucket_name + " from S3")
             bucket = self.s3_resource.Bucket(self.bucket_name)
             bucket = self.s3_resource.Bucket(self.bucket_name)
             bucket.objects.all().delete()
             bucket.objects.all().delete()
             bucket.delete()
             bucket.delete()
-            self._logger.debug('Successfully Deleted Bucket: ' + self.bucket_name + ' from S3')
+            self._logger.debug("Successfully Deleted Bucket: " + self.bucket_name + " from S3")
         except botocore.exceptions.ClientError as ex:
         except botocore.exceptions.ClientError as ex:
-            self._logger.fatal(f'Failed to delete bucket {self.bucket_name}: {ex}')
+            self._logger.fatal(f"Failed to delete bucket {self.bucket_name}: {ex}")
             raise ex
             raise ex
         return True
         return True
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
+    @explicit_params_validation(validation_type="NoneOrEmpty")
     def get_object_metadata(self, s3_key: str):
     def get_object_metadata(self, s3_key: str):
         try:
         try:
             return self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key)
             return self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key)
         except botocore.exceptions.ClientError as ex:
         except botocore.exceptions.ClientError as ex:
-            if ex.response['Error']['Code'] == '404':
-                msg = '[' + sys._getframe().f_code.co_name + '] - Key does not exist in bucket)'
+            if ex.response["Error"]["Code"] == "404":
+                msg = "[" + sys._getframe().f_code.co_name + "] - Key does not exist in bucket)"
                 self._logger.error(msg)
                 self._logger.error(msg)
                 raise KeyNotExistInBucketError(msg)
                 raise KeyNotExistInBucketError(msg)
             raise ex
             raise ex
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
+    @explicit_params_validation(validation_type="NoneOrEmpty")
     def delete_key(self, s3_key_to_delete: str) -> bool:
     def delete_key(self, s3_key_to_delete: str) -> bool:
         """
         """
         delete_key - Deletes a Key from an S3 Bucket
         delete_key - Deletes a Key from an S3 Bucket
@@ -137,21 +127,20 @@ class S3Connector(ILogger):
             :return: True/False if the operation succeeded/failed
             :return: True/False if the operation succeeded/failed
         """
         """
         try:
         try:
-            self._logger.debug('Deleting Key: ' + s3_key_to_delete + ' from S3 bucket: ' + self.bucket_name)
+            self._logger.debug("Deleting Key: " + s3_key_to_delete + " from S3 bucket: " + self.bucket_name)
             obj_status = self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key_to_delete)
             obj_status = self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key_to_delete)
         except botocore.exceptions.ClientError as ex:
         except botocore.exceptions.ClientError as ex:
-            if ex.response['Error']['Code'] == '404':
-                self._logger.error('[' + sys._getframe().f_code.co_name + '] - Key does not exist in bucket)')
+            if ex.response["Error"]["Code"] == "404":
+                self._logger.error("[" + sys._getframe().f_code.co_name + "] - Key does not exist in bucket)")
             return False
             return False
 
 
-        if obj_status['ContentLength']:
-            self._logger.debug(
-                '[' + sys._getframe().f_code.co_name + '] - Deleting file s3://' + self.bucket_name + s3_key_to_delete)
+        if obj_status["ContentLength"]:
+            self._logger.debug("[" + sys._getframe().f_code.co_name + "] - Deleting file s3://" + self.bucket_name + s3_key_to_delete)
             self.s3_client.delete_object(Bucket=self.bucket_name, Key=s3_key_to_delete)
             self.s3_client.delete_object(Bucket=self.bucket_name, Key=s3_key_to_delete)
 
 
         return True
         return True
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
+    @explicit_params_validation(validation_type="NoneOrEmpty")
     def upload_file_from_stream(self, file, key: str):
     def upload_file_from_stream(self, file, key: str):
         """
         """
         upload_file - Uploads a file to S3 via boto3 interface
         upload_file - Uploads a file to S3 via boto3 interface
@@ -161,17 +150,15 @@ class S3Connector(ILogger):
             :return True/False if the operation succeeded/failed
             :return True/False if the operation succeeded/failed
         """
         """
         try:
         try:
-            self._logger.debug('Uploading Key: ' + key + ' to S3 bucket: ' + self.bucket_name)
+            self._logger.debug("Uploading Key: " + key + " to S3 bucket: " + self.bucket_name)
             buffer = BytesIO(file)
             buffer = BytesIO(file)
             self.upload_buffer(key, buffer)
             self.upload_buffer(key, buffer)
             return True
             return True
         except Exception as ex:
         except Exception as ex:
-            self._logger.critical(
-                '[' + sys._getframe().f_code.co_name + '] - Caught Exception while trying to upload file ' + str(
-                    key) + 'to S3' + str(ex))
+            self._logger.critical("[" + sys._getframe().f_code.co_name + "] - Caught Exception while trying to upload file " + str(key) + "to S3" + str(ex))
             return False
             return False
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
+    @explicit_params_validation(validation_type="NoneOrEmpty")
     def upload_file(self, filename_to_upload: str, key: str):
     def upload_file(self, filename_to_upload: str, key: str):
         """
         """
         upload_file - Uploads a file to S3 via boto3 interface
         upload_file - Uploads a file to S3 via boto3 interface
@@ -181,18 +168,16 @@ class S3Connector(ILogger):
             :return True/False if the operation succeeded/failed
             :return True/False if the operation succeeded/failed
         """
         """
         try:
         try:
-            self._logger.debug('Uploading Key: ' + key + ' to S3 bucket: ' + self.bucket_name)
+            self._logger.debug("Uploading Key: " + key + " to S3 bucket: " + self.bucket_name)
 
 
             self.s3_client.upload_file(Bucket=self.bucket_name, Filename=filename_to_upload, Key=key)
             self.s3_client.upload_file(Bucket=self.bucket_name, Filename=filename_to_upload, Key=key)
             return True
             return True
 
 
         except Exception as ex:
         except Exception as ex:
-            self._logger.critical(
-                '[' + sys._getframe().f_code.co_name + '] - Caught Exception while trying to upload file ' + str(
-                    filename_to_upload) + 'to S3' + str(ex))
+            self._logger.critical(f"[{sys._getframe().f_code.co_name}] - Caught Exception while trying to upload file {filename_to_upload} to S3 {ex}")
             return False
             return False
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
+    @explicit_params_validation(validation_type="NoneOrEmpty")
     def download_key(self, target_path: str, key_to_download: str) -> bool:
     def download_key(self, target_path: str, key_to_download: str) -> bool:
         """
         """
         download_file - Downloads a key from S3 using boto3 to the provided filename
         download_file - Downloads a key from S3 using boto3 to the provided filename
@@ -202,22 +187,19 @@ class S3Connector(ILogger):
             :return:                   True/False if the operation succeeded/failed
             :return:                   True/False if the operation succeeded/failed
         """
         """
         try:
         try:
-            self._logger.debug('Uploading Key: ' + key_to_download + ' from S3 bucket: ' + self.bucket_name)
+            self._logger.debug("Uploading Key: " + key_to_download + " from S3 bucket: " + self.bucket_name)
             self.s3_client.download_file(Bucket=self.bucket_name, Filename=target_path, Key=key_to_download)
             self.s3_client.download_file(Bucket=self.bucket_name, Filename=target_path, Key=key_to_download)
         except botocore.exceptions.ClientError as ex:
         except botocore.exceptions.ClientError as ex:
-            if ex.response['Error']['Code'] == '404':
-                self._logger.error('[' + sys._getframe().f_code.co_name + '] - Key does exist in bucket)')
+            if ex.response["Error"]["Code"] == "404":
+                self._logger.error("[" + sys._getframe().f_code.co_name + "] - Key does exist in bucket)")
             else:
             else:
-                self._logger.critical(
-                    '[' + sys._getframe().f_code.co_name + '] - Caught Exception while trying to download key ' + str(
-                        key_to_download) + ' from S3 ' + str(ex))
+                self._logger.critical(f"[{sys._getframe().f_code.co_name}] - Caught Exception while trying to download key {key_to_download} from S3 {ex}")
             return False
             return False
 
 
         return True
         return True
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
-    def download_keys_by_prefix(self, s3_bucket_path_prefix: str, local_download_dir: str,
-                                s3_file_path_prefix: str = ''):
+    @explicit_params_validation(validation_type="NoneOrEmpty")
+    def download_keys_by_prefix(self, s3_bucket_path_prefix: str, local_download_dir: str, s3_file_path_prefix: str = ""):
         """
         """
         download_keys_by_prefix - Download all of the keys who match the provided in-bucket path prefix and file prefix
         download_keys_by_prefix - Download all of the keys who match the provided in-bucket path prefix and file prefix
             :param s3_bucket_path_prefix:   The S3 "folder" to download from
             :param s3_bucket_path_prefix:   The S3 "folder" to download from
@@ -226,22 +208,21 @@ class S3Connector(ILogger):
         :return:
         :return:
         """
         """
         if not os.path.isdir(local_download_dir):
         if not os.path.isdir(local_download_dir):
-            raise ValueError('[' + sys._getframe().f_code.co_name + '] - Provided directory does not exist')
+            raise ValueError("[" + sys._getframe().f_code.co_name + "] - Provided directory does not exist")
 
 
-        paginator = self.s3_client.get_paginator('list_objects')
-        prefix = s3_bucket_path_prefix if not s3_file_path_prefix else s3_bucket_path_prefix + '/' + s3_file_path_prefix
+        paginator = self.s3_client.get_paginator("list_objects")
+        prefix = s3_bucket_path_prefix if not s3_file_path_prefix else s3_bucket_path_prefix + "/" + s3_file_path_prefix
         page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix)
         page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix)
 
 
-        for item in page_iterator.search('Contents'):
+        for item in page_iterator.search("Contents"):
             if item is not None:
             if item is not None:
-                if item['Key'] == s3_bucket_path_prefix:
+                if item["Key"] == s3_bucket_path_prefix:
                     continue
                     continue
-            key_to_download = item['Key']
-            local_filename = key_to_download.split('/')[-1]
-            self.download_key(target_path=local_download_dir + '/' + local_filename,
-                              key_to_download=key_to_download)
+            key_to_download = item["Key"]
+            local_filename = key_to_download.split("/")[-1]
+            self.download_key(target_path=local_download_dir + "/" + local_filename, key_to_download=key_to_download)
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
+    @explicit_params_validation(validation_type="NoneOrEmpty")
     def download_file_by_path(self, s3_file_path: str, local_download_dir: str):
     def download_file_by_path(self, s3_file_path: str, local_download_dir: str):
         """
         """
         :param s3_file_path: str - path ot s3 file e.g./ "s3://x/y.zip"
         :param s3_file_path: str - path ot s3 file e.g./ "s3://x/y.zip"
@@ -250,39 +231,38 @@ class S3Connector(ILogger):
         """
         """
 
 
         if not os.path.isdir(local_download_dir):
         if not os.path.isdir(local_download_dir):
-            raise ValueError('[' + sys._getframe().f_code.co_name + '] - Provided directory does not exist')
+            raise ValueError("[" + sys._getframe().f_code.co_name + "] - Provided directory does not exist")
 
 
-        local_filename = s3_file_path.split('/')[-1]
-        self.download_key(target_path=local_download_dir + '/' + local_filename,
-                          key_to_download=s3_file_path)
+        local_filename = s3_file_path.split("/")[-1]
+        self.download_key(target_path=local_download_dir + "/" + local_filename, key_to_download=s3_file_path)
         return local_filename
         return local_filename
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
+    @explicit_params_validation(validation_type="NoneOrEmpty")
     def empty_folder_content_by_path_prefix(self, s3_bucket_path_prefix) -> list:
     def empty_folder_content_by_path_prefix(self, s3_bucket_path_prefix) -> list:
         """
         """
         empty_folder_content_by_path_prefix - Deletes all of the files in the specified bucket path
         empty_folder_content_by_path_prefix - Deletes all of the files in the specified bucket path
             :param s3_bucket_path_prefix: The "folder" to empty
             :param s3_bucket_path_prefix: The "folder" to empty
             :returns: Errors list
             :returns: Errors list
         """
         """
-        paginator = self.s3_client.get_paginator('list_objects')
+        paginator = self.s3_client.get_paginator("list_objects")
         page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=s3_bucket_path_prefix)
         page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=s3_bucket_path_prefix)
 
 
         files_dict_to_delete = dict(Objects=[])
         files_dict_to_delete = dict(Objects=[])
         errors_list = []
         errors_list = []
 
 
-        for item in page_iterator.search('Contents'):
+        for item in page_iterator.search("Contents"):
             if item is not None:
             if item is not None:
-                if item['Key'] == s3_bucket_path_prefix:
+                if item["Key"] == s3_bucket_path_prefix:
                     continue
                     continue
-                files_dict_to_delete['Objects'].append(dict(Key=item['Key']))
+                files_dict_to_delete["Objects"].append(dict(Key=item["Key"]))
 
 
                 # IF OBJECTS LIMIT HAS BEEN REACHED, FLUSH
                 # IF OBJECTS LIMIT HAS BEEN REACHED, FLUSH
-                if len(files_dict_to_delete['Objects']) >= 1000:
+                if len(files_dict_to_delete["Objects"]) >= 1000:
                     self._delete_files_left_in_list(errors_list, files_dict_to_delete)
                     self._delete_files_left_in_list(errors_list, files_dict_to_delete)
                     files_dict_to_delete = dict(Objects=[])
                     files_dict_to_delete = dict(Objects=[])
 
 
         # DELETE THE FILES LEFT IN THE LIST
         # DELETE THE FILES LEFT IN THE LIST
-        if len(files_dict_to_delete['Objects']):
+        if len(files_dict_to_delete["Objects"]):
             self._delete_files_left_in_list(errors_list, files_dict_to_delete)
             self._delete_files_left_in_list(errors_list, files_dict_to_delete)
 
 
         return errors_list
         return errors_list
@@ -291,13 +271,11 @@ class S3Connector(ILogger):
         try:
         try:
             s3_response = self.s3_client.delete_objects(Bucket=self.bucket_name, Delete=files_dict_to_delete)
             s3_response = self.s3_client.delete_objects(Bucket=self.bucket_name, Delete=files_dict_to_delete)
         except Exception as ex:
         except Exception as ex:
-            self._logger.critical(
-                '[' + sys._getframe().f_code.co_name + '] - Caught Exception while trying to delete keys ' + 'from S3 ' + str(
-                    ex))
-        if 'Errors' in s3_response:
-            errors_list.append(s3_response['Errors'])
+            self._logger.critical("[" + sys._getframe().f_code.co_name + "] - Caught Exception while trying to delete keys " + "from S3 " + str(ex))
+        if "Errors" in s3_response:
+            errors_list.append(s3_response["Errors"])
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
+    @explicit_params_validation(validation_type="NoneOrEmpty")
     def upload_buffer(self, new_key_name: str, buffer_to_write: StringIO):
     def upload_buffer(self, new_key_name: str, buffer_to_write: StringIO):
         """
         """
         Uploads a buffer into a file in S3 with the provided key name.
         Uploads a buffer into a file in S3 with the provided key name.
@@ -307,32 +285,28 @@ class S3Connector(ILogger):
         """
         """
         self.s3_resource.Object(self.bucket_name, new_key_name).put(Body=buffer_to_write.getvalue())
         self.s3_resource.Object(self.bucket_name, new_key_name).put(Body=buffer_to_write.getvalue())
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
+    @explicit_params_validation(validation_type="NoneOrEmpty")
     def list_bucket_objects(self, prefix: str = None) -> List[dict]:
     def list_bucket_objects(self, prefix: str = None) -> List[dict]:
         """
         """
         Gets a list of dictionaries, representing files in the S3 bucket that is passed in the constructor (self.bucket).
         Gets a list of dictionaries, representing files in the S3 bucket that is passed in the constructor (self.bucket).
         :param prefix: A prefix filter for the files names.
         :param prefix: A prefix filter for the files names.
         :return: the objects, dict as received from botocore.
         :return: the objects, dict as received from botocore.
         """
         """
-        paginator = self.s3_client.get_paginator('list_objects')
+        paginator = self.s3_client.get_paginator("list_objects")
         if prefix:
         if prefix:
             page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix)
             page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix)
         else:
         else:
             page_iterator = paginator.paginate(Bucket=self.bucket_name)
             page_iterator = paginator.paginate(Bucket=self.bucket_name)
 
 
         bucket_objects = []
         bucket_objects = []
-        for item in page_iterator.search('Contents'):
-            if not item or item['Key'] == self.bucket_name:
+        for item in page_iterator.search("Contents"):
+            if not item or item["Key"] == self.bucket_name:
                 continue
                 continue
             bucket_objects.append(item)
             bucket_objects.append(item)
         return bucket_objects
         return bucket_objects
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
-    def create_presigned_upload_url(self,
-                                    object_name: str,
-                                    fields=None,
-                                    conditions=None,
-                                    expiration=3600):
+    @explicit_params_validation(validation_type="NoneOrEmpty")
+    def create_presigned_upload_url(self, object_name: str, fields=None, conditions=None, expiration=3600):
         """Generate a presigned URL S3 POST request to upload a file
         """Generate a presigned URL S3 POST request to upload a file
         :param bucket_name: string
         :param bucket_name: string
         :param object_name: string
         :param object_name: string
@@ -348,14 +322,10 @@ class S3Connector(ILogger):
         if file_already_exist:
         if file_already_exist:
             raise FileExistsError(f"The key {object_name} already exists in bucket {self.bucket_name}")
             raise FileExistsError(f"The key {object_name} already exists in bucket {self.bucket_name}")
 
 
-        response = self.s3_client.generate_presigned_post(self.bucket_name,
-                                                          object_name,
-                                                          Fields=fields,
-                                                          Conditions=conditions,
-                                                          ExpiresIn=expiration)
+        response = self.s3_client.generate_presigned_post(self.bucket_name, object_name, Fields=fields, Conditions=conditions, ExpiresIn=expiration)
         return response
         return response
 
 
-    @explicit_params_validation(validation_type='NoneOrEmpty')
+    @explicit_params_validation(validation_type="NoneOrEmpty")
     def create_presigned_download_url(self, bucket_name: str, object_name: str, expiration=3600):
     def create_presigned_download_url(self, bucket_name: str, object_name: str, expiration=3600):
         """Generate a presigned URL S3 Get request to download a file
         """Generate a presigned URL S3 Get request to download a file
         :param bucket_name: string
         :param bucket_name: string
@@ -364,27 +334,17 @@ class S3Connector(ILogger):
         :return: URL encoded with the credentials in the query, to be fetched using any HTTP client.
         :return: URL encoded with the credentials in the query, to be fetched using any HTTP client.
         """
         """
         # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-presigned-urls.html
         # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-presigned-urls.html
-        response = self.s3_client.generate_presigned_url('get_object',
-                                                         Params={'Bucket': bucket_name,
-                                                                 'Key': object_name},
-                                                         ExpiresIn=expiration)
+        response = self.s3_client.generate_presigned_url("get_object", Params={"Bucket": bucket_name, "Key": object_name}, ExpiresIn=expiration)
         return response
         return response
 
 
     @staticmethod
     @staticmethod
     def convert_content_length_to_mb(content_length):
     def convert_content_length_to_mb(content_length):
-        return round(float(f'{content_length / (1e+6):2f}'), 2)
-
-    @explicit_params_validation(validation_type='NoneOrEmpty')
-    def copy_key(self,
-                 destination_bucket_name: str,
-                 source_key: str,
-                 destination_key: str):
-        self._logger.info(
-            f'Copying the bucket object {self.bucket_name}:{source_key} to {destination_bucket_name}/{destination_key}')
-        copy_source = {
-            'Bucket': self.bucket_name,
-            'Key': source_key
-        }
+        return round(float(f"{content_length / (1e+6):2f}"), 2)
+
+    @explicit_params_validation(validation_type="NoneOrEmpty")
+    def copy_key(self, destination_bucket_name: str, source_key: str, destination_key: str):
+        self._logger.info(f"Copying the bucket object {self.bucket_name}:{source_key} to {destination_bucket_name}/{destination_key}")
+        copy_source = {"Bucket": self.bucket_name, "Key": source_key}
 
 
         # Copying the key
         # Copying the key
         bucket = self.s3_resource.Bucket(destination_bucket_name)
         bucket = self.s3_resource.Bucket(destination_bucket_name)
Discard
@@ -1,7 +1,8 @@
 import os
 import os
 import sys
 import sys
 
 
-from super_gradients.common import S3Connector, explicit_params_validation
+from super_gradients.common.data_connection.s3_connector import S3Connector
+from super_gradients.common.decorators.explicit_params_validator import explicit_params_validation
 from super_gradients.common.environment.env_variables import env_variables
 from super_gradients.common.environment.env_variables import env_variables
 from super_gradients.common.abstractions.abstract_logger import ILogger
 from super_gradients.common.abstractions.abstract_logger import ILogger
 
 
Discard
@@ -1,12 +1,11 @@
 import os
 import os
-from super_gradients.common import S3Connector
-from super_gradients.common import explicit_params_validation
+from super_gradients.common.data_connection.s3_connector import S3Connector
+from super_gradients.common.decorators.explicit_params_validator import explicit_params_validation
 import zipfile
 import zipfile
 
 
 
 
 class DatasetDataInterface:
 class DatasetDataInterface:
-
-    def __init__(self, env: str, data_connection_source: str = 's3'):
+    def __init__(self, env: str, data_connection_source: str = "s3"):
         """
         """
 
 
         :param env: str "development"/"production"
         :param env: str "development"/"production"
@@ -16,7 +15,7 @@ class DatasetDataInterface:
         self.s3_connector = None
         self.s3_connector = None
         self.data_connection_source = data_connection_source
         self.data_connection_source = data_connection_source
 
 
-    @explicit_params_validation(validation_type='None')
+    @explicit_params_validation(validation_type="None")
     def load_remote_dataset_file(self, remote_file: str, local_dir: str, overwrite_local_dataset: bool = False) -> str:
     def load_remote_dataset_file(self, remote_file: str, local_dir: str, overwrite_local_dataset: bool = False) -> str:
         """
         """
 
 
@@ -29,7 +28,7 @@ class DatasetDataInterface:
         dataset_full_path = local_dir
         dataset_full_path = local_dir
         bucket = remote_file.split("/")[2]
         bucket = remote_file.split("/")[2]
         file_path = "/".join(remote_file.split("/")[3:])
         file_path = "/".join(remote_file.split("/")[3:])
-        if self.data_connection_source == 's3':
+        if self.data_connection_source == "s3":
             self.s3_connector = S3Connector(self.env, bucket)
             self.s3_connector = S3Connector(self.env, bucket)
 
 
             # DELETE THE LOCAL VERSION ON THE MACHINE
             # DELETE THE LOCAL VERSION ON THE MACHINE
@@ -45,7 +44,7 @@ class DatasetDataInterface:
                 os.mkdir(local_dir)
                 os.mkdir(local_dir)
 
 
             local_file = self.s3_connector.download_file_by_path(file_path, local_dir)
             local_file = self.s3_connector.download_file_by_path(file_path, local_dir)
-            with zipfile.ZipFile(local_dir + "/" + local_file, 'r') as zip_ref:
+            with zipfile.ZipFile(local_dir + "/" + local_file, "r") as zip_ref:
                 zip_ref.extractall(local_dir + "/")
                 zip_ref.extractall(local_dir + "/")
             os.remove(local_dir + "/" + local_file)
             os.remove(local_dir + "/" + local_file)
 
 
Discard
1
2
3
4
5
6
7
8
9
  1. def normalize_path(path: str) -> str:
  2. """Normalize the directory of file path. Replace the Windows-style (\\) path separators with unix ones (/).
  3. This is necessary when running on Windows since Hydra compose fails to find a configuration file is the config
  4. directory contains backward slash symbol.
  5. :param path: Input path string
  6. :return: Output path string with all \\ symbols replaces with /.
  7. """
  8. return path.replace("\\", "/")
Discard
@@ -14,7 +14,7 @@ from torch import nn
 
 
 from super_gradients.common.environment.env_variables import env_variables
 from super_gradients.common.environment.env_variables import env_variables
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
-from super_gradients.training.utils.hydra_utils import normalize_path
+from super_gradients.common.environment.path_utils import normalize_path
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
Discard
@@ -9,7 +9,7 @@ import numpy as np
 import psutil
 import psutil
 import torch
 import torch
 from PIL import Image
 from PIL import Image
-from super_gradients.common import ADNNModelRepositoryDataInterfaces
+from super_gradients.common.data_interface.adnn_model_repository_data_interface import ADNNModelRepositoryDataInterfaces
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.decorators.code_save_decorator import saved_codes
 from super_gradients.common.decorators.code_save_decorator import saved_codes
 from super_gradients.common.environment.ddp_utils import multi_process_safe
 from super_gradients.common.environment.ddp_utils import multi_process_safe
Discard
@@ -17,7 +17,7 @@ from super_gradients.training.datasets.detection_datasets.pascal_voc_detection i
     PascalVOCDetectionDataset,
     PascalVOCDetectionDataset,
 )
 )
 from super_gradients.training.utils import get_param
 from super_gradients.training.utils import get_param
-from super_gradients.training.utils.hydra_utils import normalize_path
+from super_gradients.common.environment.path_utils import normalize_path
 from super_gradients.training.datasets import ImageNetDataset
 from super_gradients.training.datasets import ImageNetDataset
 from super_gradients.training.datasets.detection_datasets import COCODetectionDataset
 from super_gradients.training.datasets.detection_datasets import COCODetectionDataset
 from super_gradients.training.datasets.classification_datasets.cifar import (
 from super_gradients.training.datasets.classification_datasets.cifar import (
Discard
@@ -20,7 +20,6 @@ from torchvision.transforms import transforms, InterpolationMode, RandomResizedC
 from tqdm import tqdm
 from tqdm import tqdm
 
 
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
-from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
 from super_gradients.training.datasets.auto_augment import rand_augment_transform
 from super_gradients.training.datasets.auto_augment import rand_augment_transform
 from super_gradients.training.utils.detection_utils import DetectionVisualization, Anchors
 from super_gradients.training.utils.detection_utils import DetectionVisualization, Anchors
 from super_gradients.training.utils.distributed_training_utils import get_local_rank, get_world_size
 from super_gradients.training.utils.distributed_training_utils import get_local_rank, get_world_size
@@ -402,7 +401,7 @@ class DatasetStatisticsTensorboardLogger:
         "max_batches": 30,
         "max_batches": 30,
     }
     }
 
 
-    def __init__(self, sg_logger: AbstractSGLogger, summary_params: dict = DEFAULT_SUMMARY_PARAMS):
+    def __init__(self, sg_logger, summary_params: dict = DEFAULT_SUMMARY_PARAMS):
         self.sg_logger = sg_logger
         self.sg_logger = sg_logger
         self.summary_params = {**DatasetStatisticsTensorboardLogger.DEFAULT_SUMMARY_PARAMS, **summary_params}
         self.summary_params = {**DatasetStatisticsTensorboardLogger.DEFAULT_SUMMARY_PARAMS, **summary_params}
 
 
Discard
@@ -4,7 +4,7 @@ from typing import Tuple, Type, Optional
 import hydra
 import hydra
 import torch
 import torch
 
 
-from super_gradients.common import StrictLoad
+from super_gradients.common.data_types.enum.strict_load import StrictLoad
 from super_gradients.common.plugins.deci_client import DeciClient, client_enabled
 from super_gradients.common.plugins.deci_client import DeciClient, client_enabled
 from super_gradients.training import utils as core_utils
 from super_gradients.training import utils as core_utils
 from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
 from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
Discard
@@ -3,7 +3,7 @@ import torch.nn as nn
 from typing import Union, List
 from typing import Union, List
 from super_gradients.modules import ConvBNReLU
 from super_gradients.modules import ConvBNReLU
 from super_gradients.training.utils.module_utils import make_upsample_module
 from super_gradients.training.utils.module_utils import make_upsample_module
-from super_gradients.common import UpsampleMode
+from super_gradients.common.data_types.enum.upsample_mode import UpsampleMode
 from super_gradients.training.models.segmentation_models.stdc import AbstractSTDCBackbone, STDC1Backbone, STDC2Backbone
 from super_gradients.training.models.segmentation_models.stdc import AbstractSTDCBackbone, STDC1Backbone, STDC2Backbone
 from super_gradients.training.models.segmentation_models.common import SegmentationHead
 from super_gradients.training.models.segmentation_models.common import SegmentationHead
 from super_gradients.training.models.segmentation_models.segmentation_module import SegmentationModule
 from super_gradients.training.models.segmentation_models.segmentation_module import SegmentationModule
Discard
@@ -3,7 +3,7 @@ import pkg_resources
 from hydra import compose, initialize_config_dir
 from hydra import compose, initialize_config_dir
 from hydra.core.global_hydra import GlobalHydra
 from hydra.core.global_hydra import GlobalHydra
 from super_gradients.training.utils.utils import override_default_params_without_nones
 from super_gradients.training.utils.utils import override_default_params_without_nones
-from super_gradients.training.utils.hydra_utils import normalize_path
+from super_gradients.common.environment.path_utils import normalize_path
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from typing import Dict
 from typing import Dict
 
 
Discard
@@ -5,9 +5,11 @@ import pkg_resources
 import torch
 import torch
 
 
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
-from super_gradients.common import explicit_params_validation, ADNNModelRepositoryDataInterfaces
-from super_gradients.training.pretrained_models import MODEL_URLS
 from super_gradients.common.environment import environment_config
 from super_gradients.common.environment import environment_config
+from super_gradients.common.data_interface.adnn_model_repository_data_interface import ADNNModelRepositoryDataInterfaces
+from super_gradients.common.decorators.explicit_params_validator import explicit_params_validation
+from super_gradients.training.pretrained_models import MODEL_URLS
+
 
 
 try:
 try:
     from torch.hub import download_url_to_file, load_state_dict_from_url
     from torch.hub import download_url_to_file, load_state_dict_from_url
Discard
@@ -8,6 +8,7 @@ from hydra import initialize_config_dir, compose
 from hydra.core.global_hydra import GlobalHydra
 from hydra.core.global_hydra import GlobalHydra
 from omegaconf import OmegaConf, open_dict, DictConfig
 from omegaconf import OmegaConf, open_dict, DictConfig
 
 
+from super_gradients.common.environment.path_utils import normalize_path
 from super_gradients.training.utils.checkpoint_utils import get_checkpoints_dir_path
 from super_gradients.training.utils.checkpoint_utils import get_checkpoints_dir_path
 
 
 
 
@@ -56,17 +57,6 @@ def add_params_to_cfg(cfg: DictConfig, params: List[str]):
         cfg.merge_with(new_cfg)
         cfg.merge_with(new_cfg)
 
 
 
 
-def normalize_path(path: str) -> str:
-    """Normalize the directory of file path. Replace the Windows-style (\\) path separators with unix ones (/).
-    This is necessary when running on Windows since Hydra compose fails to find a configuration file is the config
-    directory contains backward slash symbol.
-
-    :param path: Input path string
-    :return: Output path string with all \\ symbols replaces with /.
-    """
-    return path.replace("\\", "/")
-
-
 def load_arch_params(config_name: str) -> DictConfig:
 def load_arch_params(config_name: str) -> DictConfig:
     """
     """
     :param config_name: name of a yaml with arch parameters
     :param config_name: name of a yaml with arch parameters
Discard
@@ -5,7 +5,7 @@ import torch
 from torch import nn
 from torch import nn
 from omegaconf.listconfig import ListConfig
 from omegaconf.listconfig import ListConfig
 
 
-from super_gradients.common import UpsampleMode
+from super_gradients.common.data_types.enum.upsample_mode import UpsampleMode
 
 
 
 
 class MultiOutputModule(nn.Module):
 class MultiOutputModule(nn.Module):
Discard
@@ -15,7 +15,7 @@ from super_gradients.training.datasets import PascalVOCDetectionDataset, COCODet
 from super_gradients.training.transforms import DetectionMosaic, DetectionPaddedRescale, DetectionTargetsFormatTransform
 from super_gradients.training.transforms import DetectionMosaic, DetectionPaddedRescale, DetectionTargetsFormatTransform
 from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL
 from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL
 from super_gradients.training.exceptions.dataset_exceptions import EmptyDatasetException
 from super_gradients.training.exceptions.dataset_exceptions import EmptyDatasetException
-from super_gradients.training.utils.hydra_utils import normalize_path
+from super_gradients.common.environment.path_utils import normalize_path
 
 
 
 
 class COCODetectionDataset6Channels(COCODetectionDataset):
 class COCODetectionDataset6Channels(COCODetectionDataset):
Discard
@@ -8,7 +8,7 @@ from hydra import initialize_config_dir, compose
 from hydra.core.global_hydra import GlobalHydra
 from hydra.core.global_hydra import GlobalHydra
 
 
 from super_gradients.training.models.detection_models.csp_resnet import CSPResNet
 from super_gradients.training.models.detection_models.csp_resnet import CSPResNet
-from super_gradients.training.utils.hydra_utils import normalize_path
+from super_gradients.common.environment.path_utils import normalize_path
 
 
 
 
 class PPYoloETests(unittest.TestCase):
 class PPYoloETests(unittest.TestCase):
Discard