|
@@ -3,9 +3,12 @@ import unittest
|
|
import torch
|
|
import torch
|
|
|
|
|
|
from super_gradients import Trainer
|
|
from super_gradients import Trainer
|
|
|
|
+from super_gradients.common.decorators.factory_decorator import resolve_param
|
|
|
|
+from super_gradients.common.factories import ActivationsTypeFactory
|
|
from super_gradients.training import models
|
|
from super_gradients.training import models
|
|
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
|
|
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
|
|
from super_gradients.training.metrics import Accuracy, Top5
|
|
from super_gradients.training.metrics import Accuracy, Top5
|
|
|
|
+from torch import nn
|
|
|
|
|
|
|
|
|
|
class FactoriesTest(unittest.TestCase):
|
|
class FactoriesTest(unittest.TestCase):
|
|
@@ -36,6 +39,16 @@ class FactoriesTest(unittest.TestCase):
|
|
self.assertIsInstance(trainer.valid_metrics.Top5, Top5)
|
|
self.assertIsInstance(trainer.valid_metrics.Top5, Top5)
|
|
self.assertIsInstance(trainer.optimizer, torch.optim.ASGD)
|
|
self.assertIsInstance(trainer.optimizer, torch.optim.ASGD)
|
|
|
|
|
|
|
|
+ def test_activations_factory(self):
|
|
|
|
+ class DummyModel(nn.Module):
|
|
|
|
+ @resolve_param("activation_in_head", ActivationsTypeFactory())
|
|
|
|
+ def __init__(self, activation_in_head):
|
|
|
|
+ super().__init__()
|
|
|
|
+ self.activation_in_head = activation_in_head()
|
|
|
|
+
|
|
|
|
+ model = DummyModel(activation_in_head="leaky_relu")
|
|
|
|
+ self.assertIsInstance(model.activation_in_head, nn.LeakyReLU)
|
|
|
|
+
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
unittest.main()
|
|
unittest.main()
|