Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_models_factory.py 3.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
  1. import unittest
  2. import torch
  3. from super_gradients.training import models
  4. # This is a subset of all the models, since some cannot be instantiated with models.get() without explicit arch_params
  5. MODELS = [
  6. "vit_base",
  7. "vit_large",
  8. "vit_huge",
  9. "beit_base_patch16_224",
  10. "beit_large_patch16_224",
  11. "custom_densenet",
  12. "densenet121",
  13. "densenet161",
  14. "densenet169",
  15. "densenet201",
  16. "efficientnet_b0",
  17. "efficientnet_b1",
  18. "efficientnet_b2",
  19. "efficientnet_b3",
  20. "efficientnet_b4",
  21. "efficientnet_b5",
  22. "efficientnet_b6",
  23. "efficientnet_b7",
  24. "efficientnet_b8",
  25. "efficientnet_l2",
  26. "mobilenet_v2",
  27. "mobile_net_v2_135",
  28. "mobilenet_v3_large",
  29. "mobilenet_v3_small",
  30. "resnet18",
  31. "resnet18_cifar",
  32. "resnet34",
  33. "resnet50",
  34. "resnet50_3343",
  35. "resnet101",
  36. "resnet152",
  37. "resnext50",
  38. "resnext101",
  39. "shufflenet_v2_x0_5",
  40. "shufflenet_v2_x1_0",
  41. "shufflenet_v2_x1_5",
  42. "shufflenet_v2_x2_0",
  43. "csp_darknet53",
  44. "ppyoloe_s",
  45. "ppyoloe_m",
  46. "ppyoloe_l",
  47. "ppyoloe_x",
  48. "darknet53",
  49. "ssd_mobilenet_v1",
  50. "ssd_lite_mobilenet_v2",
  51. "regnetY200",
  52. "regnetY400",
  53. "regnetY600",
  54. "regnetY800",
  55. "yolox_n",
  56. "yolox_t",
  57. "yolox_s",
  58. "yolox_m",
  59. "yolox_l",
  60. "yolox_x",
  61. "yolo_nas_s",
  62. "yolo_nas_m",
  63. "yolo_nas_l",
  64. "shelfnet18_lw",
  65. "shelfnet34_lw",
  66. # "shelfnet50_3343", # FIXME: seems to not work correctly
  67. # "shelfnet50", # FIXME: seems to not work correctly
  68. # "shelfnet101", # FIXME: seems to not work correctly
  69. "stdc1_classification",
  70. "stdc2_classification",
  71. "stdc1_seg75",
  72. "stdc1_seg50",
  73. "stdc1_seg",
  74. "stdc2_seg75",
  75. "stdc2_seg50",
  76. "stdc2_seg",
  77. "ddrnet_39",
  78. "ddrnet_23",
  79. "ddrnet_23_slim",
  80. "pp_lite_b_seg75",
  81. "pp_lite_b_seg50",
  82. "pp_lite_b_seg",
  83. "pp_lite_t_seg75",
  84. "pp_lite_t_seg50",
  85. "pp_lite_t_seg",
  86. "regseg48",
  87. "segformer_b0",
  88. "segformer_b1",
  89. "segformer_b2",
  90. "segformer_b3",
  91. "segformer_b4",
  92. "segformer_b5",
  93. "dekr_w32_no_dc",
  94. "yolo_nas_pose_n",
  95. "yolo_nas_pose_s",
  96. "yolo_nas_pose_m",
  97. "yolo_nas_pose_l",
  98. ]
  99. def can_model_forward(model, input_channels: int) -> bool:
  100. """Checks if the given model can perform a forward pass on inputs of certain sizes."""
  101. input_sizes = [(224, 224), (512, 512)] # We check different sizes because some model only support one or the other
  102. for h, w in input_sizes:
  103. try:
  104. model(torch.rand(2, input_channels, h, w))
  105. return True
  106. except Exception:
  107. continue
  108. return False
  109. class DynamicModelTests(unittest.TestCase):
  110. def test_models(self):
  111. # TODO: replace `MODELS` with `ARCHITECTURES.keys()` once all models can be instantiated with
  112. # TODO models.get() without explicit arch_params without any explicit arch_params
  113. for model_name in MODELS:
  114. with self.subTest(model_name=model_name):
  115. model = models.get(model_name, num_classes=20, num_input_channels=4)
  116. self.assertEqual(model.get_input_channels(), 4)
  117. self.assertTrue(can_model_forward(model=model, input_channels=4))
  118. model.replace_input_channels(51)
  119. self.assertEqual(model.get_input_channels(), 51)
  120. self.assertTrue(can_model_forward(model=model, input_channels=51))
  121. if __name__ == "__main__":
  122. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...