Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#313 Feature/sg 187 rename sg model

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-187_rename_sg_model
@@ -3,7 +3,7 @@ import unittest
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CoCoDetectionDatasetInterface
 from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CoCoDetectionDatasetInterface
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 
 
-from super_gradients.training import SgModel
+from super_gradients.training import Trainer
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 from super_gradients.training.utils.detection_utils import CrowdDetectionCollateFN, DetectionCollateFN, \
 from super_gradients.training.utils.detection_utils import CrowdDetectionCollateFN, DetectionCollateFN, \
     DetectionTargetsFormat
     DetectionTargetsFormat
@@ -53,11 +53,11 @@ class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
                                                                 "area_thr": 0
                                                                 "area_thr": 0
                                                                 })
                                                                 })
 
 
-        model = SgModel('dataset_statistics_visual_test',
-                        model_checkpoints_location='local',
-                        post_prediction_callback=YoloPostPredictionCallback())
-        model.connect_dataset_interface(dataset, data_loader_num_workers=8)
-        model.build_model("yolox_s")
+        trainer = Trainer('dataset_statistics_visual_test',
+                          model_checkpoints_location='local',
+                          post_prediction_callback=YoloPostPredictionCallback())
+        trainer.connect_dataset_interface(dataset, data_loader_num_workers=8)
+        trainer.build_model("yolox_s")
 
 
         training_params = {"max_epochs": 1,  # we dont really need the actual training to run
         training_params = {"max_epochs": 1,  # we dont really need the actual training to run
                            "lr_mode": "cosine",
                            "lr_mode": "cosine",
@@ -74,7 +74,7 @@ class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
                            "loss_logging_items_names": ["iou", "obj", "cls", "l1", "num_fg", "Loss"],
                            "loss_logging_items_names": ["iou", "obj", "cls", "l1", "num_fg", "Loss"],
                            "metric_to_watch": "mAP@0.50:0.95",
                            "metric_to_watch": "mAP@0.50:0.95",
                            }
                            }
-        model.train(training_params=training_params)
+        trainer.train(training_params=training_params)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
Tip!

Press p or to see the previous file or, n or to see the next file