Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#13 refactor: Use width multiplier on YOLO modules from outside

Merged
Kate Feingold merged 2 commits into Deci-AI:master from deci-ai:feature/ALG-78_width-mult-yolo-modules
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
  1. import unittest
  2. from super_gradients.training import SgModel
  3. from super_gradients.training.metrics import Accuracy
  4. from super_gradients.training.datasets import ClassificationTestDatasetInterface
  5. from super_gradients.training.models import LeNet
  6. from super_gradients.training.utils.callbacks import PhaseCallback, Phase, PhaseContext
  7. class TestLRCallback(PhaseCallback):
  8. """
  9. Phase callback that collects the learning rates in lr_placeholder at the end of each epoch (used for testing). In
  10. the case of multiple parameter groups (i.e multiple learning rates) the learning rate is collected from the first
  11. one. The phase is VALIDATION_EPOCH_END to ensure all lr updates have been performed before calling this callback.
  12. """
  13. def __init__(self, lr_placeholder):
  14. super(TestLRCallback, self).__init__(Phase.VALIDATION_EPOCH_END)
  15. self.lr_placeholder = lr_placeholder
  16. def __call__(self, context: PhaseContext):
  17. self.lr_placeholder.append(context.optimizer.param_groups[0]['lr'])
  18. class LRWarmupTest(unittest.TestCase):
  19. def setUp(self) -> None:
  20. self.dataset_params = {"batch_size": 4}
  21. self.dataset = ClassificationTestDatasetInterface(dataset_params=self.dataset_params)
  22. self.arch_params = {'num_classes': 10}
  23. def test_lr_warmup(self):
  24. # Define Model
  25. net = LeNet()
  26. model = SgModel("lr_warmup_test", model_checkpoints_location='local')
  27. model.connect_dataset_interface(self.dataset)
  28. model.build_model(net, arch_params=self.arch_params)
  29. lrs = []
  30. phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
  31. train_params = {"max_epochs": 5, "lr_updates": [], "lr_decay_factor": 0.1, "lr_mode": "step",
  32. "lr_warmup_epochs": 3, "initial_lr": 1, "loss": "cross_entropy", "optimizer": 'SGD',
  33. "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  34. "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
  35. "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
  36. "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks}
  37. expected_lrs = [0.25, 0.5, 0.75, 1.0, 1.0]
  38. model.train(train_params)
  39. self.assertListEqual(lrs, expected_lrs)
  40. def test_lr_warmup_with_lr_scheduling(self):
  41. # Define Model
  42. net = LeNet()
  43. model = SgModel("lr_warmup_test", model_checkpoints_location='local')
  44. model.connect_dataset_interface(self.dataset)
  45. model.build_model(net, arch_params=self.arch_params)
  46. lrs = []
  47. phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
  48. train_params = {"max_epochs": 5, "cosine_final_lr_ratio": 0.2, "lr_mode": "cosine",
  49. "lr_warmup_epochs": 3, "initial_lr": 1, "loss": "cross_entropy", "optimizer": 'SGD',
  50. "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  51. "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
  52. "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
  53. "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks}
  54. expected_lrs = [0.25, 0.5, 0.75, 0.9236067977499791, 0.4763932022500211]
  55. model.train(train_params)
  56. # ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
  57. # THE LRS AFTER THE UPDATE
  58. self.assertListEqual(lrs, expected_lrs)
  59. if __name__ == '__main__':
  60. unittest.main()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file