Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#356 Feature/sg 216 remove dataset interface

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-216_remove_dataset_interface
@@ -4,7 +4,8 @@ import os
 
 
 from super_gradients.training import models
 from super_gradients.training import models
 
 
-from super_gradients import Trainer, ClassificationTestDatasetInterface
+from super_gradients import Trainer
+from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 
 
 
 
@@ -30,10 +31,7 @@ class LRTest(unittest.TestCase):
     @staticmethod
     @staticmethod
     def get_trainer(name=''):
     def get_trainer(name=''):
         trainer = Trainer(name, model_checkpoints_location='local')
         trainer = Trainer(name, model_checkpoints_location='local')
-        dataset_params = {"batch_size": 4}
-        dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
-        trainer.connect_dataset_interface(dataset)
-        model = models.get("resnet18_cifar", arch_params={"num_classes": 5})
+        model = models.get("resnet18_cifar", num_classes=5)
         return trainer, model
         return trainer, model
 
 
     def test_function_lr(self):
     def test_function_lr(self):
@@ -44,22 +42,25 @@ class LRTest(unittest.TestCase):
 
 
         # test if we are able that lr_function supports functions with this structure
         # test if we are able that lr_function supports functions with this structure
         training_params = {**self.training_params, "lr_mode": "function", "lr_schedule_function": test_lr_function}
         training_params = {**self.training_params, "lr_mode": "function", "lr_schedule_function": test_lr_function}
-        trainer.train(model=model, training_params=training_params)
-
+        trainer.train(model=model, training_params=training_params, train_loader=classification_test_dataloader(),
+                      valid_loader=classification_test_dataloader())
         # test that we assert lr_function is callable
         # test that we assert lr_function is callable
         training_params = {**self.training_params, "lr_mode": "function"}
         training_params = {**self.training_params, "lr_mode": "function"}
         with self.assertRaises(AssertionError):
         with self.assertRaises(AssertionError):
-            trainer.train(model=model, training_params=training_params)
+            trainer.train(model=model, training_params=training_params, train_loader=classification_test_dataloader(),
+                          valid_loader=classification_test_dataloader())
 
 
     def test_cosine_lr(self):
     def test_cosine_lr(self):
         trainer, model = self.get_trainer(self.folder_name)
         trainer, model = self.get_trainer(self.folder_name)
         training_params = {**self.training_params, "lr_mode": "cosine", "cosine_final_lr_ratio": 0.01}
         training_params = {**self.training_params, "lr_mode": "cosine", "cosine_final_lr_ratio": 0.01}
-        trainer.train(model=model, training_params=training_params)
+        trainer.train(model=model, training_params=training_params, train_loader=classification_test_dataloader(),
+                      valid_loader=classification_test_dataloader())
 
 
     def test_step_lr(self):
     def test_step_lr(self):
         trainer, model = self.get_trainer(self.folder_name)
         trainer, model = self.get_trainer(self.folder_name)
         training_params = {**self.training_params, "lr_mode": "step", "lr_decay_factor": 0.1, "lr_updates": [4]}
         training_params = {**self.training_params, "lr_mode": "step", "lr_decay_factor": 0.1, "lr_updates": [4]}
-        trainer.train(model=model, training_params=training_params)
+        trainer.train(model=model, training_params=training_params, train_loader=classification_test_dataloader(),
+                      valid_loader=classification_test_dataloader())
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
Tip!

Press p or to see the previous file or, n or to see the next file