Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

load_checkpoint_test.py 1.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
  1. import unittest
  2. import torch.nn.init
  3. from torch import nn
  4. from super_gradients.common.object_names import Models
  5. from super_gradients.training import models
  6. from super_gradients.training.utils.checkpoint_utils import transfer_weights
  7. class LoadCheckpointTest(unittest.TestCase):
  8. def test_transfer_weights(self):
  9. class Foo(nn.Module):
  10. def __init__(self):
  11. super().__init__()
  12. self.fc1 = nn.Linear(10, 10)
  13. self.fc2 = nn.Linear(10, 10)
  14. torch.nn.init.zeros_(self.fc1.weight)
  15. torch.nn.init.zeros_(self.fc2.weight)
  16. class Bar(nn.Module):
  17. def __init__(self):
  18. super().__init__()
  19. self.fc1 = nn.Linear(10, 11)
  20. self.fc2 = nn.Linear(10, 10)
  21. torch.nn.init.ones_(self.fc1.weight)
  22. torch.nn.init.ones_(self.fc2.weight)
  23. foo = Foo()
  24. bar = Bar()
  25. self.assertFalse((foo.fc2.weight == bar.fc2.weight).all())
  26. transfer_weights(foo, bar.state_dict())
  27. self.assertTrue((foo.fc2.weight == bar.fc2.weight).all())
  28. def test_checkpoint_path_url(self):
  29. m1 = models.get(Models.YOLO_NAS_S, num_classes=80, checkpoint_path="https://sghub.deci.ai/models/yolo_nas_s_coco.pth")
  30. m2 = models.get(Models.YOLO_NAS_S, pretrained_weights="coco")
  31. m1_state = m1.state_dict()
  32. m2_state = m2.state_dict()
  33. self.assertTrue(m1_state.keys() == m2_state.keys())
  34. for k in m1_state.keys():
  35. self.assertTrue((m1_state[k] == m2_state[k]).all())
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...