Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

max_batches_loop_break_test.py 3.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
  1. import unittest
  2. from super_gradients.training import Trainer
  3. from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
  4. from super_gradients.training.metrics import Accuracy, Top5
  5. from super_gradients.training.utils.callbacks import PhaseCallback, Phase, PhaseContext
  6. from super_gradients.training.models import LeNet
  7. class LastBatchIdxCollector(PhaseCallback):
  8. def __init__(self, train: bool = True):
  9. phase = Phase.TRAIN_BATCH_END if train else Phase.VALIDATION_BATCH_END
  10. super().__init__(phase=phase)
  11. self.last_batch_idx = 0
  12. def __call__(self, context: PhaseContext):
  13. self.last_batch_idx = context.batch_idx
  14. class MaxBatchesLoopBreakTest(unittest.TestCase):
  15. def test_max_train_batches_loop_break(self):
  16. last_batch_collector = LastBatchIdxCollector()
  17. train_params = {
  18. "max_epochs": 2,
  19. "lr_updates": [1],
  20. "lr_decay_factor": 0.1,
  21. "lr_mode": "StepLRScheduler",
  22. "lr_warmup_epochs": 0,
  23. "initial_lr": 0.1,
  24. "loss": "CrossEntropyLoss",
  25. "optimizer": "SGD",
  26. "criterion_params": {},
  27. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  28. "train_metrics_list": [Accuracy(), Top5()],
  29. "valid_metrics_list": [Accuracy(), Top5()],
  30. "metric_to_watch": "Accuracy",
  31. "greater_metric_to_watch_is_better": True,
  32. "phase_callbacks": [last_batch_collector],
  33. "max_train_batches": 3,
  34. }
  35. # Define Model
  36. net = LeNet()
  37. trainer = Trainer("test_max_batches_break_train")
  38. trainer.train(
  39. model=net,
  40. training_params=train_params,
  41. train_loader=classification_test_dataloader(dataset_size=16, batch_size=4),
  42. valid_loader=classification_test_dataloader(),
  43. )
  44. # ASSERT LAST BATCH IDX IS 2
  45. print(last_batch_collector.last_batch_idx)
  46. self.assertTrue(last_batch_collector.last_batch_idx == 2)
  47. def test_max_valid_batches_loop_break(self):
  48. last_batch_collector = LastBatchIdxCollector(train=False)
  49. train_params = {
  50. "max_epochs": 2,
  51. "lr_updates": [1],
  52. "lr_decay_factor": 0.1,
  53. "lr_mode": "StepLRScheduler",
  54. "lr_warmup_epochs": 0,
  55. "initial_lr": 0.1,
  56. "loss": "CrossEntropyLoss",
  57. "optimizer": "SGD",
  58. "criterion_params": {},
  59. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  60. "train_metrics_list": [Accuracy(), Top5()],
  61. "valid_metrics_list": [Accuracy(), Top5()],
  62. "metric_to_watch": "Accuracy",
  63. "greater_metric_to_watch_is_better": True,
  64. "phase_callbacks": [last_batch_collector],
  65. "max_valid_batches": 3,
  66. }
  67. # Define Model
  68. net = LeNet()
  69. trainer = Trainer("test_max_batches_break_val")
  70. trainer.train(
  71. model=net,
  72. training_params=train_params,
  73. train_loader=classification_test_dataloader(),
  74. valid_loader=classification_test_dataloader(dataset_size=16, batch_size=4),
  75. )
  76. # ASSERT LAST BATCH IDX IS 2
  77. self.assertTrue(last_batch_collector.last_batch_idx == 2)
  78. if __name__ == "__main__":
  79. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...