Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#356 Feature/sg 216 remove dataset interface

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-216_remove_dataset_interface
@@ -1,6 +1,6 @@
 import unittest
 import unittest
-from super_gradients import Trainer, \
-    ClassificationTestDatasetInterface
+from super_gradients import Trainer
+from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.models import ResNet18
 from super_gradients.training.models import ResNet18
 from torch.optim import SGD
 from torch.optim import SGD
@@ -11,8 +11,6 @@ from deci_lab_client.models import Metric, QuantizationLevel, ModelMetadata, Opt
 class DeciLabUploadTest(unittest.TestCase):
 class DeciLabUploadTest(unittest.TestCase):
     def setUp(self) -> None:
     def setUp(self) -> None:
         self.trainer = Trainer("deci_lab_export_test_model", model_checkpoints_location='local')
         self.trainer = Trainer("deci_lab_export_test_model", model_checkpoints_location='local')
-        dataset = ClassificationTestDatasetInterface(dataset_params={"batch_size": 10})
-        self.trainer.connect_dataset_interface(dataset)
 
 
     def test_train_with_deci_lab_integration(self):
     def test_train_with_deci_lab_integration(self):
         model_meta_data = ModelMetadata(name='model_for_deci_lab_upload_test',
         model_meta_data = ModelMetadata(name='model_for_deci_lab_upload_test',
@@ -49,7 +47,8 @@ class DeciLabUploadTest(unittest.TestCase):
                         "phase_callbacks": [model_conversion_callback, deci_lab_callback]}
                         "phase_callbacks": [model_conversion_callback, deci_lab_callback]}
         self.optimizer = SGD(params=net.parameters(), lr=0.1)
         self.optimizer = SGD(params=net.parameters(), lr=0.1)
 
 
-        self.trainer.train(model=net, training_params=train_params)
+        self.trainer.train(model=net, training_params=train_params,
+                           train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
 
 
         # CLEANUP
         # CLEANUP
 
 
Discard
Tip!

Press p or to see the previous file or, n or to see the next file