Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#568 Refactored scheduler callbacks (epoch-based/step-based warmup)

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-525-step-based-warmup
1 changed files with 75 additions and 11 deletions
  1. 75
    11
      tests/unit_tests/lr_warmup_test.py
@@ -6,7 +6,19 @@ from super_gradients.training import Trainer
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.models import LeNet
 from super_gradients.training.models import LeNet
-from super_gradients.training.utils.callbacks import TestLRCallback, LRCallbackBase, Phase
+from super_gradients.training.utils.callbacks import TestLRCallback, LRCallbackBase, Phase, Callback, PhaseContext, CosineLRCallback
+
+
+class CollectLRCallback(Callback):
+    def __init__(self):
+        self.per_step_learning_rates = []
+        self.per_epoch_learning_rates = []
+
+    def on_train_batch_end(self, context: PhaseContext) -> None:
+        self.per_step_learning_rates.append(context.optimizer.param_groups[0]["lr"])
+
+    def on_train_loader_end(self, context: PhaseContext) -> None:
+        self.per_epoch_learning_rates.append(context.optimizer.param_groups[0]["lr"])
 
 
 
 
 class ExponentialWarmupLRCallback(LRCallbackBase):
 class ExponentialWarmupLRCallback(LRCallbackBase):
@@ -59,7 +71,7 @@ class LRWarmupTest(unittest.TestCase):
             "greater_metric_to_watch_is_better": True,
             "greater_metric_to_watch_is_better": True,
             "ema": False,
             "ema": False,
             "phase_callbacks": phase_callbacks,
             "phase_callbacks": phase_callbacks,
-            "warmup_mode": "linear_step",
+            "warmup_mode": "linear_epoch_step",
         }
         }
 
 
         expected_lrs = [0.25, 0.5, 0.75, 1.0, 1.0]
         expected_lrs = [0.25, 0.5, 0.75, 1.0, 1.0]
@@ -95,7 +107,7 @@ class LRWarmupTest(unittest.TestCase):
             "greater_metric_to_watch_is_better": True,
             "greater_metric_to_watch_is_better": True,
             "ema": False,
             "ema": False,
             "phase_callbacks": phase_callbacks,
             "phase_callbacks": phase_callbacks,
-            "warmup_mode": "linear_step",
+            "warmup_mode": "linear_epoch_step",
         }
         }
 
 
         expected_lrs = [0.25, 0.5, 0.75, 0.9236067977499791, 0.4763932022500211]
         expected_lrs = [0.25, 0.5, 0.75, 0.9236067977499791, 0.4763932022500211]
@@ -110,13 +122,65 @@ class LRWarmupTest(unittest.TestCase):
         # THE LRS AFTER THE UPDATE
         # THE LRS AFTER THE UPDATE
         self.assertListEqual(lrs, expected_lrs)
         self.assertListEqual(lrs, expected_lrs)
 
 
-    def test_warmup_initial_lr(self):
+    def test_warmup_linear_batch_step(self):
+        # Define model
+        net = LeNet()
+        trainer = Trainer("lr_warmup_test_per_step")
+
+        collect_lr_callback = CollectLRCallback()
+
+        warmup_initial_lr = 0.05
+        lr_warmup_steps = 100
+        initial_lr = 1
+        cosine_final_lr_ratio = 0.2
+        max_epochs = 5
+
+        train_params = {
+            "max_epochs": max_epochs,
+            "lr_mode": "cosine",
+            "cosine_final_lr_ratio": cosine_final_lr_ratio,
+            "warmup_initial_lr": warmup_initial_lr,
+            "warmup_mode": "linear_batch_step",
+            "lr_warmup_steps": lr_warmup_steps,
+            "initial_lr": 1,
+            "loss": "cross_entropy",
+            "optimizer": "SGD",
+            "criterion_params": {},
+            "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
+            "train_metrics_list": [Accuracy()],
+            "valid_metrics_list": [Accuracy()],
+            "metric_to_watch": "Accuracy",
+            "greater_metric_to_watch_is_better": True,
+            "ema": False,
+            "phase_callbacks": [collect_lr_callback],
+        }
+
+        train_loader = classification_test_dataloader(batch_size=4, dataset_size=1024)
+        valid_loader = classification_test_dataloader(batch_size=4, dataset_size=5)
+
+        expected_warmup_lrs = np.linspace(warmup_initial_lr, initial_lr, lr_warmup_steps).tolist()
+        total_steps = max_epochs * len(train_loader) - lr_warmup_steps
+
+        expected_cosine_lrs = CosineLRCallback.compute_learning_rate(
+            step=np.arange(0, total_steps), total_steps=total_steps, initial_lr=initial_lr, final_lr_ratio=cosine_final_lr_ratio
+        )
+
+        trainer.train(
+            model=net,
+            training_params=train_params,
+            train_loader=train_loader,
+            valid_loader=valid_loader,
+        )
+
+        np.testing.assert_allclose(collect_lr_callback.per_step_learning_rates[:100], expected_warmup_lrs, rtol=1e-4)
+        np.testing.assert_allclose(collect_lr_callback.per_step_learning_rates[100:], expected_cosine_lrs, rtol=1e-4)
+
+    def test_warmup_linear_epoch_step(self):
         # Define model
         # Define model
         net = LeNet()
         net = LeNet()
         trainer = Trainer("test_warmup_initial_lr")
         trainer = Trainer("test_warmup_initial_lr")
 
 
-        lrs = []
-        phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
+        collect_lr_callback = CollectLRCallback()
 
 
         train_params = {
         train_params = {
             "max_epochs": 5,
             "max_epochs": 5,
@@ -124,6 +188,8 @@ class LRWarmupTest(unittest.TestCase):
             "lr_decay_factor": 0.1,
             "lr_decay_factor": 0.1,
             "lr_mode": "step",
             "lr_mode": "step",
             "lr_warmup_epochs": 3,
             "lr_warmup_epochs": 3,
+            "initial_lr": 1,
+            "warmup_initial_lr": 4.0,
             "loss": "cross_entropy",
             "loss": "cross_entropy",
             "optimizer": "SGD",
             "optimizer": "SGD",
             "criterion_params": {},
             "criterion_params": {},
@@ -133,10 +199,8 @@ class LRWarmupTest(unittest.TestCase):
             "metric_to_watch": "Accuracy",
             "metric_to_watch": "Accuracy",
             "greater_metric_to_watch_is_better": True,
             "greater_metric_to_watch_is_better": True,
             "ema": False,
             "ema": False,
-            "phase_callbacks": phase_callbacks,
-            "warmup_mode": "linear_step",
-            "initial_lr": 1,
-            "warmup_initial_lr": 4.0,
+            "phase_callbacks": [collect_lr_callback],
+            "warmup_mode": "linear_epoch_step",
         }
         }
 
 
         expected_lrs = [4.0, 3.0, 2.0, 1.0, 1.0]
         expected_lrs = [4.0, 3.0, 2.0, 1.0, 1.0]
@@ -146,7 +210,7 @@ class LRWarmupTest(unittest.TestCase):
             train_loader=classification_test_dataloader(batch_size=4, dataset_size=5),
             train_loader=classification_test_dataloader(batch_size=4, dataset_size=5),
             valid_loader=classification_test_dataloader(batch_size=4, dataset_size=5),
             valid_loader=classification_test_dataloader(batch_size=4, dataset_size=5),
         )
         )
-        self.assertListEqual(lrs, expected_lrs)
+        self.assertListEqual(collect_lr_callback.per_epoch_learning_rates, expected_lrs)
 
 
     def test_custom_lr_warmup(self):
     def test_custom_lr_warmup(self):
         # Define model
         # Define model
Discard