Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#516 Remove imports from factory.__init__

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:hotfix/SG-000-remove_imports_from_factories
@@ -1,27 +0,0 @@
-from super_gradients.common.factories.callbacks_factory import CallbacksFactory
-from super_gradients.common.factories.list_factory import ListFactory
-from super_gradients.common.factories.base_factory import BaseFactory
-from super_gradients.common.factories.type_factory import TypeFactory
-from super_gradients.common.factories.losses_factory import LossesFactory
-from super_gradients.common.factories.metrics_factory import MetricsFactory
-from super_gradients.common.factories.optimizers_type_factory import OptimizersTypeFactory
-from super_gradients.common.factories.samplers_factory import SamplersFactory
-from super_gradients.common.factories.transforms_factory import TransformsFactory
-from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
-from super_gradients.common.factories.bbox_format_factory import BBoxFormatFactory
-
-
-# "DetectionModulesFactory" from super_gradients.common.factories.detection_modules_factory is omitted due to circular import issue
-__all__ = [
-    "CallbacksFactory",
-    "ListFactory",
-    "LossesFactory",
-    "MetricsFactory",
-    "OptimizersTypeFactory",
-    "SamplersFactory",
-    "TransformsFactory",
-    "ActivationsTypeFactory",
-    "TypeFactory",
-    "BaseFactory",
-    "BBoxFormatFactory",
-]
Discard
@@ -1,4 +1,4 @@
-from super_gradients.common.factories import BaseFactory
+from super_gradients.common.factories.base_factory import BaseFactory
 import super_gradients.training.models.segmentation_models.context_modules as context_modules
 import super_gradients.training.models.segmentation_models.context_modules as context_modules
 
 
 
 
Discard
@@ -1,4 +1,4 @@
-from super_gradients.common.factories import TypeFactory
+from super_gradients.common.factories.type_factory import TypeFactory
 from super_gradients.training.utils.optimizers import OPTIMIZERS
 from super_gradients.training.utils.optimizers import OPTIMIZERS
 
 
 
 
Discard
@@ -4,7 +4,8 @@ from torchvision.datasets import MNIST
 from torchvision.transforms import Compose
 from torchvision.transforms import Compose
 
 
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
-from super_gradients.common.factories import ListFactory, TransformsFactory
+from super_gradients.common.factories.list_factory import ListFactory
+from super_gradients.common.factories.transforms_factory import TransformsFactory
 from super_gradients.common.registry.registry import register_dataloader
 from super_gradients.common.registry.registry import register_dataloader
 from super_gradients.training.dataloaders import get_data_loader
 from super_gradients.training.dataloaders import get_data_loader
 
 
Discard
@@ -3,7 +3,7 @@ from typing import List, Type, Tuple
 
 
 import torch
 import torch
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
-from super_gradients.common.factories import ActivationsTypeFactory
+from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
 from torch import nn, Tensor
 from torch import nn, Tensor
 
 
 from super_gradients.modules import RepVGGBlock, EffectiveSEBlock, ConvBNAct
 from super_gradients.modules import RepVGGBlock, EffectiveSEBlock, ConvBNAct
Discard
@@ -8,7 +8,8 @@ import torch
 from super_gradients.modules import ConvBNReLU
 from super_gradients.modules import ConvBNReLU
 from super_gradients.training.utils.module_utils import make_upsample_module
 from super_gradients.training.utils.module_utils import make_upsample_module
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
-from super_gradients.common.factories import ListFactory, TypeFactory
+from super_gradients.common.factories.list_factory import ListFactory
+from super_gradients.common.factories.type_factory import TypeFactory
 
 
 
 
 class AbstractUpFuseBlock(nn.Module, ABC):
 class AbstractUpFuseBlock(nn.Module, ABC):
Discard
@@ -13,7 +13,8 @@ from super_gradients.training.models.segmentation_models.stdc import STDCBlock
 from super_gradients.training.models import SgModule, HpmStruct
 from super_gradients.training.models import SgModule, HpmStruct
 from super_gradients.modules import ConvBNReLU
 from super_gradients.modules import ConvBNReLU
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
-from super_gradients.common.factories import ListFactory, TypeFactory
+from super_gradients.common.factories.list_factory import ListFactory
+from super_gradients.common.factories.type_factory import TypeFactory
 
 
 
 
 class AntiAliasDownsample(nn.Module):
 class AntiAliasDownsample(nn.Module):
Discard
@@ -17,7 +17,7 @@ from piptools.scripts.sync import _get_installed_distributions
 
 
 from torch.utils.data.distributed import DistributedSampler
 from torch.utils.data.distributed import DistributedSampler
 
 
-from super_gradients.common.factories import TypeFactory
+from super_gradients.common.factories.type_factory import TypeFactory
 from super_gradients.training.datasets.samplers import InfiniteSampler, RepeatAugSampler
 from super_gradients.training.datasets.samplers import InfiniteSampler, RepeatAugSampler
 
 
 from super_gradients.common.factories.callbacks_factory import CallbacksFactory
 from super_gradients.common.factories.callbacks_factory import CallbacksFactory
Discard
@@ -6,7 +6,7 @@ import unittest
 import numpy as np
 import numpy as np
 import torch
 import torch
 
 
-from super_gradients.common.factories import BBoxFormatFactory
+from super_gradients.common.factories.bbox_format_factory import BBoxFormatFactory
 from super_gradients.training.utils.bbox_formats import (
 from super_gradients.training.utils.bbox_formats import (
     CXCYWHCoordinateFormat,
     CXCYWHCoordinateFormat,
     NormalizedXYXYCoordinateFormat,
     NormalizedXYXYCoordinateFormat,
Discard
@@ -4,7 +4,7 @@ import torch
 
 
 from super_gradients import Trainer
 from super_gradients import Trainer
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
-from super_gradients.common.factories import ActivationsTypeFactory
+from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
@@ -12,28 +12,27 @@ from torch import nn
 
 
 
 
 class FactoriesTest(unittest.TestCase):
 class FactoriesTest(unittest.TestCase):
-
     def test_training_with_factories(self):
     def test_training_with_factories(self):
         trainer = Trainer("test_train_with_factories")
         trainer = Trainer("test_train_with_factories")
         net = models.get("resnet18", num_classes=5)
         net = models.get("resnet18", num_classes=5)
-        train_params = {"max_epochs": 2,
-                        "lr_updates": [1],
-                        "lr_decay_factor": 0.1,
-                        "lr_mode": "step",
-                        "lr_warmup_epochs": 0,
-                        "initial_lr": 0.1,
-                        "loss": "cross_entropy",
-                        "optimizer": "torch.optim.ASGD",  # use an optimizer by factory
-                        "criterion_params": {},
-                        "optimizer_params": {"lambd": 0.0001, "alpha": 0.75},
-                        "train_metrics_list": ["Accuracy", "Top5"],  # use a metric by factory
-                        "valid_metrics_list": ["Accuracy", "Top5"],  # use a metric by factory
-                        "metric_to_watch": "Accuracy",
-                        "greater_metric_to_watch_is_better": True}
-
-        trainer.train(model=net, training_params=train_params,
-                      train_loader=classification_test_dataloader(),
-                      valid_loader=classification_test_dataloader())
+        train_params = {
+            "max_epochs": 2,
+            "lr_updates": [1],
+            "lr_decay_factor": 0.1,
+            "lr_mode": "step",
+            "lr_warmup_epochs": 0,
+            "initial_lr": 0.1,
+            "loss": "cross_entropy",
+            "optimizer": "torch.optim.ASGD",  # use an optimizer by factory
+            "criterion_params": {},
+            "optimizer_params": {"lambd": 0.0001, "alpha": 0.75},
+            "train_metrics_list": ["Accuracy", "Top5"],  # use a metric by factory
+            "valid_metrics_list": ["Accuracy", "Top5"],  # use a metric by factory
+            "metric_to_watch": "Accuracy",
+            "greater_metric_to_watch_is_better": True,
+        }
+
+        trainer.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
 
 
         self.assertIsInstance(trainer.train_metrics.Accuracy, Accuracy)
         self.assertIsInstance(trainer.train_metrics.Accuracy, Accuracy)
         self.assertIsInstance(trainer.valid_metrics.Top5, Top5)
         self.assertIsInstance(trainer.valid_metrics.Top5, Top5)
@@ -50,5 +49,5 @@ class FactoriesTest(unittest.TestCase):
         self.assertIsInstance(model.activation_in_head, nn.LeakyReLU)
         self.assertIsInstance(model.activation_in_head, nn.LeakyReLU)
 
 
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
     unittest.main()
Discard