Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#356 Feature/sg 216 remove dataset interface

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-216_remove_dataset_interface
@@ -4,11 +4,9 @@ import re
 
 
 from super_gradients.training import models
 from super_gradients.training import models
 
 
-from super_gradients import (
-    Trainer,
-    ClassificationTestDatasetInterface,
-    SegmentationTestDatasetInterface,
-)
+from super_gradients import Trainer
+from super_gradients.training.dataloaders.dataloaders import segmentation_test_dataloader, \
+    classification_test_dataloader
 from super_gradients.training.utils.callbacks import ModelConversionCheckCallback
 from super_gradients.training.utils.callbacks import ModelConversionCheckCallback
 from super_gradients.training.metrics import Accuracy, Top5, IoU
 from super_gradients.training.metrics import Accuracy, Top5, IoU
 from super_gradients.training.losses.stdc_loss import STDCLoss
 from super_gradients.training.losses.stdc_loss import STDCLoss
@@ -16,7 +14,6 @@ from super_gradients.training.losses.ddrnet_loss import DDRNetLoss
 
 
 from deci_lab_client.models import ModelMetadata, HardwareType, FrameworkType
 from deci_lab_client.models import ModelMetadata, HardwareType, FrameworkType
 
 
-
 checkpoint_dir = "/Users/daniel/Documents/LALA"
 checkpoint_dir = "/Users/daniel/Documents/LALA"
 
 
 
 
@@ -44,6 +41,8 @@ def generate_model_metadata(architecture: str, task: Task):
 
 
 CLASSIFICATION = ["efficientnet_b0", "regnetY200", "regnetY400", "regnetY600", "regnetY800", "mobilenet_v3_large"]
 CLASSIFICATION = ["efficientnet_b0", "regnetY200", "regnetY400", "regnetY600", "regnetY800", "mobilenet_v3_large"]
 SEMANTIC_SEGMENTATION = ["ddrnet_23", "stdc1_seg", "stdc2_seg", "regseg48"]
 SEMANTIC_SEGMENTATION = ["ddrnet_23", "stdc1_seg", "stdc2_seg", "regseg48"]
+
+
 # TODO: ADD YOLOX ARCHITECTURES AND TESTS
 # TODO: ADD YOLOX ARCHITECTURES AND TESTS
 
 
 
 
@@ -70,13 +69,12 @@ class ConversionCallbackTest(unittest.TestCase):
                 "phase_callbacks": phase_callbacks,
                 "phase_callbacks": phase_callbacks,
             }
             }
 
 
-            trainer = Trainer(f"{architecture}_example", model_checkpoints_location="local", ckpt_root_dir=checkpoint_dir)
-            dataset = ClassificationTestDatasetInterface(dataset_params={"batch_size": 10})
-
-            trainer.connect_dataset_interface(dataset, data_loader_num_workers=0)
+            trainer = Trainer(f"{architecture}_example", model_checkpoints_location="local",
+                              ckpt_root_dir=checkpoint_dir)
             model = models.get(architecture=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
             model = models.get(architecture=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
             try:
             try:
-                trainer.train(model=model, training_params=train_params)
+                trainer.train(model=model, training_params=train_params, train_loader=classification_test_dataloader(),
+                              valid_loader=classification_test_dataloader())
             except Exception as e:
             except Exception as e:
                 self.fail(f"Model training didn't succeed due to {e}")
                 self.fail(f"Model training didn't succeed due to {e}")
             else:
             else:
@@ -104,10 +102,9 @@ class ConversionCallbackTest(unittest.TestCase):
 
 
         for architecture in SEMANTIC_SEGMENTATION:
         for architecture in SEMANTIC_SEGMENTATION:
             model_meta_data = generate_model_metadata(architecture=architecture, task=Task.SEMANTIC_SEGMENTATION)
             model_meta_data = generate_model_metadata(architecture=architecture, task=Task.SEMANTIC_SEGMENTATION)
-            dataset = SegmentationTestDatasetInterface(dataset_params={"batch_size": 10})
-            trainer = Trainer(f"{architecture}_example", model_checkpoints_location="local", ckpt_root_dir=checkpoint_dir)
-            trainer.connect_dataset_interface(dataset, data_loader_num_workers=0)
-            model = models.get(architecture=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
+            trainer = Trainer(f"{architecture}_example", model_checkpoints_location="local",
+                              ckpt_root_dir=checkpoint_dir)
+            model = models.get(name=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
 
 
             phase_callbacks = [
             phase_callbacks = [
                 ModelConversionCheckCallback(model_meta_data=model_meta_data, opset_version=11, rtol=1, atol=1),
                 ModelConversionCheckCallback(model_meta_data=model_meta_data, opset_version=11, rtol=1, atol=1),
@@ -131,7 +128,8 @@ class ConversionCallbackTest(unittest.TestCase):
             train_params.update(custom_config)
             train_params.update(custom_config)
 
 
             try:
             try:
-                trainer.train(model=model, training_params=train_params)
+                trainer.train(model=model, training_params=train_params, train_loader=segmentation_test_dataloader(image_size=512),
+                              valid_loader=segmentation_test_dataloader(image_size=512))
             except Exception as e:
             except Exception as e:
                 self.fail(f"Model training didn't succeed for {architecture} due to {e}")
                 self.fail(f"Model training didn't succeed for {architecture} due to {e}")
             else:
             else:
Discard
Tip!

Press p or to see the previous file or, n or to see the next file