Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

pretrained_models_unit_test.py 2.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  1. import os
  2. import shutil
  3. import tempfile
  4. import unittest
  5. import numpy as np
  6. import torch
  7. import super_gradients
  8. from super_gradients.common.object_names import Models
  9. from super_gradients.training import Trainer
  10. from super_gradients.training import models
  11. from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
  12. from super_gradients.training.metrics import Accuracy
  13. from super_gradients.training.pretrained_models import MODEL_URLS, PRETRAINED_NUM_CLASSES
  14. from super_gradients.training.processing.processing import default_yolo_nas_coco_processing_params
  15. class PretrainedModelsUnitTest(unittest.TestCase):
  16. def setUp(self) -> None:
  17. super_gradients.init_trainer()
  18. self.imagenet_pretrained_models = [Models.RESNET50, "repvgg_a0", "regnetY800"]
  19. def test_pretrained_resnet50_imagenet(self):
  20. trainer = Trainer("imagenet_pretrained_resnet50_unit_test")
  21. model = models.get(Models.RESNET50, pretrained_weights="imagenet")
  22. trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
  23. def test_pretrained_regnetY800_imagenet(self):
  24. trainer = Trainer("imagenet_pretrained_regnetY800_unit_test")
  25. model = models.get(Models.REGNETY800, pretrained_weights="imagenet")
  26. trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
  27. def test_pretrained_repvgg_a0_imagenet(self):
  28. trainer = Trainer("imagenet_pretrained_repvgg_a0_unit_test")
  29. model = models.get(Models.REPVGG_A0, pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
  30. trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
  31. def test_pretrained_models_load_preprocessing_params(self):
  32. """
  33. Test that checks whether preprocessing params from pretrained model load correctly.
  34. """
  35. state = {"net": models.get(Models.YOLO_NAS_S, num_classes=80).state_dict(), "processing_params": default_yolo_nas_coco_processing_params()}
  36. with tempfile.TemporaryDirectory() as td:
  37. checkpoint_path = os.path.join(td, "yolo_nas_s_coco.pth")
  38. torch.save(state, checkpoint_path)
  39. MODEL_URLS[Models.YOLO_NAS_S + "_test"] = checkpoint_path
  40. PRETRAINED_NUM_CLASSES["test"] = 80
  41. model = models.get(Models.YOLO_NAS_S, pretrained_weights="test")
  42. # .predict() would fail it model has no preprocessing params
  43. self.assertIsNotNone(model.predict(np.zeros(shape=(512, 512, 3), dtype=np.uint8)))
  44. def tearDown(self) -> None:
  45. if os.path.exists("~/.cache/torch/hub/"):
  46. shutil.rmtree("~/.cache/torch/hub/")
  47. if __name__ == "__main__":
  48. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...