Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

max_batches_loop_break_test.py 3.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
  1. import unittest
  2. from super_gradients.training import Trainer
  3. from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
  4. from super_gradients.training.metrics import Accuracy, Top5
  5. from super_gradients.training.utils.callbacks import PhaseCallback, Phase, PhaseContext
  6. from super_gradients.training.models import LeNet
  7. class LastBatchIdxCollector(PhaseCallback):
  8. def __init__(self, train: bool = True):
  9. phase = Phase.TRAIN_BATCH_END if train else Phase.VALIDATION_BATCH_END
  10. super().__init__(phase=phase)
  11. self.last_batch_idx = 0
  12. def __call__(self, context: PhaseContext):
  13. self.last_batch_idx = context.batch_idx
  14. class MaxBatchesLoopBreakTest(unittest.TestCase):
  15. def test_max_train_batches_loop_break(self):
  16. last_batch_collector = LastBatchIdxCollector()
  17. train_params = {
  18. "max_epochs": 2,
  19. "lr_updates": [1],
  20. "lr_decay_factor": 0.1,
  21. "lr_mode": "StepLRScheduler",
  22. "lr_warmup_epochs": 0,
  23. "initial_lr": 0.1,
  24. "loss": "CrossEntropyLoss",
  25. "optimizer": "SGD",
  26. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  27. "train_metrics_list": [Accuracy(), Top5()],
  28. "valid_metrics_list": [Accuracy(), Top5()],
  29. "metric_to_watch": "Accuracy",
  30. "greater_metric_to_watch_is_better": True,
  31. "phase_callbacks": [last_batch_collector],
  32. "max_train_batches": 3,
  33. }
  34. # Define Model
  35. net = LeNet()
  36. trainer = Trainer("test_max_batches_break_train")
  37. trainer.train(
  38. model=net,
  39. training_params=train_params,
  40. train_loader=classification_test_dataloader(dataset_size=16, batch_size=4),
  41. valid_loader=classification_test_dataloader(),
  42. )
  43. # ASSERT LAST BATCH IDX IS 2
  44. print(last_batch_collector.last_batch_idx)
  45. self.assertTrue(last_batch_collector.last_batch_idx == 2)
  46. def test_max_valid_batches_loop_break(self):
  47. last_batch_collector = LastBatchIdxCollector(train=False)
  48. train_params = {
  49. "max_epochs": 2,
  50. "lr_updates": [1],
  51. "lr_decay_factor": 0.1,
  52. "lr_mode": "StepLRScheduler",
  53. "lr_warmup_epochs": 0,
  54. "initial_lr": 0.1,
  55. "loss": "CrossEntropyLoss",
  56. "optimizer": "SGD",
  57. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  58. "train_metrics_list": [Accuracy(), Top5()],
  59. "valid_metrics_list": [Accuracy(), Top5()],
  60. "metric_to_watch": "Accuracy",
  61. "greater_metric_to_watch_is_better": True,
  62. "phase_callbacks": [last_batch_collector],
  63. "max_valid_batches": 3,
  64. }
  65. # Define Model
  66. net = LeNet()
  67. trainer = Trainer("test_max_batches_break_val")
  68. trainer.train(
  69. model=net,
  70. training_params=train_params,
  71. train_loader=classification_test_dataloader(),
  72. valid_loader=classification_test_dataloader(dataset_size=16, batch_size=4),
  73. )
  74. # ASSERT LAST BATCH IDX IS 2
  75. self.assertTrue(last_batch_collector.last_batch_idx == 2)
  76. if __name__ == "__main__":
  77. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...