Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

registry_test.py 3.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
  1. import unittest
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import torchmetrics
  6. from torch.nn.modules.loss import _Loss
  7. from super_gradients.common.registry.registry import ARCHITECTURES
  8. from super_gradients.common.registry.registry import METRICS, LOSSES
  9. from super_gradients.common.registry import register_model, register_metric, register_loss
  10. class RegistryTest(unittest.TestCase):
  11. def setUp(self):
  12. @register_model("myconvnet")
  13. class MyConvNet(nn.Module):
  14. def __init__(self, num_classes):
  15. super().__init__()
  16. self.conv1 = nn.Conv2d(3, 6, 5)
  17. self.pool = nn.MaxPool2d(2, 2)
  18. self.conv2 = nn.Conv2d(6, 16, 5)
  19. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  20. self.fc2 = nn.Linear(120, 84)
  21. self.fc3 = nn.Linear(84, num_classes)
  22. def forward(self, x):
  23. x = self.pool(F.relu(self.conv1(x)))
  24. x = self.pool(F.relu(self.conv2(x)))
  25. x = torch.flatten(x, 1)
  26. x = F.relu(self.fc1(x))
  27. x = F.relu(self.fc2(x))
  28. x = self.fc3(x)
  29. return x
  30. @register_model()
  31. def myconvnet_for_cifar10():
  32. return MyConvNet(num_classes=10)
  33. @register_metric("custom_accuracy") # Will be registered as "custom_accuracy"
  34. class CustomAccuracy(torchmetrics.Accuracy):
  35. def update(self, preds: torch.Tensor, target: torch.Tensor):
  36. if target.shape == preds.shape:
  37. target = target.argmax(1) # Supports smooth labels
  38. super().update(preds=preds.argmax(1), target=target)
  39. @register_loss("custom_rsquared_loss")
  40. class CustomRSquaredLoss(_Loss):
  41. def forward(self, output, target):
  42. criterion_mse = nn.MSELoss()
  43. return 1 - criterion_mse(output, target).item() / torch.var(target).item()
  44. def tearDown(self):
  45. ARCHITECTURES.pop("myconvnet", None)
  46. ARCHITECTURES.pop("myconvnet_for_cifar10", None)
  47. METRICS.pop("custom_accuracy", None)
  48. LOSSES.pop("custom_rsquared_loss", None)
  49. def test_cls_is_registered(self):
  50. assert ARCHITECTURES["myconvnet"]
  51. assert METRICS["custom_accuracy"]
  52. assert LOSSES["custom_rsquared_loss"]
  53. def test_fn_is_registered(self):
  54. assert ARCHITECTURES["myconvnet_for_cifar10"]
  55. def test_is_instantiable(self):
  56. assert ARCHITECTURES["myconvnet_for_cifar10"]()
  57. assert ARCHITECTURES["myconvnet"](num_classes=10)
  58. assert METRICS["custom_accuracy"]()
  59. assert LOSSES["custom_rsquared_loss"]()
  60. def test_model_outputs(self):
  61. torch.manual_seed(0)
  62. model_1 = ARCHITECTURES["myconvnet_for_cifar10"]()
  63. torch.manual_seed(0)
  64. model_2 = ARCHITECTURES["myconvnet"](num_classes=10)
  65. dummy_input = torch.randn(1, 3, 32, 32, requires_grad=False)
  66. x = model_1(dummy_input)
  67. y = model_2(dummy_input)
  68. assert torch.equal(x, y)
  69. def test_existing_key(self):
  70. with self.assertRaises(Exception):
  71. @register_model()
  72. def myconvnet_for_cifar10():
  73. return
  74. if __name__ == "__main__":
  75. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...