Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

cifar10_trainer_test.py 745 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
  1. import unittest
  2. from super_gradients.training import models
  3. import super_gradients
  4. from super_gradients import Trainer
  5. from super_gradients.training.datasets.dataset_interfaces import LibraryDatasetInterface
  6. class TestCifar10Trainer(unittest.TestCase):
  7. def test_train_cifar10(self):
  8. super_gradients.init_trainer()
  9. trainer = Trainer("test", model_checkpoints_location='local')
  10. cifar_10_dataset_interface = LibraryDatasetInterface(name="cifar10")
  11. trainer.connect_dataset_interface(cifar_10_dataset_interface)
  12. model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
  13. trainer.train(model=model, training_params={"max_epochs": 1})
  14. if __name__ == '__main__':
  15. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...