Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#516 Remove imports from factory.__init__

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:hotfix/SG-000-remove_imports_from_factories
@@ -4,7 +4,7 @@ import torch
 
 
 from super_gradients import Trainer
 from super_gradients import Trainer
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
-from super_gradients.common.factories import ActivationsTypeFactory
+from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
 from super_gradients.training import models
 from super_gradients.training import models
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
@@ -12,28 +12,27 @@ from torch import nn
 
 
 
 
 class FactoriesTest(unittest.TestCase):
 class FactoriesTest(unittest.TestCase):
-
     def test_training_with_factories(self):
     def test_training_with_factories(self):
         trainer = Trainer("test_train_with_factories")
         trainer = Trainer("test_train_with_factories")
         net = models.get("resnet18", num_classes=5)
         net = models.get("resnet18", num_classes=5)
-        train_params = {"max_epochs": 2,
-                        "lr_updates": [1],
-                        "lr_decay_factor": 0.1,
-                        "lr_mode": "step",
-                        "lr_warmup_epochs": 0,
-                        "initial_lr": 0.1,
-                        "loss": "cross_entropy",
-                        "optimizer": "torch.optim.ASGD",  # use an optimizer by factory
-                        "criterion_params": {},
-                        "optimizer_params": {"lambd": 0.0001, "alpha": 0.75},
-                        "train_metrics_list": ["Accuracy", "Top5"],  # use a metric by factory
-                        "valid_metrics_list": ["Accuracy", "Top5"],  # use a metric by factory
-                        "metric_to_watch": "Accuracy",
-                        "greater_metric_to_watch_is_better": True}
-
-        trainer.train(model=net, training_params=train_params,
-                      train_loader=classification_test_dataloader(),
-                      valid_loader=classification_test_dataloader())
+        train_params = {
+            "max_epochs": 2,
+            "lr_updates": [1],
+            "lr_decay_factor": 0.1,
+            "lr_mode": "step",
+            "lr_warmup_epochs": 0,
+            "initial_lr": 0.1,
+            "loss": "cross_entropy",
+            "optimizer": "torch.optim.ASGD",  # use an optimizer by factory
+            "criterion_params": {},
+            "optimizer_params": {"lambd": 0.0001, "alpha": 0.75},
+            "train_metrics_list": ["Accuracy", "Top5"],  # use a metric by factory
+            "valid_metrics_list": ["Accuracy", "Top5"],  # use a metric by factory
+            "metric_to_watch": "Accuracy",
+            "greater_metric_to_watch_is_better": True,
+        }
+
+        trainer.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
 
 
         self.assertIsInstance(trainer.train_metrics.Accuracy, Accuracy)
         self.assertIsInstance(trainer.train_metrics.Accuracy, Accuracy)
         self.assertIsInstance(trainer.valid_metrics.Top5, Top5)
         self.assertIsInstance(trainer.valid_metrics.Top5, Top5)
@@ -50,5 +49,5 @@ class FactoriesTest(unittest.TestCase):
         self.assertIsInstance(model.activation_in_head, nn.LeakyReLU)
         self.assertIsInstance(model.activation_in_head, nn.LeakyReLU)
 
 
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
     unittest.main()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file