|
@@ -6,6 +6,8 @@ from super_gradients.training.models.all_architectures import ARCHITECTURES
|
|
from super_gradients.training.metrics.all_metrics import METRICS
|
|
from super_gradients.training.metrics.all_metrics import METRICS
|
|
from super_gradients.training.losses.all_losses import LOSSES
|
|
from super_gradients.training.losses.all_losses import LOSSES
|
|
from super_gradients.modules.detection_modules import ALL_DETECTION_MODULES
|
|
from super_gradients.modules.detection_modules import ALL_DETECTION_MODULES
|
|
|
|
+from super_gradients.training.utils.callbacks.all_callbacks import CALLBACKS
|
|
|
|
+from super_gradients.training.transforms.all_transforms import TRANSFORMS
|
|
|
|
|
|
|
|
|
|
def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
|
|
def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
|
|
@@ -45,3 +47,5 @@ register_detection_module = create_register_decorator(registry=ALL_DETECTION_MOD
|
|
register_metric = create_register_decorator(registry=METRICS)
|
|
register_metric = create_register_decorator(registry=METRICS)
|
|
register_loss = create_register_decorator(registry=LOSSES)
|
|
register_loss = create_register_decorator(registry=LOSSES)
|
|
register_dataloader = create_register_decorator(registry=ALL_DATALOADERS)
|
|
register_dataloader = create_register_decorator(registry=ALL_DATALOADERS)
|
|
|
|
+register_callback = create_register_decorator(registry=CALLBACKS)
|
|
|
|
+register_transform = create_register_decorator(registry=TRANSFORMS)
|