Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#587 Feature/sg 521 gpu tests

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-521_gpu_tests
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
  1. import unittest
  2. import torch
  3. from super_gradients import Trainer
  4. from super_gradients.training import models
  5. from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
  6. from super_gradients.training.metrics import Accuracy
  7. class CallTrainTwiceTest(unittest.TestCase):
  8. """
  9. CallTrainTwiceTest
  10. Purpose is to call train twice and see nothing crashes. Should be ran with available GPUs (when possible)
  11. so when calling train again we see there's no change in the model's device.
  12. """
  13. def test_call_train_twice(self):
  14. trainer = Trainer("external_criterion_test")
  15. dataloader = classification_test_dataloader(batch_size=10)
  16. model = models.get("resnet18", num_classes=5)
  17. train_params = {
  18. "max_epochs": 2,
  19. "lr_updates": [1],
  20. "lr_decay_factor": 0.1,
  21. "lr_mode": "step",
  22. "lr_warmup_epochs": 0,
  23. "initial_lr": 0.1,
  24. "loss": torch.nn.CrossEntropyLoss(),
  25. "optimizer": "SGD",
  26. "criterion_params": {},
  27. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  28. "train_metrics_list": [Accuracy()],
  29. "valid_metrics_list": [Accuracy()],
  30. "metric_to_watch": "Accuracy",
  31. "greater_metric_to_watch_is_better": True,
  32. }
  33. trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
  34. trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
  35. if __name__ == "__main__":
  36. unittest.main()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file