Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#396 Trainer constructor cleanup

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-266_clean_trainer_ctor
@@ -16,7 +16,7 @@ from super_gradients.training.dataloaders.dataloaders import (
 class TestCifarTrainer(unittest.TestCase):
 class TestCifarTrainer(unittest.TestCase):
     def test_train_cifar10_dataloader(self):
     def test_train_cifar10_dataloader(self):
         super_gradients.init_trainer()
         super_gradients.init_trainer()
-        trainer = Trainer("test", model_checkpoints_location="local")
+        trainer = Trainer("test")
         cifar10_train_dl, cifar10_val_dl = cifar10_train(), cifar10_val()
         cifar10_train_dl, cifar10_val_dl = cifar10_train(), cifar10_val()
         model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
         model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
         trainer.train(
         trainer.train(
@@ -35,7 +35,7 @@ class TestCifarTrainer(unittest.TestCase):
 
 
     def test_train_cifar100_dataloader(self):
     def test_train_cifar100_dataloader(self):
         super_gradients.init_trainer()
         super_gradients.init_trainer()
-        trainer = Trainer("test", model_checkpoints_location="local")
+        trainer = Trainer("test")
         cifar100_train_dl, cifar100_val_dl = cifar100_train(), cifar100_val()
         cifar100_train_dl, cifar100_val_dl = cifar100_train(), cifar100_val()
         model = models.get("resnet18_cifar", arch_params={"num_classes": 100})
         model = models.get("resnet18_cifar", arch_params={"num_classes": 100})
         trainer.train(
         trainer.train(
Discard
Tip!

Press p or to see the previous file or, n or to see the next file