Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#617 Feature/sg 000 break inner train loop

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-000_break_inner_train_loop
@@ -99,6 +99,12 @@ qat_params:
   num_calib_batches: 2 # int, number of batches to collect the statistics from.
   num_calib_batches: 2 # int, number of batches to collect the statistics from.
   percentile: 99.99 # float, percentile value to use when quant_modules_calib_method='percentile'. Discarded when other methods are used (Default=99.99).
   percentile: 99.99 # float, percentile value to use when quant_modules_calib_method='percentile'. Discarded when other methods are used (Default=99.99).
 
 
+max_train_batches: None,  # For debug- when not None- will break out of inner train loop
+# (i.e iterating over train_loader) when reaching this number of batches.
+
+max_valid_batches: None,  # For debug- when not None- will break out of inner valid loop
+# (i.e iterating over valid_loader) when reaching this number of batches.
+
 sg_logger: base_sg_logger
 sg_logger: base_sg_logger
 sg_logger_params:
 sg_logger_params:
   tb_files_user_prompt: False # Asks User for Tensorboard Deletion Prompt
   tb_files_user_prompt: False # Asks User for Tensorboard Deletion Prompt
Discard
@@ -68,6 +68,10 @@ DEFAULT_TRAINING_PARAMS = {
     "ckpt_name": "ckpt_latest.pth",
     "ckpt_name": "ckpt_latest.pth",
     "resume_strict_load": False,
     "resume_strict_load": False,
     "sync_bn": False,
     "sync_bn": False,
+    "max_train_batches": None,  # For debug- when not None- will break out of inner train loop
+    # (i.e iterating over train_loader) when reaching this number of batches.
+    "max_valid_batches": None,  # For debug- when not None- will break out of inner valid loop
+    # (i.e iterating over valid_loader) when reaching this number of batches.
 }
 }
 
 
 DEFAULT_OPTIMIZER_PARAMS_SGD = {"weight_decay": 1e-4, "momentum": 0.9}
 DEFAULT_OPTIMIZER_PARAMS_SGD = {"weight_decay": 1e-4, "momentum": 0.9}
Discard
@@ -193,6 +193,8 @@ class Trainer:
 
 
         self.train_monitored_values = {}
         self.train_monitored_values = {}
         self.valid_monitored_values = {}
         self.valid_monitored_values = {}
+        self.max_train_batches = None
+        self.max_valid_batches = None
 
 
     @property
     @property
     def device(self) -> str:
     def device(self) -> str:
@@ -445,7 +447,9 @@ class Trainer:
 
 
             # TODO: ITERATE BY MAX ITERS
             # TODO: ITERATE BY MAX ITERS
             # FOR INFINITE SAMPLERS WE MUST BREAK WHEN REACHING LEN ITERATIONS.
             # FOR INFINITE SAMPLERS WE MUST BREAK WHEN REACHING LEN ITERATIONS.
-            if self._infinite_train_loader and batch_idx == len(self.train_loader) - 1:
+            if (self._infinite_train_loader and batch_idx == len(self.train_loader) - 1) or (
+                self.max_train_batches is not None and self.max_train_batches - 1 <= batch_idx
+            ):
                 break
                 break
 
 
         if not self.ddp_silent_mode:
         if not self.ddp_silent_mode:
@@ -965,6 +969,13 @@ class Trainer:
                         percentile: float, percentile value to use when Trainer,quant_modules_calib_method='percentile'.
                         percentile: float, percentile value to use when Trainer,quant_modules_calib_method='percentile'.
                          Discarded when other methods are used (Default=99.99).
                          Discarded when other methods are used (Default=99.99).
 
 
+                -   `max_train_batches`: int, for debug- when not None- will break out of inner train loop (i.e iterating over
+                      train_loader) when reaching this number of batches. Usefull for debugging (default=None).
+
+                -   `max_valid_batches`: int, for debug- when not None- will break out of inner valid loop (i.e iterating over
+                      valid_loader) when reaching this number of batches. Usefull for debugging (default=None).
+
+
 
 
         :return:
         :return:
         """
         """
@@ -1143,6 +1154,21 @@ class Trainer:
 
 
         self.ckpt_best_name = self.training_params.ckpt_best_name
         self.ckpt_best_name = self.training_params.ckpt_best_name
 
 
+        if self.training_params.max_train_batches is not None and (
+            self.training_params.max_train_batches > len(self.train_loader) or self.training_params.max_train_batches <= 0
+        ):
+
+            raise ValueError("max_train_batches must be positive and smaller then len(train_loader).")
+
+        self.max_train_batches = self.training_params.max_train_batches
+
+        if self.training_params.max_valid_batches is not None and (
+            self.training_params.max_valid_batches > len(self.valid_loader) or self.training_params.max_valid_batches <= 0
+        ):
+
+            raise ValueError("max_valid_batches must be positive and smaller then len(valid_loader).")
+        self.max_valid_batches = self.training_params.max_valid_batches
+
         # STATE ATTRIBUTE SET HERE FOR SUBSEQUENT TRAIN() CALLS
         # STATE ATTRIBUTE SET HERE FOR SUBSEQUENT TRAIN() CALLS
         self._first_backward = True
         self._first_backward = True
 
 
@@ -1754,6 +1780,9 @@ class Trainer:
 
 
                     progress_bar_data_loader.set_postfix(**pbar_message_dict)
                     progress_bar_data_loader.set_postfix(**pbar_message_dict)
 
 
+                if evaluation_type == EvaluationType.VALIDATION and self.max_valid_batches is not None and self.max_valid_batches - 1 <= batch_idx:
+                    break
+
         # NEED TO COMPUTE METRICS FOR THE FIRST TIME IF PROGRESS VERBOSITY IS NOT SET
         # NEED TO COMPUTE METRICS FOR THE FIRST TIME IF PROGRESS VERBOSITY IS NOT SET
         if not metrics_progress_verbose:
         if not metrics_progress_verbose:
             # COMPUTE THE RUNNING USER METRICS AND LOSS RUNNING ITEMS. RESULT TUPLE IS THEIR CONCATENATION.
             # COMPUTE THE RUNNING USER METRICS AND LOSS RUNNING ITEMS. RESULT TUPLE IS THEIR CONCATENATION.
Discard
@@ -26,6 +26,7 @@ from tests.unit_tests.detection_utils_test import TestDetectionUtils
 from tests.unit_tests.detection_dataset_test import DetectionDatasetTest
 from tests.unit_tests.detection_dataset_test import DetectionDatasetTest
 from tests.unit_tests.export_onnx_test import TestModelsONNXExport
 from tests.unit_tests.export_onnx_test import TestModelsONNXExport
 from tests.unit_tests.local_ckpt_head_replacement_test import LocalCkptHeadReplacementTest
 from tests.unit_tests.local_ckpt_head_replacement_test import LocalCkptHeadReplacementTest
+from tests.unit_tests.max_batches_loop_break_test import MaxBatchesLoopBreakTest
 from tests.unit_tests.phase_delegates_test import ContextMethodsTest
 from tests.unit_tests.phase_delegates_test import ContextMethodsTest
 from tests.unit_tests.quantization_utility_tests import QuantizationUtilityTest
 from tests.unit_tests.quantization_utility_tests import QuantizationUtilityTest
 from tests.unit_tests.random_erase_test import RandomEraseTest
 from tests.unit_tests.random_erase_test import RandomEraseTest
@@ -117,6 +118,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LocalCkptHeadReplacementTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LocalCkptHeadReplacementTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DetectionDatasetTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DetectionDatasetTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestModelsONNXExport))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestModelsONNXExport))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(MaxBatchesLoopBreakTest))
 
 
     def _add_modules_to_end_to_end_tests_suite(self):
     def _add_modules_to_end_to_end_tests_suite(self):
         """
         """
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
  1. import unittest
  2. from super_gradients.training import Trainer
  3. from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
  4. from super_gradients.training.metrics import Accuracy, Top5
  5. from super_gradients.training.utils.callbacks import PhaseCallback, Phase, PhaseContext
  6. from super_gradients.training.models import LeNet
  7. class LastBatchIdxCollector(PhaseCallback):
  8. def __init__(self, train: bool = True):
  9. phase = Phase.TRAIN_BATCH_END if train else Phase.VALIDATION_BATCH_END
  10. super().__init__(phase=phase)
  11. self.last_batch_idx = 0
  12. def __call__(self, context: PhaseContext):
  13. self.last_batch_idx = context.batch_idx
  14. class MaxBatchesLoopBreakTest(unittest.TestCase):
  15. def test_max_train_batches_loop_break(self):
  16. last_batch_collector = LastBatchIdxCollector()
  17. train_params = {
  18. "max_epochs": 2,
  19. "lr_updates": [1],
  20. "lr_decay_factor": 0.1,
  21. "lr_mode": "step",
  22. "lr_warmup_epochs": 0,
  23. "initial_lr": 0.1,
  24. "loss": "cross_entropy",
  25. "optimizer": "SGD",
  26. "criterion_params": {},
  27. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  28. "train_metrics_list": [Accuracy(), Top5()],
  29. "valid_metrics_list": [Accuracy(), Top5()],
  30. "metric_to_watch": "Accuracy",
  31. "greater_metric_to_watch_is_better": True,
  32. "phase_callbacks": [last_batch_collector],
  33. "max_train_batches": 3,
  34. }
  35. # Define Model
  36. net = LeNet()
  37. trainer = Trainer("test_max_batches_break_train")
  38. trainer.train(
  39. model=net,
  40. training_params=train_params,
  41. train_loader=classification_test_dataloader(dataset_size=16, batch_size=4),
  42. valid_loader=classification_test_dataloader(),
  43. )
  44. # ASSERT LAST BATCH IDX IS 2
  45. print(last_batch_collector.last_batch_idx)
  46. self.assertTrue(last_batch_collector.last_batch_idx == 2)
  47. def test_max_valid_batches_loop_break(self):
  48. last_batch_collector = LastBatchIdxCollector(train=False)
  49. train_params = {
  50. "max_epochs": 2,
  51. "lr_updates": [1],
  52. "lr_decay_factor": 0.1,
  53. "lr_mode": "step",
  54. "lr_warmup_epochs": 0,
  55. "initial_lr": 0.1,
  56. "loss": "cross_entropy",
  57. "optimizer": "SGD",
  58. "criterion_params": {},
  59. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  60. "train_metrics_list": [Accuracy(), Top5()],
  61. "valid_metrics_list": [Accuracy(), Top5()],
  62. "metric_to_watch": "Accuracy",
  63. "greater_metric_to_watch_is_better": True,
  64. "phase_callbacks": [last_batch_collector],
  65. "max_valid_batches": 3,
  66. }
  67. # Define Model
  68. net = LeNet()
  69. trainer = Trainer("test_max_batches_break_val")
  70. trainer.train(
  71. model=net,
  72. training_params=train_params,
  73. train_loader=classification_test_dataloader(),
  74. valid_loader=classification_test_dataloader(dataset_size=16, batch_size=4),
  75. )
  76. # ASSERT LAST BATCH IDX IS 2
  77. self.assertTrue(last_batch_collector.last_batch_idx == 2)
  78. if __name__ == "__main__":
  79. unittest.main()
Discard