|
@@ -22,8 +22,8 @@ class ExponentialWarmupLRCallback(LRCallbackBase):
|
|
warmup_epochs = self.training_params.lr_warmup_epochs
|
|
warmup_epochs = self.training_params.lr_warmup_epochs
|
|
lr_start = self.warmup_initial_lr
|
|
lr_start = self.warmup_initial_lr
|
|
lr_end = self.initial_lr
|
|
lr_end = self.initial_lr
|
|
- self.c1 = (lr_end - lr_start) / (np.exp(warmup_epochs) - 1.)
|
|
|
|
- self.c2 = (lr_start * np.exp(warmup_epochs) - lr_end) / (np.exp(warmup_epochs) - 1.)
|
|
|
|
|
|
+ self.c1 = (lr_end - lr_start) / (np.exp(warmup_epochs) - 1.0)
|
|
|
|
+ self.c2 = (lr_start * np.exp(warmup_epochs) - lr_end) / (np.exp(warmup_epochs) - 1.0)
|
|
|
|
|
|
def perform_scheduling(self, context):
|
|
def perform_scheduling(self, context):
|
|
self.lr = self.c1 * np.exp(context.epoch) + self.c2
|
|
self.lr = self.c1 * np.exp(context.epoch) + self.c2
|
|
@@ -34,7 +34,6 @@ class ExponentialWarmupLRCallback(LRCallbackBase):
|
|
|
|
|
|
|
|
|
|
class LRWarmupTest(unittest.TestCase):
|
|
class LRWarmupTest(unittest.TestCase):
|
|
-
|
|
|
|
def test_lr_warmup(self):
|
|
def test_lr_warmup(self):
|
|
# Define Model
|
|
# Define Model
|
|
net = LeNet()
|
|
net = LeNet()
|
|
@@ -43,18 +42,33 @@ class LRWarmupTest(unittest.TestCase):
|
|
lrs = []
|
|
lrs = []
|
|
phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
|
|
phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
|
|
|
|
|
|
- train_params = {"max_epochs": 5, "lr_updates": [], "lr_decay_factor": 0.1, "lr_mode": "step",
|
|
|
|
- "lr_warmup_epochs": 3, "initial_lr": 1, "loss": "cross_entropy", "optimizer": 'SGD',
|
|
|
|
- "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
|
|
|
|
- "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
|
|
|
|
- "metric_to_watch": "Accuracy",
|
|
|
|
- "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks,
|
|
|
|
- "warmup_mode": "linear_step"}
|
|
|
|
|
|
+ train_params = {
|
|
|
|
+ "max_epochs": 5,
|
|
|
|
+ "lr_updates": [],
|
|
|
|
+ "lr_decay_factor": 0.1,
|
|
|
|
+ "lr_mode": "step",
|
|
|
|
+ "lr_warmup_epochs": 3,
|
|
|
|
+ "initial_lr": 1,
|
|
|
|
+ "loss": "cross_entropy",
|
|
|
|
+ "optimizer": "SGD",
|
|
|
|
+ "criterion_params": {},
|
|
|
|
+ "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
|
|
|
|
+ "train_metrics_list": [Accuracy()],
|
|
|
|
+ "valid_metrics_list": [Accuracy()],
|
|
|
|
+ "metric_to_watch": "Accuracy",
|
|
|
|
+ "greater_metric_to_watch_is_better": True,
|
|
|
|
+ "ema": False,
|
|
|
|
+ "phase_callbacks": phase_callbacks,
|
|
|
|
+ "warmup_mode": "linear_step",
|
|
|
|
+ }
|
|
|
|
|
|
expected_lrs = [0.25, 0.5, 0.75, 1.0, 1.0]
|
|
expected_lrs = [0.25, 0.5, 0.75, 1.0, 1.0]
|
|
- trainer.train(model=net, training_params=train_params,
|
|
|
|
- train_loader=classification_test_dataloader(batch_size=4),
|
|
|
|
- valid_loader=classification_test_dataloader(batch_size=4))
|
|
|
|
|
|
+ trainer.train(
|
|
|
|
+ model=net,
|
|
|
|
+ training_params=train_params,
|
|
|
|
+ train_loader=classification_test_dataloader(batch_size=4),
|
|
|
|
+ valid_loader=classification_test_dataloader(batch_size=4),
|
|
|
|
+ )
|
|
self.assertListEqual(lrs, expected_lrs)
|
|
self.assertListEqual(lrs, expected_lrs)
|
|
|
|
|
|
def test_lr_warmup_with_lr_scheduling(self):
|
|
def test_lr_warmup_with_lr_scheduling(self):
|
|
@@ -65,18 +79,32 @@ class LRWarmupTest(unittest.TestCase):
|
|
lrs = []
|
|
lrs = []
|
|
phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
|
|
phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
|
|
|
|
|
|
- train_params = {"max_epochs": 5, "cosine_final_lr_ratio": 0.2, "lr_mode": "cosine",
|
|
|
|
- "lr_warmup_epochs": 3, "initial_lr": 1, "loss": "cross_entropy", "optimizer": 'SGD',
|
|
|
|
- "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
|
|
|
|
- "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
|
|
|
|
- "metric_to_watch": "Accuracy",
|
|
|
|
- "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks,
|
|
|
|
- "warmup_mode": "linear_step"}
|
|
|
|
|
|
+ train_params = {
|
|
|
|
+ "max_epochs": 5,
|
|
|
|
+ "cosine_final_lr_ratio": 0.2,
|
|
|
|
+ "lr_mode": "cosine",
|
|
|
|
+ "lr_warmup_epochs": 3,
|
|
|
|
+ "initial_lr": 1,
|
|
|
|
+ "loss": "cross_entropy",
|
|
|
|
+ "optimizer": "SGD",
|
|
|
|
+ "criterion_params": {},
|
|
|
|
+ "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
|
|
|
|
+ "train_metrics_list": [Accuracy()],
|
|
|
|
+ "valid_metrics_list": [Accuracy()],
|
|
|
|
+ "metric_to_watch": "Accuracy",
|
|
|
|
+ "greater_metric_to_watch_is_better": True,
|
|
|
|
+ "ema": False,
|
|
|
|
+ "phase_callbacks": phase_callbacks,
|
|
|
|
+ "warmup_mode": "linear_step",
|
|
|
|
+ }
|
|
|
|
|
|
expected_lrs = [0.25, 0.5, 0.75, 0.9236067977499791, 0.4763932022500211]
|
|
expected_lrs = [0.25, 0.5, 0.75, 0.9236067977499791, 0.4763932022500211]
|
|
- trainer.train(model=net, training_params=train_params,
|
|
|
|
- train_loader=classification_test_dataloader(batch_size=4, dataset_size=5),
|
|
|
|
- valid_loader=classification_test_dataloader(batch_size=4, dataset_size=5))
|
|
|
|
|
|
+ trainer.train(
|
|
|
|
+ model=net,
|
|
|
|
+ training_params=train_params,
|
|
|
|
+ train_loader=classification_test_dataloader(batch_size=4, dataset_size=5),
|
|
|
|
+ valid_loader=classification_test_dataloader(batch_size=4, dataset_size=5),
|
|
|
|
+ )
|
|
|
|
|
|
# ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
|
|
# ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
|
|
# THE LRS AFTER THE UPDATE
|
|
# THE LRS AFTER THE UPDATE
|
|
@@ -90,18 +118,34 @@ class LRWarmupTest(unittest.TestCase):
|
|
lrs = []
|
|
lrs = []
|
|
phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
|
|
phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
|
|
|
|
|
|
- train_params = {"max_epochs": 5, "lr_updates": [], "lr_decay_factor": 0.1, "lr_mode": "step",
|
|
|
|
- "lr_warmup_epochs": 3, "loss": "cross_entropy", "optimizer": 'SGD',
|
|
|
|
- "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
|
|
|
|
- "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
|
|
|
|
- "metric_to_watch": "Accuracy",
|
|
|
|
- "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks,
|
|
|
|
- "warmup_mode": "linear_step", "initial_lr": 1, "warmup_initial_lr": 4.}
|
|
|
|
-
|
|
|
|
- expected_lrs = [4., 3., 2., 1., 1.]
|
|
|
|
- trainer.train(model=net, training_params=train_params,
|
|
|
|
- train_loader=classification_test_dataloader(batch_size=4, dataset_size=5),
|
|
|
|
- valid_loader=classification_test_dataloader(batch_size=4, dataset_size=5))
|
|
|
|
|
|
+ train_params = {
|
|
|
|
+ "max_epochs": 5,
|
|
|
|
+ "lr_updates": [],
|
|
|
|
+ "lr_decay_factor": 0.1,
|
|
|
|
+ "lr_mode": "step",
|
|
|
|
+ "lr_warmup_epochs": 3,
|
|
|
|
+ "loss": "cross_entropy",
|
|
|
|
+ "optimizer": "SGD",
|
|
|
|
+ "criterion_params": {},
|
|
|
|
+ "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
|
|
|
|
+ "train_metrics_list": [Accuracy()],
|
|
|
|
+ "valid_metrics_list": [Accuracy()],
|
|
|
|
+ "metric_to_watch": "Accuracy",
|
|
|
|
+ "greater_metric_to_watch_is_better": True,
|
|
|
|
+ "ema": False,
|
|
|
|
+ "phase_callbacks": phase_callbacks,
|
|
|
|
+ "warmup_mode": "linear_step",
|
|
|
|
+ "initial_lr": 1,
|
|
|
|
+ "warmup_initial_lr": 4.0,
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ expected_lrs = [4.0, 3.0, 2.0, 1.0, 1.0]
|
|
|
|
+ trainer.train(
|
|
|
|
+ model=net,
|
|
|
|
+ training_params=train_params,
|
|
|
|
+ train_loader=classification_test_dataloader(batch_size=4, dataset_size=5),
|
|
|
|
+ valid_loader=classification_test_dataloader(batch_size=4, dataset_size=5),
|
|
|
|
+ )
|
|
self.assertListEqual(lrs, expected_lrs)
|
|
self.assertListEqual(lrs, expected_lrs)
|
|
|
|
|
|
def test_custom_lr_warmup(self):
|
|
def test_custom_lr_warmup(self):
|
|
@@ -112,20 +156,36 @@ class LRWarmupTest(unittest.TestCase):
|
|
lrs = []
|
|
lrs = []
|
|
phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
|
|
phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
|
|
|
|
|
|
- train_params = {"max_epochs": 5, "lr_updates": [], "lr_decay_factor": 0.1, "lr_mode": "step",
|
|
|
|
- "lr_warmup_epochs": 3, "loss": "cross_entropy", "optimizer": 'SGD',
|
|
|
|
- "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
|
|
|
|
- "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
|
|
|
|
- "metric_to_watch": "Accuracy",
|
|
|
|
- "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks,
|
|
|
|
- "warmup_mode": ExponentialWarmupLRCallback, "initial_lr": 1., "warmup_initial_lr": 0.1}
|
|
|
|
|
|
+ train_params = {
|
|
|
|
+ "max_epochs": 5,
|
|
|
|
+ "lr_updates": [],
|
|
|
|
+ "lr_decay_factor": 0.1,
|
|
|
|
+ "lr_mode": "step",
|
|
|
|
+ "lr_warmup_epochs": 3,
|
|
|
|
+ "loss": "cross_entropy",
|
|
|
|
+ "optimizer": "SGD",
|
|
|
|
+ "criterion_params": {},
|
|
|
|
+ "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
|
|
|
|
+ "train_metrics_list": [Accuracy()],
|
|
|
|
+ "valid_metrics_list": [Accuracy()],
|
|
|
|
+ "metric_to_watch": "Accuracy",
|
|
|
|
+ "greater_metric_to_watch_is_better": True,
|
|
|
|
+ "ema": False,
|
|
|
|
+ "phase_callbacks": phase_callbacks,
|
|
|
|
+ "warmup_mode": ExponentialWarmupLRCallback,
|
|
|
|
+ "initial_lr": 1.0,
|
|
|
|
+ "warmup_initial_lr": 0.1,
|
|
|
|
+ }
|
|
|
|
|
|
expected_lrs = [0.1, 0.18102751585334242, 0.40128313980266034, 1.0, 1.0]
|
|
expected_lrs = [0.1, 0.18102751585334242, 0.40128313980266034, 1.0, 1.0]
|
|
- trainer.train(model=net, training_params=train_params,
|
|
|
|
- train_loader=classification_test_dataloader(batch_size=4),
|
|
|
|
- valid_loader=classification_test_dataloader(batch_size=4))
|
|
|
|
|
|
+ trainer.train(
|
|
|
|
+ model=net,
|
|
|
|
+ training_params=train_params,
|
|
|
|
+ train_loader=classification_test_dataloader(batch_size=4),
|
|
|
|
+ valid_loader=classification_test_dataloader(batch_size=4),
|
|
|
|
+ )
|
|
self.assertListEqual(lrs, expected_lrs)
|
|
self.assertListEqual(lrs, expected_lrs)
|
|
|
|
|
|
|
|
|
|
-if __name__ == '__main__':
|
|
|
|
|
|
+if __name__ == "__main__":
|
|
unittest.main()
|
|
unittest.main()
|